Skip to content

Commit

Permalink
Do not mutate shapes of ExpandedDistribution for map-free ops (#1574)
Browse files Browse the repository at this point in the history
  • Loading branch information
pierreglaser committed May 18, 2023
1 parent a8004f0 commit db2913b
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 2 deletions.
16 changes: 14 additions & 2 deletions numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,15 +613,27 @@ def tree_flatten(self):
self.base_dist,
)
base_flatten, base_aux = base_dist.tree_flatten()
return base_flatten, (type(self.base_dist), base_aux, self.batch_shape)
return base_flatten, (
type(self.base_dist),
base_aux,
self.batch_shape,
prepend_ndim,
)

@classmethod
def tree_unflatten(cls, aux_data, params):
base_cls, base_aux, batch_shape = aux_data
base_cls, base_aux, batch_shape, prepend_ndim = aux_data
base_dist = base_cls.tree_unflatten(base_aux, params)
prepend_shape = base_dist.batch_shape[
: len(base_dist.batch_shape) - len(batch_shape)
]
if len(prepend_shape) == 0:
# in that case, no additional dimension was added
# to the flattened distribution, and the batch_shape
# manipulation happening during the flattening can be
# reverted
base_dist._batch_shape = base_dist.batch_shape[prepend_ndim:]
return cls(base_dist, batch_shape=batch_shape)
return cls(base_dist, batch_shape=prepend_shape + batch_shape)


Expand Down
37 changes: 37 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2666,6 +2666,43 @@ def g(x):
assert tree_map(lambda x: x[None], g(0)).batch_shape == (1, 10, 3)


def test_expand_no_unnecessary_batch_shape_expansion():
# ExpandedDistribution can mutate the `batch_shape` of
# its base distribution in order to make ExpandedDistribution
# mappable, see #684. However, this mutation should not take
# place if no mapping operation is performed.

for arg in (jnp.array(1.0), jnp.ones((2,)), jnp.ones((2, 2))):
# Low level test: ensure that (tree_flatten o tree_unflatten)(expanded_dist)
# amounts to an identity operation.
d = dist.Normal(arg, arg).expand([10, 3, *arg.shape])
roundtripped_d = type(d).tree_unflatten(*d.tree_flatten()[::-1])
assert d.batch_shape == roundtripped_d.batch_shape
assert d.base_dist.batch_shape == roundtripped_d.base_dist.batch_shape
assert d.base_dist.event_shape == roundtripped_d.base_dist.event_shape
assert jnp.allclose(d.base_dist.loc, roundtripped_d.base_dist.loc)
assert jnp.allclose(d.base_dist.scale, roundtripped_d.base_dist.scale)

# High-level test: `jax.jit`ting a function returning an ExpandedDistribution
# (which involves an instance of the low-level case as it will transform
# the original function by adding some flattening and unflattening steps)
# should return same object as its non-jitted equivalent.
def bs(arg):
return dist.Normal(arg, arg).expand([10, 3, *arg.shape])

d = bs(arg)
dj = jax.jit(bs)(arg)

assert isinstance(d, dist.ExpandedDistribution)
assert isinstance(dj, dist.ExpandedDistribution)

assert d.batch_shape == dj.batch_shape
assert d.base_dist.batch_shape == dj.base_dist.batch_shape
assert d.base_dist.event_shape == dj.base_dist.event_shape
assert jnp.allclose(d.base_dist.loc, dj.base_dist.loc)
assert jnp.allclose(d.base_dist.scale, dj.base_dist.scale)


@pytest.mark.parametrize("batch_shape", [(), (4,), (2, 3)], ids=str)
def test_kl_delta_normal_shape(batch_shape):
v = np.random.normal(size=batch_shape)
Expand Down

0 comments on commit db2913b

Please sign in to comment.