Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shared variable issues when using NumPyro JAX sampler #4142

Closed
mschmidt87 opened this issue Sep 28, 2020 · 11 comments · Fixed by #4646
Closed

Shared variable issues when using NumPyro JAX sampler #4142

mschmidt87 opened this issue Sep 28, 2020 · 11 comments · Fixed by #4646

Comments

@mschmidt87
Copy link
Contributor

If you have questions about a specific use case, or you are not sure whether this is a bug or not, please post it to our discourse channel: https://discourse.pymc.io

Description of your problem

I am trying to the new JAX-based sampler in the pymc3jax branch, presented in this notebook: https://gist.github.com/twiecki/f0a28dd06620aa86142931c1f10b5434
I can run the notebook as it is just fine, but if I register the data of the model using the pm.Data constructor, I am getting an. MissingInputError.

Essentially, I am replacing Cell 6 in the notebook with this code:

with pm.Model() as hierarchical_model:
    county_idx = pm.Data('county_idx', data.county_code.values.astype('int32'))
    floor = pm.Data('floor', data.floor.values)
    log_radon = pm.Data('log_radon', data.log_radon)
    
    # Hyperpriors for group nodes
    mu_a = pm.Normal('mu_a', mu=0., sigma=100.)
    sigma_a = pm.HalfNormal('sigma_a', 5.)
    mu_b = pm.Normal('mu_b', mu=0., sigma=100.)
    sigma_b = pm.HalfNormal('sigma_b', 5.)

    # Intercept for each county, distributed around group mean mu_a
    # Above we just set mu and sd to a fixed value while here we
    # plug in a common group distribution for all a and b (which are
    # vectors of length n_counties).
    a = pm.Normal('a', mu=mu_a, sigma=sigma_a, shape=n_counties)
    # Intercept for each county, distributed around group mean mu_a
    b = pm.Normal('b', mu=mu_b, sigma=sigma_b, shape=n_counties)

    # Model error
    eps = pm.HalfCauchy('eps', 5.)

    radon_est = a[county_idx] + b[county_idx]*floor

    # Data likelihood
    radon_like = pm.Normal('radon_like', mu=radon_est,
                           sigma=eps, observed=log_radon)

So, I am registering the two input variables and the output variables as pm.Data objects and replaced their calls in the code below.

I can then run the standard samples without problems but the JAX sampler (Cell 10) fails.

Please provide the full traceback.

---------------------------------------------------------------------------
MissingInputError                         Traceback (most recent call last)
<timed exec> in <module>

/path/to/pymc3/pymc3/sampling_jax.py in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, model, progress_bar)
    114     seed = jax.random.PRNGKey(random_seed)
    115 
--> 116     fgraph = theano.gof.FunctionGraph(model.free_RVs, [model.logpt])
    117     fns = theano.sandbox.jaxify.jax_funcify(fgraph)
    118     logp_fn_jax = fns[0]

/path/to/Theano-PyMC/theano/gof/fg.py in __init__(self, inputs, outputs, features, clone, update_mapping)
    174 
    175         for output in outputs:
--> 176             self.__import_r__(output, reason="init")
    177         for i, output in enumerate(outputs):
    178             output.clients.append(("output", i))

/path/to/Theano-PyMC/theano/gof/fg.py in __import_r__(self, variable, reason)
    347         # Imports the owners of the variables
    348         if variable.owner and variable.owner not in self.apply_nodes:
--> 349             self.__import__(variable.owner, reason=reason)
    350         elif (
    351             variable.owner is None

/path/to/Theano-PyMC/theano/gof/fg.py in __import__(self, apply_node, check, reason)
    399                             % (node.inputs.index(r), str(node))
    400                         )
--> 401                         raise MissingInputError(error_msg, variable=r)
    402 
    403         for node in new_nodes:

MissingInputError: Input 1 of the graph (indices start from 0), used to compute AdvancedSubtensor1(b, county_idx), was not provided and not given a value. Use the Theano flag exception_verbosity='high', for more information on this error.
 
Backtrace when that variable is created:

  File "/path/to/conda/env//lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2866, in run_cell
    result = self._run_cell(
  File "/path/to/conda/env//lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2895, in _run_cell
    return runner(coro)
  File "/path/to/conda/env//lib/python3.8/site-packages/IPython/core/async_helpers.py", line 68, in _pseudo_sync_runner
    coro.send(None)
  File "/path/to/conda/env//lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3071, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/path/to/conda/env//lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3263, in run_ast_nodes
    if (await self.run_code(code, result,  async_=asy)):
  File "/path/to/conda/env//lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3343, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-17-2577b767217d>", line 2, in <module>
    county_idx = pm.Data('county_idx', data.county_code.values.astype('int32'))
  File "/path/to/pymc3/pymc3/data.py", line 516, in __new__
    shared_object = theano.shared(pm.model.pandas_to_array(value), name)

Versions and main components

  • PyMC3 Version: checkout of pymc3jax branch
  • Theano Version: checkout of Theano-Pymc master branch
  • Python Version: 3.8
  • Operating system: Mac OS
  • How did you install PyMC3: manual installation of the branch
@mschmidt87 mschmidt87 changed the title pymc3jax: MissingInputError pymc3jax: MissingInputError when using pm.Data inside a model in combination with JAX sampler Sep 28, 2020
@junpenglao
Copy link
Member

junpenglao commented Sep 28, 2020

Thanks for reporting and testing out! I think theano is marking shared variable as input - will need to work on the log_prob generation function to identify these.

For now, you need to replace all the theano.shared variable and pm.Data with numpy array in your model.

@brandonwillard
Copy link
Contributor

This issue isn't due to lack of support for shared variables in theano.sandbox.jaxify.jax_funcify; it's because a shared variable is being passed directly to theano.gof.FunctionGraph in sample_numpyro_nuts.

If we want to create a FunctionGraph, we'll have to preprocess all the shared variables like pfunc does when theano.function is called. This might be as simple as calling rebuild_collect_shared and using those inputs and outputs, like pfunc does.

@brandonwillard
Copy link
Contributor

brandonwillard commented Oct 4, 2020

As implied in my comment here, one way to add shared variable support to functions like sample_numpyro_nuts is to get the external JAX samplers into a Theano graph so that everything can be compiled using theano.function. This would provide complete shared variable support.

One way to do that is to write simple wrapper Ops that represent the external JAXable sampler functions. With those, one can write jax_funcifys that use said external functions and be done.

@brandonwillard
Copy link
Contributor

brandonwillard commented Oct 4, 2020

Here's a template for such an Op and jax_funcify implementation:

import theano.tensor as tt

from theano.gof.op import Op
from theano.gof.graph import Apply
from theano.tensor.type import TensorType
from theano.sandbox.jaxify import jax_funcify


class NumPyroNUTS(Op):
    def __init__(self, draws=1000, tune=1000, chains=4):
        self.draws = draws
        self.tune = tune
        self.chains = chains
        super().__init__()

    def make_node(self, input_rvs):
        """Construct a node for the NumPyro NUTS sampler.

        Parameters
        ----------
        input_rvs : List[TensorVariable]
            The input variables, or `init_state`s, obtained from `model.free_RVs`, for example.
        """
        inputs = [tt.as_tensor(rv) for rv in input_rvs]

        # New, potentially broadcastable dimensions added by sampling
        broadcastable_sample_dims = [self.chains == 1, self.draws == 1]

        # These symbolic tensors/arrays that represent the posterior samples
        # for each variable.
        outputs = [TensorType(rv.dtype, broadcastable_sample_dims + list(rv.broadcastable)) for rv in input_rvs]
        return Apply(self, inputs, outputs)

    def perform(self, node, inputs, outputs):
        # `inputs` is a list containing the numeric initial state values.

        # Simply put, when we're in here, we have concrete numeric values for
        # the `inputs` specified in `make_node`, so we can do any
        # pure Python work we want.  The only requirement is that we return
        # numeric values that correspond to the `outputs` specified in `make_node`
        # (i.e. they must have the same `dtype`s, number of dimensions, and broadcast
        # pattern).

        # This could compile[, cache,] and evaluate the JAX-jitted function just like
        # `sample_numpyro_nuts` does and return the numeric values (as a list of
        # samples, in the way `make_node` specifies them, of course).

        # `outputs` is a list containing a list for each output variable.
        # These lists need to be populated with the numeric sample arrays for
        # each variable.
        outputs = ...


@jax_funcify.register(NumPyroNUTS)
def jax_funcify_NumPyroNUTS(op):
    draws = op.draws
    tune = op.tune
    chains = op.chains

    def numpyronuts(init, draws=draws, tune=tune, chains=chains):
        # Just return the JAX-jittable sampler function constructed in
        # `sample_numpyro_nuts` (e.g. `_sample` or some variant thereof).
        return ...

    return numpyronuts

@twiecki
Copy link
Member

twiecki commented Oct 5, 2020

@brandonwillard This is cool. I wonder if there are any other benefits to this approach rather than getting theano.shared to work?

@brandonwillard
Copy link
Contributor

One other huge benefit: this approach works with symbolic shapes...

@kc611
Copy link
Contributor

kc611 commented Feb 19, 2021

This looks interesting. One question I have here is what exactly will be the difference between the content of Op's perform and the jax_funcify_NUTS. Is the perform expected to have full python implementation of NUTS sampler (that would be quite bulky.). Or is it supposed to have the NUTS sampler from numpyro (which will make it's contents same as those supposed to be in jax_funcify_NUTS) ?

@brandonwillard
Copy link
Contributor

One question I have here is what exactly will be the difference between the content of Op's perform and the jax_funcify_NUTS. Is the perform expected to have full python implementation of NUTS sampler (that would be quite bulky.). Or is it supposed to have the NUTS sampler from numpyro (which will make it's contents same as those supposed to be in jax_funcify_NUTS) ?

Assuming that we want to create such an Op simply to get the NumPyro JAX code into an Aesara/Theano Function (i.e. the result of compiling an Aesara/Theano graph using aesara.function), Op.perform could simply raise a NotImplementedError, because we would never want/need to evaluate graphs containing this Op using Python and/or C.

In other words, NumPyroNUTS could be a dummy Op that only serves to inject external JAX code into the compiled graph via its corresponding jax_funcify implementation. Doing this solves any problems related to shared variables, because shared variables are a construct that exists only within the context of Aesara/Theano and the use of it's Function objects.

@brandonwillard brandonwillard changed the title pymc3jax: MissingInputError when using pm.Data inside a model in combination with JAX sampler Shared variable issues using NumPyro JAX sampler Mar 4, 2021
@brandonwillard brandonwillard changed the title Shared variable issues using NumPyro JAX sampler Shared variable issues when using NumPyro JAX sampler Mar 4, 2021
@brandonwillard brandonwillard linked a pull request Apr 17, 2021 that will close this issue
@ricardoV94
Copy link
Member

@twiecki or @brandonwillard can this one be closed?

@twiecki
Copy link
Member

twiecki commented Jun 27, 2021

Is it fixed?

@twiecki
Copy link
Member

twiecki commented Jun 27, 2021

Ah yes, seems like it. Neat.

@twiecki twiecki closed this as completed Jun 27, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants