Skip to content

Commit

Permalink
raise error when reparameterize lognormal (#1548)
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi committed Mar 9, 2023
1 parent 28d7f63 commit 9a43b19
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
9 changes: 9 additions & 0 deletions numpyro/infer/reparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,15 @@ def __init__(self, centered=None, shape_params=()):

def __call__(self, name, fn, obs):
assert obs is None, "LocScaleReparam does not support observe statements"
support = fn.support
if isinstance(support, constraints.independent):
support = fn.support.base_constraint
if support is not constraints.real:
raise ValueError(
"LocScaleReparam only supports distributions with real "
f"support, but got {support} support at site {name}."
)

centered = self.centered
if is_identically_one(centered):
return fn, obs
Expand Down
9 changes: 9 additions & 0 deletions test/infer/test_reparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,3 +390,12 @@ def model():
)
log_det_site = handlers.trace(substituted_model).get_trace()["_x_log_det"]
assert log_det_site["scale"] == 5.0


def test_loc_scale_reparam_raise_for_log_normal():
def model():
numpyro.sample("x", dist.LogNormal(0, 1))

reparam_model = handlers.reparam(model, config={"x": LocScaleReparam(0)})
with pytest.raises(ValueError, match="LocScaleReparam.*"):
handlers.seed(reparam_model, rng_seed=0)()

0 comments on commit 9a43b19

Please sign in to comment.