Skip to content

Commit

Permalink
Bug/steinvi reinit (#1626)
Browse files Browse the repository at this point in the history
* added separate guide for reinitialization.

* added test case for reinit.
  • Loading branch information
OlaRonning committed Aug 15, 2023
1 parent 62cce3d commit 4e37df3
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
3 changes: 2 additions & 1 deletion numpyro/contrib/einstein/steinvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def __init__(
self._inference_model = model
self.model = model
self.guide = guide
self._init_guide = deepcopy(guide)
self.optim = optim
self.stein_loss = SteinLoss( # TODO: @OlaRonning handle enum
elbo_num_particles=num_elbo_particles,
Expand Down Expand Up @@ -388,7 +389,7 @@ def init(self, rng_key: KeyArray, *args, **kwargs):
)

guide_init_params = self._find_init_params(
particle_seed, self.guide, args, kwargs
particle_seed, self._init_guide, args, kwargs
)

guide_init = handlers.seed(self.guide, guide_seed)
Expand Down
23 changes: 22 additions & 1 deletion test/contrib/einstein/test_steinvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from numpy.testing import assert_allclose
import pytest

from jax import random
from jax import numpy as jnp, random

import numpyro
from numpyro import handlers
Expand Down Expand Up @@ -193,6 +193,27 @@ def model():
return


def test_stein_reinit():
num_particles = 4

def model():
numpyro.sample("x", Normal(0, 1.0))

stein = SteinVI(
model,
AutoDelta(model),
Adam(1.0),
RBFKernel(),
num_stein_particles=num_particles,
)

for i in range(2):
with handlers.seed(rng_seed=i):
params = stein.get_params(stein.init(numpyro.prng_key()))
xs = params["x_auto_loc"]
assert jnp.unique(xs).shape == xs.shape


@pytest.mark.parametrize(
"auto_class",
[
Expand Down

0 comments on commit 4e37df3

Please sign in to comment.