-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
joint_distribution_coroutine.py
399 lines (334 loc) · 15.5 KB
/
joint_distribution_coroutine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
# Copyright 2018 The TensorFlow Probability Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""The `JointDistributionCoroutine` class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import warnings
import tensorflow.compat.v2 as tf
from tensorflow_probability.python.distributions import joint_distribution as joint_distribution_lib
from tensorflow_probability.python.internal import assert_util
from tensorflow_probability.python.internal import prefer_static
from tensorflow_probability.python.internal import samplers
from tensorflow_probability.python.internal import structural_tuple
from tensorflow_probability.python.util.seed_stream import SeedStream
from tensorflow_probability.python.util.seed_stream import TENSOR_SEED_MSG_PREFIX
from tensorflow.python.util import nest # pylint: disable=g-direct-tensorflow-import
__all__ = [
'JointDistributionCoroutine',
]
JAX_MODE = False
# Cause all warnings to always be triggered.
# Not having this means subsequent calls wont trigger the warning.
warnings.filterwarnings(
'always',
module='tensorflow_probability.*joint_distribution_coroutine',
append=True) # Don't override user-set filters.
class JointDistributionCoroutine(joint_distribution_lib.JointDistribution):
"""Joint distribution parameterized by a distribution-making generator.
This distribution enables both sampling and joint probability computation from
a single model specification.
A joint distribution is a collection of possibly interdependent distributions.
The `JointDistributionCoroutine` is specified by a generator that
generates the elements of this collection.
#### Mathematical Details
The `JointDistributionCoroutine` implements the chain rule of probability.
That is, the probability function of a length-`d` vector `x` is,
```none
p(x) = prod{ p(x[i] | x[:i]) : i = 0, ..., (d - 1) }
```
The `JointDistributionCoroutine` is parameterized by a generator
that yields `tfp.distributions.Distribution`-like instances.
Each element yielded implements the `i`-th *full conditional distribution*,
`p(x[i] | x[:i])`. Within the generator, the return value from the yield
is a sample from the distribution that may be used to construct subsequent
yielded `Distribution`-like instances. This allows later instances
to be conditional on earlier ones.
When the `sample` method for a `JointDistributionCoroutine` is called with
a `sample_shape`, the `sample` method for each of the yielded
distributions is called.
The distributions that have been wrapped in the
`JointDistributionCoroutine.Root` class will be called with `sample_shape` as
the `sample_shape` argument, and the unwrapped distributions
will be called with `()` as the `sample_shape` argument.
It is the user's responsibility to ensure that
each of the distributions generates samples with the specified sample
size.
**Name resolution**: The names of `JointDistributionCoroutine` components
may be specified by passing `name` arguments to distribution constructors (
`tfd.Normal(0., 1., name='x')). Components without an explicit name will be
assigned a dummy name.
#### Examples
```python
tfd = tfp.distributions
# Consider the following generative model:
# e ~ Exponential(rate=[100, 120])
# g ~ Gamma(concentration=e[0], rate=e[1])
# n ~ Normal(loc=0, scale=2.)
# m ~ Normal(loc=n, scale=g)
# In TFP, we can write this as:
Root = tfd.JointDistributionCoroutine.Root # Convenient alias.
def model():
e = yield Root(tfd.Independent(tfd.Exponential(rate=[100, 120]), 1))
g = yield tfd.Gamma(concentration=e[..., 0], rate=e[..., 1])
n = yield Root(tfd.Normal(loc=0, scale=2.))
m = yield tfd.Normal(loc=n, scale=g)
joint = tfd.JointDistributionCoroutine(model)
x = joint.sample()
# ==> x is a length-4 tuple of Tensors representing a draw/realization from
# each distribution.
joint.log_prob(x)
# ==> A scalar `Tensor` representing the total log prob under all four
# distributions.
```
For improved readability of sampled values, the yielded distributions can also
be named:
```python
tfd = tfp.distributions
Root = tfd.JointDistributionCoroutine.Root # Convenient alias.
def model():
e = yield Root(tfd.Independent(
tfd.Exponential(rate=[100, 120], name='e'), 1))
g = yield tfd.Gamma(concentration=e[..., 0], rate=e[..., 1], name='g')
n = yield Root(tfd.Normal(loc=0, scale=2., name='n'))
m = yield tfd.Normal(loc=n, scale=g, name='m')
joint = tfd.JointDistributionCoroutine(model)
x = joint.sample()
# ==> x is a namedtuple with fields (in order) 'e', 'g', 'n', 'm' and values
# representing the draw/realization from each corresponding distribution.
joint.log_prob(x)
# ==> A scalar `Tensor` representing the total log prob under all four
# distributions.
# Passing dictionaries via `kwargs` also works.
joint.log_prob(**x._as_dict())
# Or:
joint.log_prob(e=..., g=..., n=..., m=...)
```
If any of the yielded distributions are not explicitly named, they will
automatically be given a name of the form `var#` where `#` is the index of the
associated distribution. E.g. the first yielded distribution will have a
default name of `var0`.
#### Discussion
Each element yielded by the generator must be a `tfd.Distribution`-like
instance.
An object is deemed '`tfd.Distribution`-like' if it has a
`sample`, `log_prob`, and distribution properties, e.g., `batch_shape`,
`event_shape`, `dtype`.
Consider the following fragment from a generator:
```python
n = yield Root(tfd.Normal(loc=0, scale=2.))
m = yield tfd.Normal(loc=n, scale=1.0)
```
The random variable `n` has no dependence on earlier random variables and
`Root` is used to indicate that its distribution needs to be passed a
`sample_shape`. On the other hand, the distribution of `m` is constructed
using the value of `n`. This means that `n` is already shaped according to
the `sample_shape` and there is no need to pass `m`'s distribution a
`sample_size`. So `Root` is not used to wrap `m`'s distribution.
**Note**: unlike most other distributions in `tfp.distributions`,
`JointDistributionCoroutine.sample` returns a `tuple` of `Tensor`s
rather than a `Tensor`. Accordingly `joint.batch_shape` returns a
`tuple` of `TensorShape`s for each of the distributions' batch shapes
and `joint.batch_shape_tensor()` returns a `tuple` of `Tensor`s for
each of the distributions' event shapes. (Same with `event_shape` analogues.)
"""
class Root(collections.namedtuple('Root', ['distribution'])):
"""Wrapper for coroutine distributions which lack distribution parents."""
__slots__ = ()
def __init__(
self,
model,
sample_dtype=None,
validate_args=False,
name=None,
):
"""Construct the `JointDistributionCoroutine` distribution.
Args:
model: A generator that yields a sequence of `tfd.Distribution`-like
instances.
sample_dtype: Samples from this distribution will be structured like
`tf.nest.pack_sequence_as(sample_dtype, list_)`. `sample_dtype` is only
used for `tf.nest.pack_sequence_as` structuring of outputs, never
casting (which is the responsibility of the component distributions).
Default value: `None` (i.e. `namedtuple`).
validate_args: Python `bool`. Whether to validate input with asserts.
If `validate_args` is `False`, and the inputs are invalid,
correct behavior is not guaranteed.
Default value: `False`.
name: The name for ops managed by the distribution.
Default value: `None` (i.e., `JointDistributionCoroutine`).
"""
parameters = dict(locals())
with tf.name_scope(name or 'JointDistributionCoroutine') as name:
self._model_coroutine = model
# Hint `no_dependency` to tell tf.Module not to screw up the sample dtype
# with extraneous wrapping (list => ListWrapper, etc.).
self._sample_dtype = self._no_dependency(sample_dtype)
self._single_sample_distributions = {}
super(JointDistributionCoroutine, self).__init__(
dtype=sample_dtype,
reparameterization_type=None, # Ignored; we'll override.
validate_args=validate_args,
allow_nan_stats=False,
parameters=parameters,
name=name)
# TODO(b/166658748): Once the bug is resolved, we should be able to eliminate
# this workaround that disables sanitize_seed for JD*AB.
_stateful_to_stateless = True
@property
def _require_root(self):
return True
@property
def model(self):
return self._model_coroutine
def _assert_compatible_shape(self, index, sample_shape, samples):
requested_shape, _ = self._expand_sample_shape_to_vector(
tf.convert_to_tensor(sample_shape, dtype=tf.int32),
name='requested_shape')
actual_shape = prefer_static.shape(samples)
actual_rank = prefer_static.rank_from_shape(actual_shape)
requested_rank = prefer_static.rank_from_shape(requested_shape)
# We test for two properties we expect of yielded distributions:
# (1) The rank of the tensor of generated samples must be at least
# as large as the rank requested.
# (2) The requested shape must be a prefix of the shape of the
# generated tensor of samples.
# We attempt to perform test (1) statically first.
# We don't need to do this explicitly for test (2) because
# `assert_equal` evaluates statically if it can.
static_actual_rank = tf.get_static_value(actual_rank)
static_requested_rank = tf.get_static_value(requested_rank)
assertion_message = ('Samples yielded by distribution #{} are not '
'consistent with `sample_shape` passed to '
'`JointDistributionCoroutine` '
'distribution.'.format(index))
# TODO Remove this static check (b/138738650)
if (static_actual_rank is not None and
static_requested_rank is not None):
# We're able to statically check the rank
if static_actual_rank < static_requested_rank:
raise ValueError(assertion_message)
else:
control_dependencies = []
else:
# We're not able to statically check the rank
control_dependencies = [
assert_util.assert_greater_equal(
actual_rank, requested_rank,
message=assertion_message)
]
with tf.control_dependencies(control_dependencies):
trimmed_actual_shape = actual_shape[:requested_rank]
control_dependencies = [
assert_util.assert_equal(
requested_shape, trimmed_actual_shape,
message=assertion_message)
]
return control_dependencies
def _flat_sample_distributions(self, sample_shape=(), seed=None, value=None):
"""Executes `model`, creating both samples and distributions."""
ds = []
values_out = []
if samplers.is_stateful_seed(seed):
seed_stream = SeedStream(seed, salt='JointDistributionCoroutine')
if not self._stateful_to_stateless:
seed = None
else:
seed_stream = None # We got a stateless seed for seed=.
# TODO(b/166658748): Make _stateful_to_stateless always True (eliminate it).
if self._stateful_to_stateless and (seed is not None or not JAX_MODE):
seed = samplers.sanitize_seed(seed, salt='JointDistributionCoroutine')
gen = self._model_coroutine()
index = 0
d = next(gen)
if self._require_root and not isinstance(d, self.Root):
raise ValueError('First distribution yielded by coroutine must '
'be wrapped in `Root`.')
try:
while True:
actual_distribution = d.distribution if isinstance(d, self.Root) else d
ds.append(actual_distribution)
# Ensure reproducibility even when xs are (partially) set. Always split.
stateful_sample_seed = None if seed_stream is None else seed_stream()
if seed is None:
stateless_sample_seed = None
else:
stateless_sample_seed, seed = samplers.split_seed(seed)
if (value is not None and len(value) > index and
value[index] is not None):
def convert_tree_to_tensor(x, dtype_hint):
return tf.convert_to_tensor(x, dtype_hint=dtype_hint)
# This signature does not allow kwarg names. Applies
# `convert_to_tensor` on the next value.
next_value = nest.map_structure_up_to(
ds[-1].dtype, # shallow_tree
convert_tree_to_tensor, # func
value[index], # x
ds[-1].dtype) # dtype_hint
else:
try:
next_value = actual_distribution.sample(
sample_shape=sample_shape if isinstance(d, self.Root) else (),
seed=(stateful_sample_seed if stateless_sample_seed is None
else stateless_sample_seed))
except TypeError as e:
if ('Expected int for argument' not in str(e) and
TENSOR_SEED_MSG_PREFIX not in str(e)) or (
stateful_sample_seed is None):
raise
msg = (
'Falling back to stateful sampling for distribution #{index} '
'(0-based) of type `{dist_cls}` with component name '
'{component_name} and `dist.name` "{dist_name}". Please '
'update to use `tf.random.stateless_*` RNGs. This fallback may '
'be removed after 20-Dec-2020. ({exc})')
component_name = (
joint_distribution_lib.get_explicit_name_for_component(ds[-1]))
if component_name is None:
component_name = '[None specified]'
else:
component_name = '"{}"'.format(component_name)
warnings.warn(msg.format(
index=index,
component_name=component_name,
dist_name=ds[-1].name,
dist_cls=type(ds[-1]),
exc=str(e)))
next_value = actual_distribution.sample(
sample_shape=sample_shape if isinstance(d, self.Root) else (),
seed=stateful_sample_seed)
if self._validate_args:
with tf.control_dependencies(
self._assert_compatible_shape(
index, sample_shape, next_value)):
values_out.append(tf.nest.map_structure(tf.identity, next_value))
else:
values_out.append(next_value)
index += 1
d = gen.send(next_value)
except StopIteration:
pass
return ds, values_out
def _model_unflatten(self, xs):
if self._sample_dtype is None:
return structural_tuple.structtuple(self._flat_resolve_names())(*xs)
# Cast `xs` as `tuple` so we can handle generators.
return tf.nest.pack_sequence_as(self._sample_dtype, tuple(xs))
def _model_flatten(self, xs):
if self._sample_dtype is None:
return tuple((xs[k] for k in self._flat_resolve_names())
if isinstance(xs, collections.Mapping) else xs)
return nest.flatten_up_to(self._sample_dtype, xs)