-
Notifications
You must be signed in to change notification settings - Fork 5.5k
/
distributions.py
248 lines (194 loc) · 8.28 KB
/
distributions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
"""This is the next version of action distribution base class."""
from typing import Tuple
import gymnasium as gym
import abc
from ray.rllib.utils.annotations import ExperimentalAPI
from ray.rllib.utils.typing import TensorType, Union
from ray.rllib.utils.annotations import override
@ExperimentalAPI
class Distribution(abc.ABC):
"""The base class for distribution over a random variable.
Examples:
.. testcode::
import torch
from ray.rllib.core.models.configs import MLPHeadConfig
from ray.rllib.models.torch.torch_distributions import TorchCategorical
model = MLPHeadConfig(input_dims=[1]).build(framework="torch")
# Create an action distribution from model logits
action_logits = model(torch.Tensor([[1]]))
action_dist = TorchCategorical.from_logits(action_logits)
action = action_dist.sample()
# Create another distribution from a dummy Tensor
action_dist2 = TorchCategorical.from_logits(torch.Tensor([0]))
# Compute some common metrics
logp = action_dist.logp(action)
kl = action_dist.kl(action_dist2)
entropy = action_dist.entropy()
"""
@abc.abstractmethod
def sample(
self,
*,
sample_shape: Tuple[int, ...] = None,
return_logp: bool = False,
**kwargs,
) -> Union[TensorType, Tuple[TensorType, TensorType]]:
"""Draw a sample from the distribution.
Args:
sample_shape: The shape of the sample to draw.
return_logp: Whether to return the logp of the sampled values.
**kwargs: Forward compatibility placeholder.
Returns:
The sampled values. If return_logp is True, returns a tuple of the
sampled values and its logp.
"""
@abc.abstractmethod
def rsample(
self,
*,
sample_shape: Tuple[int, ...] = None,
return_logp: bool = False,
**kwargs,
) -> Union[TensorType, Tuple[TensorType, TensorType]]:
"""Draw a re-parameterized sample from the action distribution.
If this method is implemented, we can take gradients of samples w.r.t. the
distribution parameters.
Args:
sample_shape: The shape of the sample to draw.
return_logp: Whether to return the logp of the sampled values.
**kwargs: Forward compatibility placeholder.
Returns:
The sampled values. If return_logp is True, returns a tuple of the
sampled values and its logp.
"""
@abc.abstractmethod
def logp(self, value: TensorType, **kwargs) -> TensorType:
"""The log-likelihood of the distribution computed at `value`
Args:
value: The value to compute the log-likelihood at.
**kwargs: Forward compatibility placeholder.
Returns:
The log-likelihood of the value.
"""
@abc.abstractmethod
def kl(self, other: "Distribution", **kwargs) -> TensorType:
"""The KL-divergence between two distributions.
Args:
other: The other distribution.
**kwargs: Forward compatibility placeholder.
Returns:
The KL-divergence between the two distributions.
"""
@abc.abstractmethod
def entropy(self, **kwargs) -> TensorType:
"""The entropy of the distribution.
Args:
**kwargs: Forward compatibility placeholder.
Returns:
The entropy of the distribution.
"""
@staticmethod
@abc.abstractmethod
def required_input_dim(space: gym.Space, **kwargs) -> int:
"""Returns the required length of an input parameter tensor.
Args:
space: The space this distribution will be used for,
whose shape attributes will be used to determine the required shape of
the input parameter tensor.
**kwargs: Forward compatibility placeholder.
Returns:
size of the required input vector (minus leading batch dimension).
"""
@classmethod
def from_logits(cls, logits: TensorType, **kwargs) -> "Distribution":
"""Creates a Distribution from logits.
The caller does not need to have knowledge of the distribution class in order
to create it and sample from it. The passed batched logits vectors might be
split up and are passed to the distribution class' constructor as kwargs.
Args:
logits: The logits to create the distribution from.
**kwargs: Forward compatibility placeholder.
Returns:
The created distribution.
.. testcode::
import numpy as np
from ray.rllib.models.distributions import Distribution
class Uniform(Distribution):
def __init__(self, lower, upper):
self.lower = lower
self.upper = upper
def sample(self):
return self.lower + (self.upper - self.lower) * np.random.rand()
def logp(self, x):
...
def kl(self, other):
...
def entropy(self):
...
@staticmethod
def required_input_dim(space):
...
def rsample(self):
...
@classmethod
def from_logits(cls, logits, **kwargs):
return Uniform(logits[:, 0], logits[:, 1])
logits = np.array([[0.0, 1.0], [2.0, 3.0]])
my_dist = Uniform.from_logits(logits)
sample = my_dist.sample()
"""
raise NotImplementedError
@classmethod
def get_partial_dist_cls(
parent_cls: "Distribution", **partial_kwargs
) -> "Distribution":
"""Returns a partial child of TorchMultiActionDistribution.
This is useful if inputs needed to instantiate the Distribution from logits
are available, but the logits are not.
"""
class DistributionPartial(parent_cls):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@staticmethod
def _merge_kwargs(**kwargs):
"""Checks if keys in kwargs don't clash with partial_kwargs."""
overlap = set(kwargs) & set(partial_kwargs)
if overlap:
raise ValueError(
f"Cannot override the following kwargs: {overlap}.\n"
f"This is because they were already set at the time this "
f"partial class was defined."
)
merged_kwargs = {**partial_kwargs, **kwargs}
return merged_kwargs
@classmethod
@override(parent_cls)
def required_input_dim(cls, space: gym.Space, **kwargs) -> int:
merged_kwargs = cls._merge_kwargs(**kwargs)
assert space == merged_kwargs["space"]
return parent_cls.required_input_dim(**merged_kwargs)
@classmethod
@override(parent_cls)
def from_logits(
cls,
logits: TensorType,
**kwargs,
) -> "DistributionPartial":
merged_kwargs = cls._merge_kwargs(**kwargs)
distribution = parent_cls.from_logits(logits, **merged_kwargs)
# Replace the class of the returned distribution with this partial
# This makes it so that we can use type() on this distribution and
# get back the partial class.
distribution.__class__ = cls
return distribution
# Substitute name of this partial class to match the original class.
DistributionPartial.__name__ = f"{parent_cls}Partial"
return DistributionPartial
def to_deterministic(self) -> "Distribution":
"""Returns a deterministic equivalent for this distribution.
Specifically, the deterministic equivalent for a Categorical distribution is a
Deterministic distribution that selects the action with maximum logit value.
Generally, the choice of the deterministic replacement is informed by
established conventions.
"""
raise NotImplementedError