Skip to content

Commit

Permalink
Fix incorrect unflattenning of inverse transforms (#1600)
Browse files Browse the repository at this point in the history
  • Loading branch information
pierreglaser committed Jun 6, 2023
1 parent 4e3c50e commit 523162f
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
7 changes: 4 additions & 3 deletions numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,10 @@ def inverse_shape(self, shape):
def tree_flatten(self):
return (self._inv,), (("_inv",), dict())

@classmethod
def tree_unflatten(cls, aux_data, params):
return cls(params)
def __eq__(self, other):
if not isinstance(other, _InverseTransform):
return False
return self._inv == other._inv


class AbsTransform(ParameterFreeTransform):
Expand Down
2 changes: 2 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ def out_t(transform, x):
assert jitted_in_t(transform, 1.0) == 1.0
assert jitted_out_t(transform, 1.0) == transform

assert jitted_out_t(transform.inv, 1.0) == transform.inv

assert jnp.allclose(
vmap(in_t, in_axes=(None, 0), out_axes=0)(transform, jnp.ones(3)),
jnp.ones(3),
Expand Down

0 comments on commit 523162f

Please sign in to comment.