Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when using obs_mask and predictive with different input shape #1847

Closed
felipeangelimvieira opened this issue Aug 9, 2024 · 2 comments
Labels
question Further information is requested

Comments

@felipeangelimvieira
Copy link

felipeangelimvieira commented Aug 9, 2024

First of all, thank you for this amazing library.

I've found that Predictive raises an unexpected error when using obs_mask. It happens when a certain shape is passed during SVI inference, but another is used in predictive, maybe related to #1772.

Here it is a code to reproduce it:

import numpyro 
import numpy as np
import jax
from numpyro.infer import MCMC, NUTS, SVI, Predictive, Trace_ELBO
from numpyro.infer.autoguide import AutoDelta
from numpyro.infer.initialization import init_to_mean
from numpyro.infer.svi import SVIRunResult

def model(y, x, obs_mask):
    
    
    a = numpyro.sample('a', numpyro.distributions.Normal(0, 1))
    b = numpyro.sample('b', numpyro.distributions.Normal(0, 1))
    sigma = numpyro.sample('sigma', numpyro.distributions.HalfNormal(1))
    
    mu = a + b * x
    numpyro.sample('y', numpyro.distributions.Normal(mu, sigma), obs=y,obs_mask=obs_mask)
    
    return None

x = np.random.normal(0, 1, 100)
y = 1 + 2 * x + np.random.normal(0, 1, 100)
obs_mask = np.ones_like(y, dtype=bool)
obs_mask[-20:] = False

guide_ = AutoDelta(model, init_loc_fn=init_to_mean())
svi_ = SVI(model, guide_, numpyro.optim.Adam(step_size=1e-4), loss=Trace_ELBO())
run_results_: SVIRunResult = svi_.run(
    rng_key=jax.random.PRNGKey(24), num_steps=1000, y=y, x=x, obs_mask=obs_mask
)


posterior_samples_ = guide_.sample_posterior(
    jax.random.PRNGKey(24), params=run_results_.params, y=y, x=x, obs_mask=obs_mask
)


predictive = numpyro.infer.Predictive(
            model,
            params=run_results_.params,
            guide=guide_,
            num_samples=1000,
        )

start_idx = 50
predictive_samples = predictive(
    rng_key=jax.random.PRNGKey(24),
    y=y[-start_idx:],
    x=x[-start_idx:],
    obs_mask=obs_mask[-start_idx:],
)
@fehiepsi fehiepsi added the question Further information is requested label Aug 9, 2024
@fehiepsi
Copy link
Member

fehiepsi commented Aug 9, 2024

This is expected. obs_mask introduces a local latent variable named foo_unobserved whose distribution will be inferred by SVI. Assume that you have a model $x_n \to z_n \to y_n$ and you use autoguide to approximate $p(z_n | x_n, y_n)$. Such information does not allow you to make prediction $p(z'_n | x'_n)$. Instead, you might want to construct a custom guide for $q(z | x)$.

@felipeangelimvieira
Copy link
Author

Oh I see, thank you for the explanation! I think I could use mask handler directly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants