Skip to content

Commit

Permalink
Include deterministic variables in AutoDelta's sample_posterior (#1584)
Browse files Browse the repository at this point in the history
Co-authored-by: Niklas M <michel.niklas@gmail.com>
  • Loading branch information
nikmich1 and Niklas M committed May 9, 2023
1 parent a2c28c1 commit c6fb104
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 8 deletions.
18 changes: 16 additions & 2 deletions numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
periodic_repeat,
sum_rightmost,
)
from numpyro.infer import Predictive
from numpyro.infer.elbo import Trace_ELBO
from numpyro.infer.initialization import init_to_median, init_to_uniform
from numpyro.infer.util import helpful_support_errors, initialize_model
Expand Down Expand Up @@ -455,12 +456,25 @@ def __call__(self, *args, **kwargs):

return result

def sample_posterior(self, rng_key, params, sample_shape=()):
def sample_posterior(self, rng_key, params, *args, sample_shape=(), **kwargs):
locs = {k: params["{}_{}_loc".format(k, self.prefix)] for k in self._init_locs}
latent_samples = {
k: jnp.broadcast_to(v, sample_shape + jnp.shape(v)) for k, v in locs.items()
}
return latent_samples
deterministic_vars = [
k for k, v in self.prototype_trace.items() if v["type"] == "deterministic"
]
if not deterministic_vars:
return latent_samples
else:
predictive = Predictive(
model=self.model,
posterior_samples=latent_samples,
return_sites=deterministic_vars,
batch_ndims=len(sample_shape),
)
deterministic_samples = predictive(rng_key, *args, **kwargs)
return {**latent_samples, **deterministic_samples}

def median(self, params):
locs = {k: params["{}_{}_loc".format(k, self.prefix)] for k in self._init_locs}
Expand Down
44 changes: 38 additions & 6 deletions test/infer/test_autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,12 @@ def test_logistic_regression(auto_class, Elbo):
logits = jnp.sum(true_coefs * data, axis=-1)
labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

def model(data, labels):
def model(data=None, labels=None):
coefs = numpyro.sample("coefs", dist.Normal(0, 1).expand([dim]).to_event())
logits = numpyro.deterministic("logits", jnp.sum(coefs * data, axis=-1))
with numpyro.plate("N", len(data)):
return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels)
if data is not None:
logits = numpyro.deterministic("logits", jnp.sum(coefs * data, axis=-1))
with numpyro.plate("N", len(data)):
return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels)

adam = optim.Adam(0.01)
rng_key_init = random.PRNGKey(1)
Expand Down Expand Up @@ -507,12 +508,12 @@ def create_plates(batch, subsample, full_size):
)
def test_autoguide_deterministic(auto_class):
def model(y=None):
n = y.size if y is not None else 1
n, len_y = (y.size, len(y)) if y is not None else (1, 1)

mu = numpyro.sample("mu", dist.Normal(0, 5))
sigma = numpyro.param("sigma", 1, constraint=constraints.positive)

with numpyro.plate("N", len(y)):
with numpyro.plate("N", len_y):
y = numpyro.sample("y", dist.Normal(mu, sigma).expand((n,)), obs=y)
numpyro.deterministic("z", (y - mu) / sigma)

Expand Down Expand Up @@ -931,3 +932,34 @@ def create_plates():
mf_elbo = -mf_elbo.item()

assert dais_elbo > mf_elbo + 0.1


def test_autodelta_capture_deterministic_variables():
def model():
x = numpyro.sample("x", dist.Normal())
numpyro.deterministic("x2", x**2)

guide = AutoDelta(model)
svi = SVI(model, guide, optim.Adam(0.01), Trace_ELBO())
svi_result = svi.run(random.PRNGKey(0), num_steps=1_000)
guide_samples = guide.sample_posterior(
rng_key=random.PRNGKey(1), params=svi_result.params
)
assert "x2" in guide_samples


@pytest.mark.parametrize("shape", [(), (1,), (2, 3)])
@pytest.mark.parametrize("sample_shape", [(), (1,), (2, 3)])
def test_autodelta_sample_posterior_with_sample_shape(shape, sample_shape):
def model():
x = numpyro.sample("x", dist.Normal().expand(shape))
numpyro.deterministic("x2", x**2)

guide = AutoDelta(model)
svi = SVI(model, guide, optim.Adam(0.01), Trace_ELBO())
svi_result = svi.run(random.PRNGKey(0), num_steps=1_000)
guide_samples = guide.sample_posterior(
rng_key=random.PRNGKey(1), params=svi_result.params, sample_shape=sample_shape
)
assert guide_samples["x"].shape == sample_shape + shape
assert guide_samples["x2"].shape == sample_shape + shape

0 comments on commit c6fb104

Please sign in to comment.