-
Notifications
You must be signed in to change notification settings - Fork 222
/
flows.py
165 lines (131 loc) · 5.26 KB
/
flows.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
from jax import lax
import jax.numpy as jnp
from numpyro.distributions.constraints import real_vector
from numpyro.distributions.transforms import Transform
from numpyro.util import fori_loop
def _clamp_preserve_gradients(x, min, max):
return x + lax.stop_gradient(jnp.clip(x, a_min=min, a_max=max) - x)
# adapted from https://github.com/pyro-ppl/pyro/blob/dev/pyro/distributions/transforms/iaf.py
class InverseAutoregressiveTransform(Transform):
"""
An implementation of Inverse Autoregressive Flow, using Eq (10) from Kingma et al., 2016,
:math:`\\mathbf{y} = \\mu_t + \\sigma_t\\odot\\mathbf{x}`
where :math:`\\mathbf{x}` are the inputs, :math:`\\mathbf{y}` are the outputs, :math:`\\mu_t,\\sigma_t`
are calculated from an autoregressive network on :math:`\\mathbf{x}`, and :math:`\\sigma_t>0`.
**References**
1. *Improving Variational Inference with Inverse Autoregressive Flow* [arXiv:1606.04934],
Diederik P. Kingma, Tim Salimans, Rafal Jozefowicz, Xi Chen, Ilya Sutskever, Max Welling
"""
domain = real_vector
codomain = real_vector
def __init__(
self, autoregressive_nn, log_scale_min_clip=-5.0, log_scale_max_clip=3.0
):
"""
:param autoregressive_nn: an autoregressive neural network whose forward call returns a real-valued
mean and log scale as a tuple
"""
self.arn = autoregressive_nn
self.log_scale_min_clip = log_scale_min_clip
self.log_scale_max_clip = log_scale_max_clip
def __call__(self, x):
"""
:param numpy.ndarray x: the input into the transform
"""
return self.call_with_intermediates(x)[0]
def call_with_intermediates(self, x):
mean, log_scale = self.arn(x)
log_scale = _clamp_preserve_gradients(
log_scale, self.log_scale_min_clip, self.log_scale_max_clip
)
scale = jnp.exp(log_scale)
return scale * x + mean, log_scale
def _inverse(self, y):
"""
:param numpy.ndarray y: the output of the transform to be inverted
"""
# NOTE: Inversion is an expensive operation that scales in the dimension of the input
def _update_x(i, x):
mean, log_scale = self.arn(x)
inverse_scale = jnp.exp(
-_clamp_preserve_gradients(
log_scale, min=self.log_scale_min_clip, max=self.log_scale_max_clip
)
)
x = (y - mean) * inverse_scale
return x
x = fori_loop(0, y.shape[-1], _update_x, jnp.zeros(y.shape))
return x
def log_abs_det_jacobian(self, x, y, intermediates=None):
"""
Calculates the elementwise determinant of the log jacobian.
:param numpy.ndarray x: the input to the transform
:param numpy.ndarray y: the output of the transform
"""
if intermediates is None:
log_scale = self.arn(x)[1]
log_scale = _clamp_preserve_gradients(
log_scale, self.log_scale_min_clip, self.log_scale_max_clip
)
return log_scale.sum(-1)
else:
log_scale = intermediates
return log_scale.sum(-1)
def tree_flatten(self):
return (self.log_scale_min_clip, self.log_scale_max_clip), (
("log_scale_min_clip", "log_scale_max_clip"),
{"arn": self.arn},
)
def __eq__(self, other):
if not isinstance(other, InverseAutoregressiveTransform):
return False
return (
(self.arn is other.arn)
& jnp.array_equal(self.log_scale_min_clip, other.log_scale_min_clip)
& jnp.array_equal(self.log_scale_max_clip, other.log_scale_max_clip)
)
class BlockNeuralAutoregressiveTransform(Transform):
"""
An implementation of Block Neural Autoregressive flow.
**References**
1. *Block Neural Autoregressive Flow*,
Nicola De Cao, Ivan Titov, Wilker Aziz
"""
domain = real_vector
codomain = real_vector
def __init__(self, bn_arn):
self.bn_arn = bn_arn
def __call__(self, x):
"""
:param numpy.ndarray x: the input into the transform
"""
return self.call_with_intermediates(x)[0]
def call_with_intermediates(self, x):
y, logdet = self.bn_arn(x)
return y, logdet
def _inverse(self, y):
raise NotImplementedError(
"Block neural autoregressive transform does not have an analytic"
" inverse implemented."
)
def log_abs_det_jacobian(self, x, y, intermediates=None):
"""
Calculates the elementwise determinant of the log jacobian.
:param numpy.ndarray x: the input to the transform
:param numpy.ndarray y: the output of the transform
"""
if intermediates is None:
logdet = self.bn_arn(x)[1]
return logdet.sum(-1)
else:
logdet = intermediates
return logdet.sum(-1)
def tree_flatten(self):
return (), ((), {"bn_arn": self.bn_arn})
def __eq__(self, other):
return (
isinstance(other, BlockNeuralAutoregressiveTransform)
and self.bn_arn is other.bn_arn
)