Skip to content

Commit

Permalink
Support model without global variables in AutoSemiDAIS (#1610)
Browse files Browse the repository at this point in the history
* support model without global variables in AutoSemiDAIS

* add test for autosemidais local only

* black

* fix typo in docs
  • Loading branch information
fehiepsi committed Jun 19, 2023
1 parent 7291cba commit 0e50bac
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 24 deletions.
60 changes: 36 additions & 24 deletions numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,8 +1119,8 @@ def local_model(theta):
numpyro.sample("obs", dist.Normal(0.0, tau), obs=jnp.ones(2))
model = lambda: local_model(global_model())
base_guide = AutoNormal(global_model)
guide = AutoSemiDAIS(model, local_model, base_guide, K=4)
global_guide = AutoNormal(global_model)
guide = AutoSemiDAIS(model, local_model, global_guide, K=4)
svi = SVI(model, guide, ...)
# sample posterior for particular data subset {3, 7}
Expand All @@ -1131,8 +1131,9 @@ def local_model(theta):
:param callable local_model: The portion of `model` that includes the local latent variables only.
The signature of `local_model` should be the return type of the global model with global latent
variables only.
:param callable base_guide: A guide for the global latent variables, e.g. an autoguide.
:param callable global_guide: A guide for the global latent variables, e.g. an autoguide.
The return type should be a dictionary of latent sample sites names and corresponding samples.
If there is no global variable in the model, we can set this to None.
:param str prefix: A prefix that will be prefixed to all internal sites.
:param int K: A positive integer that controls the number of HMC steps used.
Defaults to 4.
Expand All @@ -1150,7 +1151,7 @@ def __init__(
self,
model,
local_model,
base_guide,
global_guide,
*,
prefix="auto",
K=4,
Expand All @@ -1175,7 +1176,7 @@ def __init__(
raise ValueError("init_scale must be positive.")

self.local_model = local_model
self.base_guide = base_guide
self.global_guide = global_guide
self.eta_init = eta_init
self.eta_max = eta_max
self.gamma_init = gamma_init
Expand Down Expand Up @@ -1237,25 +1238,30 @@ def _setup_prototype(self, *args, **kwargs):
self._local_latent_dim = jnp.size(local_init_latent) // plate_subsample_size
self._local_plate = (plate_name, plate_full_size, plate_subsample_size)

rng_key = numpyro.prng_key()
with handlers.block(), handlers.seed(rng_seed=rng_key):
global_output = self.base_guide.model(*args, **kwargs)
if self.global_guide is not None:
with handlers.block(), handlers.seed(rng_seed=0):
local_args = (self.global_guide.model(*args, **kwargs),)
local_kwargs = {}
else:
local_args = args
local_kwargs = kwargs.copy()

with handlers.block():
local_kwargs["_subsample_idx"] = {
plate_name: subsample_plates[plate_name]["value"]
}
(
_,
self._local_potential_fn_gen,
self._local_postprecess_fn,
_,
) = initialize_model(
numpyro.prng_key(),
random.PRNGKey(0),
partial(_subsample_model, self.local_model),
init_strategy=self.init_loc_fn,
dynamic_args=True,
model_args=(global_output,),
model_kwargs={
"_subsample_idx": {
plate_name: subsample_plates[plate_name]["value"]
}
},
model_args=local_args,
model_kwargs=local_kwargs,
)

def __call__(self, *args, **kwargs):
Expand Down Expand Up @@ -1309,12 +1315,19 @@ def fn(x):

return fn

global_latents = self.base_guide(*args, **kwargs)
rng_key = numpyro.prng_key()
with handlers.block(), handlers.seed(rng_seed=rng_key), handlers.substitute(
data=global_latents
):
global_output = self.base_guide.model(*args, **kwargs)
if self.global_guide is not None:
global_latents = self.global_guide(*args, **kwargs)
rng_key = numpyro.prng_key()
with handlers.block(), handlers.seed(rng_seed=rng_key), handlers.substitute(
data=global_latents
):
global_outputs = self.global_guide.model(*args, **kwargs)
local_args = (global_ouputs,)
local_kwargs = {}
else:
global_latents = {}
local_args = args
local_kwargs = kwargs.copy()

plate_name, N, subsample_size = self._local_plate
D, K = self._local_latent_dim, self.K
Expand Down Expand Up @@ -1383,9 +1396,8 @@ def base_z_dist_log_prob(x):
infer={"is_auxiliary": True},
)

local_log_density = make_local_log_density(
global_output, _subsample_idx={plate_name: idx}
)
local_kwargs["_subsample_idx"] = {plate_name: idx}
local_log_density = make_local_log_density(*local_args, **local_kwargs)

def scan_body(carry, eps_beta):
eps, beta = eps_beta
Expand Down
17 changes: 17 additions & 0 deletions test/infer/test_autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,23 @@ def model():
assert samples["sigma"].shape == (5,) and samples["log_sigma"].shape == (5, 2)


def test_autosemidais_local_only():
data = jnp.linspace(0, 1, 10)

def model():
with numpyro.plate("N", 10, subsample_size=5, dim=-1):
batch = numpyro.subsample(data, event_dim=0)
numpyro.sample("x", dist.Normal(batch, 1))

guide = AutoSemiDAIS(model, model, None)
svi = SVI(model, guide, optim.Adam(0.01), Trace_ELBO())
svi_result = svi.run(random.PRNGKey(0), 10)
samples = guide.sample_posterior(
random.PRNGKey(1), svi_result.params, sample_shape=(100,)
)
assert samples["x"].shape == (100, 5)


def test_autosemidais_inadmissible_smoke():
def global_model():
return numpyro.sample("theta", dist.Normal(0, 1))
Expand Down

0 comments on commit 0e50bac

Please sign in to comment.