Skip to content

Commit

Permalink
add forward shape for SimplexToOrderTransform (#1583)
Browse files Browse the repository at this point in the history
* add forward shape for SimplexToOrderTransform

* run black
  • Loading branch information
fehiepsi committed May 6, 2023
1 parent d63dae4 commit a2c28c1
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 8 deletions.
6 changes: 6 additions & 0 deletions numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,12 @@ def log_abs_det_jacobian(self, x, y, intermediates=None):
J_logdet = (softplus(y) + softplus(-y)).sum(-1)
return J_logdet

def forward_shape(self, shape):
return shape[:-1] + (shape[-1] - 1,)

def inverse_shape(self, shape):
return shape[:-1] + (shape[-1] + 1,)


def _softplus_inv(y):
return jnp.log(-jnp.expm1(-y)) + y
Expand Down
16 changes: 8 additions & 8 deletions numpyro/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,14 @@ def sample(
:param str name: name of the sample site.
:param fn: a stochastic function that returns a sample.
:param numpy.ndarray obs: observed value
:param jnp.ndarray obs: observed value
:param jax.random.PRNGKey rng_key: an optional random key for `fn`.
:param sample_shape: Shape of samples to be drawn.
:param dict infer: an optional dictionary containing additional information
for inference algorithms. For example, if `fn` is a discrete distribution,
setting `infer={'enumerate': 'parallel'}` to tell MCMC marginalize
this discrete latent site.
:param numpy.ndarray obs_mask: Optional boolean array mask of shape
:param jnp.ndarray obs_mask: Optional boolean array mask of shape
broadcastable with ``fn.batch_shape``. If provided, events with
mask=True will be conditioned on ``obs`` and remaining events will be
imputed by sampling. This introduces a latent sample site named ``name
Expand Down Expand Up @@ -235,7 +235,7 @@ def param(name, init_value=None, **kwargs):
Note that the onus of using this to initialize the optimizer is
on the user inference algorithm, since there is no global parameter
store in NumPyro.
:type init_value: numpy.ndarray or callable
:type init_value: jnp.ndarray or callable
:param constraint: NumPyro constraint, defaults to ``constraints.real``.
:type constraint: numpyro.distributions.constraints.Constraint
:param int event_dim: (optional) number of rightmost dimensions unrelated
Expand Down Expand Up @@ -289,7 +289,7 @@ def deterministic(name, value):
values in the model execution trace.
:param str name: name of the deterministic site.
:param numpy.ndarray value: deterministic value to record in the trace.
:param jnp.ndarray value: deterministic value to record in the trace.
"""
if not _PYRO_STACK:
return value
Expand Down Expand Up @@ -376,7 +376,7 @@ def model():
# ...
:returns: The mask.
:rtype: None, bool, or numpy.ndarray
:rtype: None, bool, or jnp.ndarray
"""
return _inspect()["mask"]

Expand Down Expand Up @@ -607,7 +607,7 @@ def factor(name, log_factor):
probabilistic model.
:param str name: Name of the trivial sample.
:param numpy.ndarray log_factor: A possibly batched log probability factor.
:param jnp.ndarray log_factor: A possibly batched log probability factor.
"""
unit_dist = numpyro.distributions.distribution.Unit(log_factor)
unit_value = unit_dist.sample(None)
Expand Down Expand Up @@ -657,11 +657,11 @@ def model(data):
data = numpyro.subsample(data, event_dim=0)
# ...
:param numpy.ndarray data: A tensor of batched data.
:param jnp.ndarray data: A tensor of batched data.
:param int event_dim: The event dimension of the data tensor. Dimensions to
the left are considered batch dimensions.
:returns: A subsampled version of ``data``
:rtype: ~numpy.ndarray
:rtype: ~jnp.ndarray
"""
if not _PYRO_STACK:
return data
Expand Down
10 changes: 10 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2300,6 +2300,16 @@ def test_composed_transform_1(batch_shape):
assert_allclose(log_det, expected_log_det)


@pytest.mark.parametrize("batch_shape", [(), (5,)])
def test_simplex_to_order_transform(batch_shape):
simplex = jnp.arange(5.0) / jnp.arange(5.0).sum()
simplex = jnp.broadcast_to(simplex, batch_shape + simplex.shape)
transform = SimplexToOrderedTransform()
out = transform(simplex)
assert out.shape == transform.forward_shape(simplex.shape)
assert simplex.shape == transform.inverse_shape(out.shape)


@pytest.mark.parametrize("batch_shape", [(), (5,)])
@pytest.mark.parametrize("prepend_event_shape", [(), (4,)])
@pytest.mark.parametrize("sample_shape", [(), (7,)])
Expand Down

0 comments on commit a2c28c1

Please sign in to comment.