-
Notifications
You must be signed in to change notification settings - Fork 21.4k
/
transformed_distribution.py
157 lines (139 loc) · 6.49 KB
/
transformed_distribution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import torch
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.transforms import Transform
from torch.distributions.utils import _sum_rightmost
from typing import Dict
class TransformedDistribution(Distribution):
r"""
Extension of the Distribution class, which applies a sequence of Transforms
to a base distribution. Let f be the composition of transforms applied::
X ~ BaseDistribution
Y = f(X) ~ TransformedDistribution(BaseDistribution, f)
log p(Y) = log p(X) + log |det (dX/dY)|
Note that the ``.event_shape`` of a :class:`TransformedDistribution` is the
maximum shape of its base distribution and its transforms, since transforms
can introduce correlations among events.
An example for the usage of :class:`TransformedDistribution` would be::
# Building a Logistic Distribution
# X ~ Uniform(0, 1)
# f = a + b * logit(X)
# Y ~ f(X) ~ Logistic(a, b)
base_distribution = Uniform(0, 1)
transforms = [SigmoidTransform().inv, AffineTransform(loc=a, scale=b)]
logistic = TransformedDistribution(base_distribution, transforms)
For more examples, please look at the implementations of
:class:`~torch.distributions.gumbel.Gumbel`,
:class:`~torch.distributions.half_cauchy.HalfCauchy`,
:class:`~torch.distributions.half_normal.HalfNormal`,
:class:`~torch.distributions.log_normal.LogNormal`,
:class:`~torch.distributions.pareto.Pareto`,
:class:`~torch.distributions.weibull.Weibull`,
:class:`~torch.distributions.relaxed_bernoulli.RelaxedBernoulli` and
:class:`~torch.distributions.relaxed_categorical.RelaxedOneHotCategorical`
"""
arg_constraints: Dict[str, constraints.Constraint] = {}
def __init__(self, base_distribution, transforms, validate_args=None):
self.base_dist = base_distribution
if isinstance(transforms, Transform):
self.transforms = [transforms, ]
elif isinstance(transforms, list):
if not all(isinstance(t, Transform) for t in transforms):
raise ValueError("transforms must be a Transform or a list of Transforms")
self.transforms = transforms
else:
raise ValueError("transforms must be a Transform or list, but was {}".format(transforms))
shape = self.base_dist.batch_shape + self.base_dist.event_shape
event_dim = max([len(self.base_dist.event_shape)] + [t.event_dim for t in self.transforms])
batch_shape = shape[:len(shape) - event_dim]
event_shape = shape[len(shape) - event_dim:]
super(TransformedDistribution, self).__init__(batch_shape, event_shape, validate_args=validate_args)
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(TransformedDistribution, _instance)
batch_shape = torch.Size(batch_shape)
base_dist_batch_shape = batch_shape + self.base_dist.batch_shape[len(self.batch_shape):]
new.base_dist = self.base_dist.expand(base_dist_batch_shape)
new.transforms = self.transforms
super(TransformedDistribution, new).__init__(batch_shape, self.event_shape, validate_args=False)
new._validate_args = self._validate_args
return new
@constraints.dependent_property
def support(self):
return self.transforms[-1].codomain if self.transforms else self.base_dist.support
@property
def has_rsample(self):
return self.base_dist.has_rsample
def sample(self, sample_shape=torch.Size()):
"""
Generates a sample_shape shaped sample or sample_shape shaped batch of
samples if the distribution parameters are batched. Samples first from
base distribution and applies `transform()` for every transform in the
list.
"""
with torch.no_grad():
x = self.base_dist.sample(sample_shape)
for transform in self.transforms:
x = transform(x)
return x
def rsample(self, sample_shape=torch.Size()):
"""
Generates a sample_shape shaped reparameterized sample or sample_shape
shaped batch of reparameterized samples if the distribution parameters
are batched. Samples first from base distribution and applies
`transform()` for every transform in the list.
"""
x = self.base_dist.rsample(sample_shape)
for transform in self.transforms:
x = transform(x)
return x
def log_prob(self, value):
"""
Scores the sample by inverting the transform(s) and computing the score
using the score of the base distribution and the log abs det jacobian.
"""
event_dim = len(self.event_shape)
log_prob = 0.0
y = value
for transform in reversed(self.transforms):
x = transform.inv(y)
log_prob = log_prob - _sum_rightmost(transform.log_abs_det_jacobian(x, y),
event_dim - transform.event_dim)
y = x
log_prob = log_prob + _sum_rightmost(self.base_dist.log_prob(y),
event_dim - len(self.base_dist.event_shape))
return log_prob
def _monotonize_cdf(self, value):
"""
This conditionally flips ``value -> 1-value`` to ensure :meth:`cdf` is
monotone increasing.
"""
sign = 1
for transform in self.transforms:
sign = sign * transform.sign
if isinstance(sign, int) and sign == 1:
return value
return sign * (value - 0.5) + 0.5
def cdf(self, value):
"""
Computes the cumulative distribution function by inverting the
transform(s) and computing the score of the base distribution.
"""
for transform in self.transforms[::-1]:
value = transform.inv(value)
if self._validate_args:
self.base_dist._validate_sample(value)
value = self.base_dist.cdf(value)
value = self._monotonize_cdf(value)
return value
def icdf(self, value):
"""
Computes the inverse cumulative distribution function using
transform(s) and computing the score of the base distribution.
"""
value = self._monotonize_cdf(value)
if self._validate_args:
self.base_dist._validate_sample(value)
value = self.base_dist.icdf(value)
for transform in self.transforms:
value = transform(value)
return value