Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

numpyro.deterministic static on infer.Predictive #1772

Closed
AkiroSR opened this issue Apr 3, 2024 · 13 comments · Fixed by #1789
Closed

numpyro.deterministic static on infer.Predictive #1772

AkiroSR opened this issue Apr 3, 2024 · 13 comments · Fixed by #1789
Labels
question Further information is requested

Comments

@AkiroSR
Copy link

AkiroSR commented Apr 3, 2024

For some reason after fitting the model the numpyro.deterministic shape remains static, after trying to predict with a different shape it throws a shape error.

Example in lightweight-mmm:

# extra_features.shape = (10,3) / trying to predict 10 new time periods

extra_features_effect = numpyro.deterministic(
    name="extra_features_effect",
    value=jnp.einsum(
        extra_features_einsum, extra_features, coef_extra_features
    ),
)

# extra_features_effect.shape = (30,3)  / output is resized to the size of the model as when fit; 30 periods

This throws a size error, see:
google/lightweight_mmm#309
and
google/lightweight_mmm#308

@fehiepsi
Copy link
Member

fehiepsi commented Apr 3, 2024

Sorry for the breakage! Could you try to use the dev branch of lightweight mmm? I will ping a dev there for a release if it works.

@AkiroSR
Copy link
Author

AkiroSR commented Apr 3, 2024

I think it's related to numpyro. The problem function is numpyro.deterministic.
Everything else works.
I'll have a look but I reckon it's related to the meridian release

@fehiepsi
Copy link
Member

fehiepsi commented Apr 3, 2024

Do you mean that pip install --upgrade git+https://github.com/google/lightweight_mmm.git does not resolve the issue?

@fehiepsi fehiepsi added the question Further information is requested label Apr 13, 2024
@nikisix
Copy link

nikisix commented Apr 26, 2024

@fehiepsi saw your fix on lightweight Change-Id: I7c0658b0a13506c319fd3e6e00cdf2791d64e26f.

I believe the long-term fix here is 2-fold:

  1. Return deterministic sites in posterior_samples (mcmc saves deterministic sites in its samples, and accessed via mcmc.get_samples()).
  2. Predictive always pops deterministic sites.

If these are unfeasible for deeper reasons, then at least mention the pop trick here: https://num.pyro.ai/en/v0.2.0/utilities.html

As the current behavior is a bit counterintuitive.

@kylejcaron
Copy link
Contributor

kylejcaron commented Apr 29, 2024

I'm running into the same issue, here's a reproducible example:

import numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS,Predictive
from jax import random


X = np.random.normal(0, 1, size=1000)
y = 5 + 1.2*X + np.random.normal(size=1000)

def model(X,y=None):
    alpha = numpyro.sample("alpha", dist.Normal(0,10))
    beta = numpyro.sample("beta", dist.Normal(0,1))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    with numpyro.plate("data", len(X)):
        eta = numpyro.deterministic("eta", alpha + beta*X)
        obs = numpyro.sample("obs", dist.Normal(eta, sigma), obs=y)
   
# Run NUTS.
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000)
mcmc.run(random.PRNGKey(0), X=X, y=y)

# Make predictions where X is a different shape
posterior_samples = mcmc.get_samples()
# posterior_samples.pop("eta") # this fixes the issues
pred_func = Predictive(model, posterior_samples=posterior_samples)
traceback

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
    [... skipping hidden 1 frame]

File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/jax/_src/util.py:290, in cache.<locals>.wrap.<locals>.wrapper(*args, **kwargs)
    289 else:
--> 290   return cached(config.trace_context(), *args, **kwargs)

File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/jax/_src/util.py:283, in cache.<locals>.wrap.<locals>.cached(_, *args, **kwargs)
    281 @functools.lru_cache(max_size)
    282 def cached(_, *args, **kwargs):
--> 283   return f(*args, **kwargs)

File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/jax/_src/lax/lax.py:155, in _broadcast_shapes_cached(*shapes)
    153 @cache()
    154 def _broadcast_shapes_cached(*shapes: tuple[int, ...]) -> tuple[int, ...]:
--> 155   return _broadcast_shapes_uncached(*shapes)

File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/jax/_src/lax/lax.py:171, in _broadcast_shapes_uncached(*shapes)
    170 if result_shape is None:
--> 171   raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
    172 return result_shape

ValueError: Incompatible shapes for broadcasting: shapes=[(200,), (1000,)]

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
Cell In[1], line 26
     24 # Make predictions where X is a different shape
     25 pred_func = Predictive(model, posterior_samples=mcmc.get_samples())
---> 26 preds = pred_func(random.PRNGKey(1), X=X[:200], y=None)

File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/infer/util.py:1011, in Predictive.__call__(self, rng_key, *args, **kwargs)
   1001 """
   1002 Returns dict of samples from the predictive distribution. By default, only sample sites not
   1003 contained in `posterior_samples` are returned. This can be modified by changing the
   (...)
   1008 :param kwargs: model kwargs.
   1009 """
   1010 if self.batch_ndims == 0 or self.params == {} or self.guide is None:
-> 1011     return self._call_with_params(rng_key, self.params, args, kwargs)
   1012 elif self.batch_ndims == 1:  # batch over parameters
   1013     batch_size = jnp.shape(tree_flatten(self.params)[0][0])[0]

File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/infer/util.py:988, in Predictive._call_with_params(self, rng_key, params, args, kwargs)
    977     posterior_samples = _predictive(
    978         guide_rng_key,
    979         guide,
   (...)
    985         model_kwargs=kwargs,
    986     )
    987 model = substitute(self.model, self.params)
--> 988 return _predictive(
    989     rng_key,
    990     model,
    991     posterior_samples,
    992     self._batch_shape,
    993     return_sites=self.return_sites,
    994     infer_discrete=self.infer_discrete,
    995     parallel=self.parallel,
    996     model_args=args,
    997     model_kwargs=kwargs,
    998 )

File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/infer/util.py:825, in _predictive(rng_key, model, posterior_samples, batch_shape, return_sites, infer_discrete, parallel, model_args, model_kwargs)
    823 rng_key = rng_key.reshape(batch_shape + key_shape)
    824 chunk_size = num_samples if parallel else 1
--> 825 return soft_vmap(
    826     single_prediction, (rng_key, posterior_samples), len(batch_shape), chunk_size
    827 )

File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/util.py:419, in soft_vmap(fn, xs, batch_ndims, chunk_size)
    413     xs = tree_map(
    414         lambda x: jnp.reshape(x, prepend_shape + (chunk_size,) + jnp.shape(x)[1:]),
    415         xs,
    416     )
    417     fn = vmap(fn)
--> 419 ys = lax.map(fn, xs) if num_chunks > 1 else fn(xs)
    420 map_ndims = int(num_chunks > 1) + int(chunk_size > 1)
    421 ys = tree_map(
    422     lambda y: jnp.reshape(
    423         y, (int(np.prod(jnp.shape(y)[:map_ndims])),) + jnp.shape(y)[map_ndims:]
    424     )[:batch_size],
    425     ys,
    426 )

    [... skipping hidden 12 frame]

File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/infer/util.py:798, in _predictive.<locals>.single_prediction(val)
    789     pred_samples = _sample_posterior(
    790         config_enumerate(condition(model, samples)),
    791         first_available_dim,
   (...)
    795         **model_kwargs,
    796     )
    797 else:
--> 798     model_trace = trace(
    799         seed(substitute(masked_model, samples), rng_key)
    800     ).get_trace(*model_args, **model_kwargs)
    801     pred_samples = {name: site["value"] for name, site in model_trace.items()}
    803 if return_sites is not None:

File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/handlers.py:171, in trace.get_trace(self, *args, **kwargs)
    163 def get_trace(self, *args, **kwargs):
    164     """
    165     Run the wrapped callable and return the recorded trace.
    166 
   (...)
    169     :return: `OrderedDict` containing the execution trace.
    170     """
--> 171     self(*args, **kwargs)
    172     return self.trace

File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

    [... skipping similar frames: Messenger.__call__ at line 105 (2 times)]

File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

Cell In[1], line 17, in model(X, y)
     15 with numpyro.plate("data", len(X)):
     16     eta = numpyro.deterministic("eta", alpha + beta*X)
---> 17     obs = numpyro.sample("obs", dist.Normal(eta, sigma), obs=y)

File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/primitives.py:222, in sample(name, fn, obs, rng_key, sample_shape, infer, obs_mask)
    207 initial_msg = {
    208     "type": "sample",
    209     "name": name,
   (...)
    218     "infer": {} if infer is None else infer,
    219 }
    221 # ...and use apply_stack to send it to the Messengers
--> 222 msg = apply_stack(initial_msg)
    223 return msg["value"]

File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/primitives.py:47, in apply_stack(msg)
     45 pointer = 0
     46 for pointer, handler in enumerate(reversed(_PYRO_STACK)):
---> 47     handler.process_message(msg)
     48     # When a Messenger sets the "stop" field of a message,
     49     # it prevents any Messengers above it on the stack from being applied.
     50     if msg.get("stop"):

File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/numpyro/primitives.py:546, in plate.process_message(self, msg)
    544 overlap_idx = max(len(expected_shape) - len(dist_batch_shape), 0)
    545 trailing_shape = expected_shape[overlap_idx:]
--> 546 broadcast_shape = lax.broadcast_shapes(
    547     trailing_shape, tuple(dist_batch_shape)
    548 )
    549 batch_shape = expected_shape[:overlap_idx] + broadcast_shape
    550 msg["fn"] = msg["fn"].expand(batch_shape)

    [... skipping hidden 1 frame]

File ~/Library/Caches/pypoetry/virtualenvs/ao-models-QNM_MrRk-py3.11/lib/python3.11/site-packages/jax/_src/lax/lax.py:171, in _broadcast_shapes_uncached(*shapes)
    169 result_shape = _try_broadcast_shapes(shape_list)
    170 if result_shape is None:
--> 171   raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
    172 return result_shape

ValueError: Incompatible shapes for broadcasting: shapes=[(200,), (1000,)]

I get that inputting samples for a deterministic site would lead to the model expecting a certain shape, but it does seem a bit awkward that the typical workflow with predictions requires some extra work if deterministics are involved.

I wonder if something like this is possible? https://github.com/pyro-ppl/numpyro/blob/2f1bccdba2fc7b0a6ec235ca1bd5ce2417a0635c/numpyro/infer/mcmc.py#L714C61-L714C62

@fehiepsi
Copy link
Member

Hi @nikisix and @kylejcaron, really sorry for the breakage! I think a good action is to introduce exclude_deterministic=True to Predictive. This rolls the behavior back to pre-0.14 release. I'm less worried that new users will want to use deterministic sites in Predictive. What do you think, @martinjankowiak?

@martinjankowiak
Copy link
Collaborator

something like that sounds reasonable. the change in behavior was probably a mistake...

@kylejcaron
Copy link
Contributor

@fehiepsi @martinjankowiak should the AutoGuide.sample_posterior() be changed as well? It seems more difficult to fix since many sample_posterior functions are unique to auto guides.

For example, the following workflow has the same problem :

guide = AutoNormal(model) 
svi = SVI(model, guide, optim=numpyro.optim.Adam(0.01), loss=Trace_ELBO())
svi_result = svi.run(random.PRNGKey(0), 10000, X=X, y=y)

params = guide.sample_posterior(random.PRNGKey(0), params=svi_result.params)
pred_func = Predictive(model, params=params, num_samples=100)
preds = pred_func(random.PRNGKey(1), X=X[:250], y=None)

The solution for this seems to just including the guide and using SVI params instead, but I imagine some may be using the pattern above

pred_func = Predictive(model, guide=guide, params=svi_result.params, num_samples=100)
preds = pred_func(random.PRNGKey(1),X[:n_preds])['eta']

@kylejcaron
Copy link
Contributor

I think this pattern could be used with an exclude_deterministic arg in AutoGuide's

@fehiepsi
Copy link
Member

@kylejcaron I think we can fix this in Predictive. The breakage happens because we allow substituting deterministic sites in the substitute handler. We can create a subclass of substitute which skip processing deterministic sites and use it in Predictive.

@kylejcaron
Copy link
Contributor

kylejcaron commented Apr 30, 2024

@kylejcaron I think we can fix this in Predictive. The breakage happens because we allow substituting deterministic sites in the substitute handler. We can create a subclass of substitute which skip processing deterministic sites and use it in Predictive.

Got it that makes sense to me - seems like it'd involve just replacing the substitute call in this line and L987, but let me know if I'm missing anything.

I'm happy to make an attempt at this, any name recommendations for the new effect handler?

@fehiepsi
Copy link
Member

The substitute logic is at this line. You can change

substitute(model, data)

to something like

substitute(model, substitute_fn=lambda msg: data[msg["name"]] if msg["name"] in data and msg["type"] != "deterministic")

@kylejcaron
Copy link
Contributor

The substitute logic is at this line. You can change

substitute(model, data)

to something like

substitute(model, substitute_fn=lambda msg: data[msg["name"]] if msg["name"] in data and msg["type"] != "deterministic")

nice idea with the substitute_fn, just added a PR!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants