Skip to content

Commit

Permalink
Support for VAE in AutoSemiDAIS (#1619)
Browse files Browse the repository at this point in the history
* support model without global variables in AutoSemiDAIS

* add test for autosemidais local only

* black

* fix typo in docs

* support for vae in semidais

* fix bug using wrong sign of potential energy

* no need to store prototype local model trace

* add docs for local_guide in semidais

* allow params in local model

* fix wrong scale at z0

* add comment for why we divide by subsample_size at z_0 log prob

* address comment
  • Loading branch information
fehiepsi committed Jul 10, 2023
1 parent 9618266 commit 428dee9
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 50 deletions.
161 changes: 112 additions & 49 deletions numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@
from numpyro.infer import Predictive
from numpyro.infer.elbo import Trace_ELBO
from numpyro.infer.initialization import init_to_median, init_to_uniform
from numpyro.infer.util import helpful_support_errors, initialize_model
from numpyro.infer.util import (
helpful_support_errors,
initialize_model,
potential_energy,
)
from numpyro.nn.auto_reg_nn import AutoregressiveNN
from numpyro.nn.block_neural_arn import BlockNeuralAutoregressiveNN
from numpyro.util import not_jax_tracer
Expand Down Expand Up @@ -1134,6 +1138,8 @@ def local_model(theta):
:param callable global_guide: A guide for the global latent variables, e.g. an autoguide.
The return type should be a dictionary of latent sample sites names and corresponding samples.
If there is no global variable in the model, we can set this to None.
:param callable local_guide: An optional guide for specifying the DAIS base distribution for
local latent variables.
:param str prefix: A prefix that will be prefixed to all internal sites.
:param int K: A positive integer that controls the number of HMC steps used.
Defaults to 4.
Expand All @@ -1152,6 +1158,7 @@ def __init__(
model,
local_model,
global_guide,
local_guide=None,
*,
prefix="auto",
K=4,
Expand All @@ -1177,6 +1184,7 @@ def __init__(

self.local_model = local_model
self.global_guide = global_guide
self.local_guide = local_guide
self.eta_init = eta_init
self.eta_max = eta_max
self.gamma_init = gamma_init
Expand All @@ -1186,6 +1194,7 @@ def __init__(
def _setup_prototype(self, *args, **kwargs):
super()._setup_prototype(*args, **kwargs)
# extract global/local/local_dim/plates
assert self.prototype_trace is not None
subsample_plates = {
name: site
for name, site in self.prototype_trace.items()
Expand Down Expand Up @@ -1225,9 +1234,10 @@ def _setup_prototype(self, *args, **kwargs):
for k, v in local_init_locs.items()
}
_, shape_dict = _ravel_dict(one_sample)
local_init_latent = jax.vmap(
self._pack_local_latent = jax.vmap(
lambda x: _ravel_dict(x)[0], in_axes=(subsample_axes,)
)(local_init_locs)
)
local_init_latent = self._pack_local_latent(local_init_locs)
unpack_latent = partial(_unravel_dict, shape_dict=shape_dict)
# this is to match the behavior of Pyro, where we can apply
# unpack_latent for a batch of samples
Expand All @@ -1246,23 +1256,14 @@ def _setup_prototype(self, *args, **kwargs):
local_args = args
local_kwargs = kwargs.copy()

with handlers.block():
local_kwargs["_subsample_idx"] = {
plate_name: subsample_plates[plate_name]["value"]
}
(
_,
self._local_potential_fn_gen,
self._local_postprecess_fn,
_,
) = initialize_model(
random.PRNGKey(0),
partial(_subsample_model, self.local_model),
init_strategy=self.init_loc_fn,
dynamic_args=True,
model_args=local_args,
model_kwargs=local_kwargs,
)
if self.local_guide is not None:
with handlers.block(), handlers.trace() as tr, handlers.seed(rng_seed=0):
self.local_guide(*local_args, **local_kwargs)
self.prototype_local_guide_trace = tr

with handlers.block(), handlers.trace() as tr, handlers.seed(rng_seed=0):
self.local_model(*local_args, **local_kwargs)
self.prototype_local_model_trace = tr

def __call__(self, *args, **kwargs):
if self.prototype_trace is None:
Expand Down Expand Up @@ -1305,16 +1306,6 @@ def _get_posterior(self):
def _sample_latent(self, *args, **kwargs):
kwargs.pop("sample_shape", ())

def make_local_log_density(*local_args, **local_kwargs):
def fn(x):
x_unpack = self._unpack_local_latent(x)
with numpyro.handlers.block():
return -self._local_potential_fn_gen(*local_args, **local_kwargs)(
x_unpack
)

return fn

if self.global_guide is not None:
global_latents = self.global_guide(*args, **kwargs)
rng_key = numpyro.prng_key()
Expand All @@ -1329,6 +1320,34 @@ def fn(x):
local_args = args
local_kwargs = kwargs.copy()

local_guide_params = {}
if self.local_guide is not None:
for name, site in self.prototype_local_guide_trace.items():
if site["type"] == "param":
local_guide_params[name] = numpyro.param(
name, site["value"], **site["kwargs"]
)

local_model_params = {}
for name, site in self.prototype_local_model_trace.items():
if site["type"] == "param":
local_model_params[name] = numpyro.param(
name, site["value"], **site["kwargs"]
)

def make_local_log_density(*local_args, **local_kwargs):
def fn(x):
x_unpack = self._unpack_local_latent(x)
with numpyro.handlers.block():
return -potential_energy(
partial(_subsample_model, self.local_model),
local_args,
local_kwargs,
{**x_unpack, **local_model_params},
)

return fn

plate_name, N, subsample_size = self._local_plate
D, K = self._local_latent_dim, self.K

Expand Down Expand Up @@ -1366,25 +1385,70 @@ def fn(x):
)
inv_mass_matrix = 0.5 / mass_matrix
assert inv_mass_matrix.shape == (subsample_size, D)
z_0_loc_init = jnp.zeros((N, D))
z_0_loc = numpyro.param(
"{}_z_0_loc".format(self.prefix), z_0_loc_init, event_dim=1
)
z_0_scale_init = jnp.ones((N, D)) * self.init_scale
z_0_scale = numpyro.param(
"{}_z_0_scale".format(self.prefix),
z_0_scale_init,
constraint=constraints.positive,
event_dim=1,
)
base_z_dist = dist.Normal(z_0_loc, z_0_scale).to_event(1)
assert base_z_dist.shape() == (subsample_size, D)
z_0 = numpyro.sample(
"{}_z_0".format(self.prefix), base_z_dist, infer={"is_auxiliary": True}
)

def base_z_dist_log_prob(x):
return base_z_dist.log_prob(x).sum()
local_kwargs["_subsample_idx"] = {plate_name: idx}
if self.local_guide is not None:
key = numpyro.prng_key()
subsample_guide = partial(_subsample_model, self.local_guide)
with handlers.block(), handlers.trace() as tr, handlers.seed(
rng_seed=key
), handlers.substitute(data=local_guide_params):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
subsample_guide(*local_args, **local_kwargs)
latent = {
name: biject_to(site["fn"].support).inv(site["value"])
for name, site in tr.items()
if site["type"] == "sample"
and not site.get("is_observed", False)
}
z_0 = self._pack_local_latent(latent)

def base_z_dist_log_prob(z):
latent = self._unpack_local_latent(z)
assert isinstance(latent, dict)
with handlers.block():
with warnings.catch_warnings():
warnings.simplefilter("ignore")
scale = N / subsample_size
return (
-potential_energy(
subsample_guide,
local_args,
local_kwargs,
{**local_guide_params, **latent},
)
/ scale
)

# The log_prob of z_0 will be broadcasted to `subsample_size` because this statement
# is run under the subsample plate. Hence we divide the log_prob by `subsample_size`.
numpyro.factor(
"{}_z_0_factor".format(self.prefix),
base_z_dist_log_prob(z_0) / subsample_size,
)
else:
z_0_loc_init = jnp.zeros((N, D))
z_0_loc = numpyro.param(
"{}_z_0_loc".format(self.prefix), z_0_loc_init, event_dim=1
)
z_0_scale_init = jnp.ones((N, D)) * self.init_scale
z_0_scale = numpyro.param(
"{}_z_0_scale".format(self.prefix),
z_0_scale_init,
constraint=constraints.positive,
event_dim=1,
)
base_z_dist = dist.Normal(z_0_loc, z_0_scale).to_event(1)
assert base_z_dist.shape() == (subsample_size, D)
z_0 = numpyro.sample(
"{}_z_0".format(self.prefix),
base_z_dist,
infer={"is_auxiliary": True},
)

def base_z_dist_log_prob(x):
return base_z_dist.log_prob(x).sum()

momentum_dist = dist.Normal(0, mass_matrix).to_event(1)
eps = numpyro.sample(
Expand All @@ -1396,7 +1460,6 @@ def base_z_dist_log_prob(x):
infer={"is_auxiliary": True},
)

local_kwargs["_subsample_idx"] = {plate_name: idx}
local_log_density = make_local_log_density(*local_args, **local_kwargs)

def scan_body(carry, eps_beta):
Expand Down
2 changes: 2 additions & 0 deletions numpyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,8 @@ def unconstrain_fn(model, model_args, model_kwargs, params):
def _unconstrain_reparam(params, site):
name = site["name"]
if name in params:
if site["type"] != "sample":
return params[name]
p = params[name]
support = site["fn"].support
with helpful_support_errors(site):
Expand Down
6 changes: 5 additions & 1 deletion test/infer/test_autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,11 @@ def model():
batch = numpyro.subsample(data, event_dim=0)
numpyro.sample("x", dist.Normal(batch, 1))

guide = AutoSemiDAIS(model, model, None)
def create_plates():
return numpyro.plate("N", 10, subsample_size=5, dim=-1)

local_guide = AutoNormal(model, create_plates=create_plates)
guide = AutoSemiDAIS(model, model, None, local_guide=local_guide)
svi = SVI(model, guide, optim.Adam(0.01), Trace_ELBO())
svi_result = svi.run(random.PRNGKey(0), 10)
samples = guide.sample_posterior(
Expand Down

0 comments on commit 428dee9

Please sign in to comment.