-
Notifications
You must be signed in to change notification settings - Fork 223
/
batch_util.py
578 lines (472 loc) · 16.8 KB
/
batch_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
import copy
from functools import singledispatch
from typing import Union
import jax.numpy as jnp
from jax.tree_util import tree_map
from numpyro.distributions import constraints
from numpyro.distributions.conjugate import (
BetaBinomial,
DirichletMultinomial,
GammaPoisson,
NegativeBinomial2,
NegativeBinomialLogits,
NegativeBinomialProbs,
)
from numpyro.distributions.constraints import Constraint
from numpyro.distributions.continuous import (
CAR,
LKJ,
AsymmetricLaplaceQuantile,
Beta,
BetaProportion,
Chi2,
Gamma,
HalfCauchy,
HalfNormal,
InverseGamma,
Kumaraswamy,
LKJCholesky,
LogNormal,
LogUniform,
LowRankMultivariateNormal,
MultivariateStudentT,
Pareto,
RelaxedBernoulliLogits,
StudentT,
Uniform,
)
from numpyro.distributions.copula import GaussianCopula, GaussianCopulaBeta
from numpyro.distributions.discrete import (
CategoricalProbs,
DiscreteUniform,
OrderedLogistic,
ZeroInflatedPoisson,
ZeroInflatedProbs,
)
from numpyro.distributions.distribution import (
Distribution,
ExpandedDistribution,
Independent,
MaskedDistribution,
Unit,
)
from numpyro.distributions.transforms import (
AffineTransform,
CorrCholeskyTransform,
PowerTransform,
Transform,
)
from numpyro.distributions.truncated import (
LeftTruncatedDistribution,
RightTruncatedDistribution,
TwoSidedTruncatedDistribution,
)
@singledispatch
def vmap_over(d: Union[Distribution, Transform, Constraint], **kwargs):
raise NotImplementedError
@vmap_over.register
def _vmap_over_affine_transform(
dist: AffineTransform, loc=None, scale=None, domain=None
):
dist_axes = copy.copy(dist)
dist_axes.loc = loc
dist_axes.scale = scale
dist_axes.domain = domain
return dist_axes
@vmap_over.register
def _vmap_over_greater_than(dist: constraints._GreaterThan, lower_bound=None):
axes = copy.copy(dist)
axes.lower_bound = lower_bound
return axes
@vmap_over.register
def _vmap_over_less_than(dist: constraints._LessThan, upper_bound=None):
axes = copy.copy(dist)
axes.upper_bound = upper_bound
return axes
@vmap_over.register
def _vmap_over_interval(
dist: constraints._Interval, lower_bound=None, upper_bound=None
):
axes = copy.copy(dist)
axes.lower_bound = lower_bound
axes.upper_bound = upper_bound
return axes
@vmap_over.register
def _vmap_over_integer_interval(
dist: constraints._IntegerInterval, lower_bound=None, upper_bound=None
):
dist_axes = copy.copy(dist)
dist_axes.lower_bound = lower_bound
dist_axes.upper_bound = upper_bound
return dist_axes
@vmap_over.register
def _vmap_over_corr_cholesky_transform(dist: CorrCholeskyTransform):
return None
@vmap_over.register
def _vmap_over_power_transform(dist: PowerTransform, exponent=None):
axes = copy.copy(dist)
axes.exponent = exponent
return axes
@vmap_over.register
def _default_vmap_over(d: Distribution, **kwargs):
pytree_fields = type(d).gather_pytree_data_fields()
dist_axes = copy.copy(d)
for f in pytree_fields:
setattr(dist_axes, f, kwargs.get(f, None))
return dist_axes
@vmap_over.register
def _(dist: AsymmetricLaplaceQuantile, loc=None, scale=None, quantile=None):
dist_axes = _default_vmap_over(
dist,
loc=loc,
scale=scale,
quantile=quantile,
_ald=vmap_over(
dist._ald,
loc=loc,
scale=scale if scale is not None else quantile,
asymmetry=quantile,
),
)
return dist_axes
@vmap_over.register
def _vmap_over_beta(dist: Beta, concentration1=None, concentration0=None):
dist_axes = _default_vmap_over(
dist, concentration1=concentration1, concentration0=concentration0
)
if concentration1 is not None or concentration0 is not None:
dist_axes._dirichlet = 0
else:
dist_axes._dirichlet = None
return dist_axes
@vmap_over.register
def _vmap_over_beta_proportion(dist: BetaProportion, mean=None, concentration=None):
dist_axes = vmap_over.dispatch(Beta)(
dist,
concentration1=concentration if concentration is not None else mean,
concentration0=concentration if concentration is not None else mean,
)
dist_axes.concentration = concentration
return dist_axes
@vmap_over.register
def _vmap_over_chi2(dist: Chi2, df=None):
dist_axes = vmap_over.dispatch(Gamma)(dist, rate=df, concentration=df)
dist_axes.df = df
return dist_axes
@vmap_over.register
def _vmap_over_gaussian_copula(
dist: GaussianCopula,
marginal_dist=None,
correlation_matrix=None,
correlation_cholesky=None,
):
dist_axes = _default_vmap_over(
dist, marginal_dist=marginal_dist, correlation_matrix=correlation_matrix
)
dist_axes.base_dist = vmap_over(
dist.base_dist,
loc=correlation_matrix if correlation_matrix == 0 else correlation_cholesky,
scale_tril=correlation_matrix
if correlation_matrix == 0
else correlation_cholesky,
covariance_matrix=correlation_matrix,
)
return dist_axes
@vmap_over.register
def _vmap_over_gausian_copula_beta(
dist: GaussianCopulaBeta,
concentration1=None,
concentration0=None,
correlation_matrix=None,
correlation_cholesky=None,
):
d = vmap_over.dispatch(GaussianCopula)(
dist,
vmap_over(
dist.marginal_dist,
concentration1=concentration1,
concentration0=concentration0,
),
correlation_matrix=correlation_matrix,
correlation_cholesky=correlation_cholesky,
)
d.concentration1 = concentration1
d.concentration0 = concentration0
return d
@vmap_over.register
def _vmap_over_half_cauchy(dist: HalfCauchy, scale=None):
dist_axes = _default_vmap_over(dist, scale=scale)
dist_axes._cauchy = vmap_over(dist._cauchy, loc=scale, scale=scale)
return dist_axes
@vmap_over.register
def _vmap_over_inverse_gamma(dist: InverseGamma, concentration=None, rate=None):
dist_axes = _default_vmap_over(dist, concentration=concentration, rate=rate)
dist_axes.base_dist = vmap_over(
dist.base_dist, concentration=concentration, rate=rate
)
dist_axes.transforms = None
return dist_axes
@vmap_over.register
def _vmap_over_uniform(dist: Uniform, low=None, high=None):
dist_axes = _default_vmap_over(dist, low=low, high=high)
dist_axes._support = vmap_over(dist._support, lower_bound=low, upper_bound=high)
return dist_axes
@vmap_over.register
def _vmap_over_kumaraswamy(dist: Kumaraswamy, concentration0=None, concentration1=None):
dist_axes = _default_vmap_over(
dist, concentration0=concentration0, concentration1=concentration1
)
dist_axes.concentration0 = concentration0
dist_axes.concentration1 = concentration1
return dist_axes
@vmap_over.register
def _vmap_over_lkj(dist: LKJ, concentration=None):
dist_axes = _default_vmap_over(dist, concentration=concentration)
dist_axes.base_dist = vmap_over(dist.base_dist, concentration=concentration)
dist_axes.transforms = None
return dist_axes
@vmap_over.register
def _vmap_over_lkj_cholesky(dist: LKJCholesky, concentration):
dist_axes = _default_vmap_over(dist, concentration=concentration)
if dist_axes.sample_method == "onion":
dist_axes._beta = vmap_over(
dist._beta, concentration1=None, concentration0=concentration
)
elif dist_axes.sample_method == "cvine":
dist_axes._beta = vmap_over(
dist._beta, concentration1=concentration, concentration0=concentration
)
return dist_axes
@vmap_over.register
def _vmap_over_lognormal(dist: LogNormal, loc=None, scale=None):
dist_axes = _default_vmap_over(dist, loc=loc, scale=scale)
dist_axes.transforms = None
dist_axes.base_dist = vmap_over(dist.base_dist, loc=loc, scale=scale)
return dist_axes
@vmap_over.register
def _vmap_over_loguniform(dist: LogUniform, low=None, high=None):
dist_axes = _default_vmap_over(dist, low=low, high=high)
dist_axes.base_dist = vmap_over(dist.base_dist, low=low, high=high)
dist_axes._support = vmap_over(dist._support, lower_bound=low, upper_bound=high)
return dist_axes
@vmap_over.register
def _vmap_over_car(
dist: CAR, loc=None, correlation=None, conditional_precision=None, adj_matrix=None
):
dist_axes = _default_vmap_over(
dist,
loc=loc,
correlation=correlation,
conditional_precision=conditional_precision,
)
if not dist.is_sparse:
dist_axes.adj_matrix = adj_matrix
dist_axes.precision_matrix = adj_matrix
else:
assert adj_matrix is None
return dist_axes
@vmap_over.register
def _vmap_over_multivariate_student_t(
dist: MultivariateStudentT, df=None, loc=None, scale_tril=None
):
dist_axes = _default_vmap_over(dist, df=df, loc=loc, scale_tril=scale_tril)
dist_axes._chi2 = vmap_over(dist._chi2, df=df)
return dist_axes
@vmap_over.register
def _vmap_over_low_rank_multivariate_normal(
dist: LowRankMultivariateNormal, loc=None, cov_factor=None, cov_diag=None
):
dist_axes = _default_vmap_over(
dist, loc=loc, cov_factor=cov_factor, cov_diag=cov_diag
)
dist_axes._capacitance_tril = cov_diag if cov_diag is not None else cov_factor
return dist_axes
@vmap_over.register
def _vmap_over_pareto(dist: Pareto, scale=None, alpha=None):
dist_axes = _default_vmap_over(dist, scale=scale, alpha=alpha)
dist_axes.base_dist = vmap_over(dist.base_dist, rate=alpha)
dist_axes.transforms = [None, vmap_over(dist.transforms[1], loc=None, scale=scale)]
return dist_axes
@vmap_over.register
def _vmap_over_relaxed_bernoulli_logits(
dist: RelaxedBernoulliLogits, temperature=None, logits=None
):
dist_axes = _default_vmap_over(dist, temperature=temperature, logits=logits)
dist_axes.transforms = None
dist_axes.base_dist = vmap_over(
dist.base_dist,
loc=logits if logits is not None else temperature,
scale=temperature,
)
return dist_axes
@vmap_over.register
def _vmap_over_student_t(dist: StudentT, df=None, loc=None, scale=None):
dist_axes = _default_vmap_over(dist, df=df, loc=loc, scale=scale)
dist_axes._chi2 = vmap_over(dist._chi2, df=df)
return dist_axes
@vmap_over.register
def _vmap_over_two_sided_truncated_distribution(
dist: TwoSidedTruncatedDistribution, low=None, high=None
):
dist_axes = _default_vmap_over(dist, low=low, high=high)
dist_axes.base_dist = None
dist_axes._support = vmap_over(dist._support, lower_bound=low, upper_bound=high)
return dist_axes
@vmap_over.register
def _vmap_over_left_truncated_distribution(dist: LeftTruncatedDistribution, low=None):
dist_axes = _default_vmap_over(dist, low=low)
dist_axes.base_dist = None
dist_axes._support = vmap_over(dist._support, lower_bound=low)
return dist_axes
@vmap_over.register
def _vmap_over_right_truncated_distribution(
dist: RightTruncatedDistribution, high=None
):
dist_axes = _default_vmap_over(dist, high=high)
dist_axes.base_dist = None
dist_axes._support = vmap_over(dist._support, upper_bound=high)
return dist_axes
@vmap_over.register
def _vmap_over_beta_binomial(
dist: BetaBinomial, concentration1=None, concentration0=None, total_count=None
):
dist_axes = _default_vmap_over(
dist,
concentration1=concentration1,
concentration0=concentration0,
total_count=total_count,
)
dist_axes._beta = vmap_over(
dist._beta, concentration1=concentration1, concentration0=concentration0
)
return dist_axes
@vmap_over.register
def _vmap_over_dirichlet_multinomial(dist: DirichletMultinomial, concentration=None):
dist_axes = _default_vmap_over(dist, concentration=concentration)
dist_axes._dirichlet = vmap_over(dist._dirichlet, concentration=concentration)
return dist_axes
@vmap_over.register
def _vmap_over_gamma_poisson(dist: GammaPoisson, concentration=None, rate=None):
dist_axes = _default_vmap_over(dist, concentration=concentration, rate=rate)
dist_axes._gamma = vmap_over(dist._gamma, concentration=concentration, rate=rate)
return dist_axes
@vmap_over.register
def _vmap_over_negative_binomial_probs(
dist: NegativeBinomialProbs, total_count=None, probs=None
):
dist_axes = vmap_over.dispatch(GammaPoisson)(
dist, concentration=total_count, rate=probs
)
dist_axes.total_count = total_count
dist_axes.probs = probs
return dist_axes
@vmap_over.register
def _vmap_over_negative_binomial_logits(
dist: NegativeBinomialLogits, total_count=None, logits=None
):
dist_axes = vmap_over.dispatch(GammaPoisson)(
dist, concentration=total_count, rate=logits
)
dist_axes.total_count = total_count
dist_axes.logits = logits
return dist_axes
@vmap_over.register
def _vmap_over_negative_binomial_2(
dist: NegativeBinomial2, mean=None, concentration=None
):
return vmap_over.dispatch(GammaPoisson)(
dist,
concentration=concentration,
rate=concentration if concentration is not None else mean,
)
@vmap_over.register
def _vmap_over_ordered_logistic(dist: OrderedLogistic, predictor=None, cutpoints=None):
dist_axes = vmap_over.dispatch(CategoricalProbs)(
dist, probs=predictor if predictor is not None else cutpoints
)
dist_axes.predictor = predictor
dist_axes.cutpoints = cutpoints
return dist_axes
@vmap_over.register
def _vmap_over_discrete_uniform(dist: DiscreteUniform, low=None, high=None):
dist_axes = _default_vmap_over(dist, low=low, high=high)
dist_axes._support = vmap_over(dist._support, lower_bound=low, upper_bound=high)
return dist_axes
@vmap_over.register
def _vmap_over_zero_inflated_poisson(dist: ZeroInflatedPoisson, gate=None, rate=None):
dist_axes = vmap_over.dispatch(ZeroInflatedProbs)(
dist, base_dist=vmap_over(dist.base_dist, rate=rate), gate=gate
)
dist_axes.rate = rate
return dist_axes
@vmap_over.register
def _vmap_over_half_normal(dist: HalfNormal, scale=None):
dist_axes = _default_vmap_over(dist, scale=scale)
dist_axes._normal = vmap_over(dist._normal, loc=scale, scale=scale)
return dist_axes
@singledispatch
def promote_batch_shape(d: Distribution):
raise NotImplementedError
@promote_batch_shape.register
def _default_promote_batch_shape(d: Distribution):
attr_name = list(d.arg_constraints.keys())[0]
attr_event_dim = d.arg_constraints[attr_name].event_dim
attr = getattr(d, attr_name)
resolved_batch_shape = attr.shape[
: max(0, attr.ndim - d.event_dim - attr_event_dim)
]
new_self = copy.deepcopy(d)
new_self._batch_shape = resolved_batch_shape
return new_self
@promote_batch_shape.register
def _promote_batch_shape_expanded(d: ExpandedDistribution):
orig_delta_batch_shape = d.batch_shape[
: len(d.batch_shape) - len(d.base_dist.batch_shape)
]
new_self = copy.deepcopy(d)
# new dimensions coming from a vmap or numpyro scan/enum operation
promoted_base_dist = promote_batch_shape(new_self.base_dist)
new_shapes_elems = promoted_base_dist.batch_shape[
: len(promoted_base_dist.batch_shape) - len(d.base_dist.batch_shape)
]
# The new dimensions are appended in front of the previous ExpandedDistribution
# batch dimensions. However, these batch dimensions are now present in
# the base distribution. Thus the dimensions present in the original
# ExpandedDistribution batch_shape, but not in the original base distribution
# batch_shape are now intermediate dimensions: to maintain broadcastability,
# the attribute of the batch distribution are expanded with such intermediate
# dimensions.
new_self._batch_shape = (*new_shapes_elems, *d.batch_shape)
new_self.base_dist._batch_shape = (
*new_shapes_elems,
*tuple(1 for _ in orig_delta_batch_shape),
*d.base_dist.batch_shape,
)
new_axes_locs = range(
len(new_shapes_elems),
len(new_shapes_elems) + len(orig_delta_batch_shape),
)
new_base_dist = tree_map(
lambda x: jnp.expand_dims(x, axis=new_axes_locs), new_self.base_dist
)
new_self.base_dist = new_base_dist
return new_self
@promote_batch_shape.register
def _promote_batch_shape_masked(d: MaskedDistribution):
new_self = copy.copy(d)
new_base_dist = promote_batch_shape(d.base_dist)
new_self._batch_shape = new_base_dist.batch_shape
new_self.base_dist = new_base_dist
return new_self
@promote_batch_shape.register
def _promote_batch_shape_independent(d: Independent):
new_self = copy.copy(d)
new_base_dist = promote_batch_shape(d.base_dist)
new_self._batch_shape = new_base_dist.batch_shape[: -d.event_dim]
new_self.base_dist = new_base_dist
return new_self
@promote_batch_shape.register
def _promote_batch_shape_unit(d: Unit):
return d