Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing JAX Implementation for Split Op #145

Closed
sp-leonardo-alcivar opened this issue Dec 20, 2022 · 4 comments · Fixed by #209
Closed

Missing JAX Implementation for Split Op #145

sp-leonardo-alcivar opened this issue Dec 20, 2022 · 4 comments · Fixed by #209
Assignees
Labels

Comments

@sp-leonardo-alcivar
Copy link

sp-leonardo-alcivar commented Dec 20, 2022

Description

Hi everyone, since the update to Pymc 5.0 a code that used to run in pymc 4.0 started given me a lot of problems when using pm.sampling.jax.sample_numpyro_nuts. The model is as follows:

def func1(x, alpha: float = 0.0, l_max: int = 12):
    cycles = [
        at.concatenate(
            [at.zeros(i), x[: x.shape[0] - i]]
        )
        for i in range(l_max)
    ]
    x_cycle = at.stack(cycles)
    w = at.as_tensor_variable([at.power(alpha, i) for i in range(l_max)])
    return at.dot(w, x_cycle)
def func2(x, lam: float = 0.5):
    return (1 - at.exp(-lam * x)) / (1 + at.exp(-lam * x))
with pm.Model() as asdr_model:
            asdr_model.add_coord('d', self.d, mutable = False)
            asdr_model.add_coord('ps', self.ps, mutable = False)
            asdr_model.add_coord('fourier_mode', np.arange(2*self.n_order), mutable = False)
            
            if add_trend:
                t_ = pm.Data(name=f"t", value= self.t, mutable = True)
            if add_seasonality:
                fourier_features_ = pm.Data(name=f"fourier_features", value= self.fourier_features, mutable = True)

            # --- priors ---

            ## intercept
            if add_trend:
                a = pm.Normal(name="a",
                              mu=priors_config['intercept']['mu'],
                              sigma=priors_config['intercept']['sigma'])
            else:
                a = pm.Normal(name="a",
                              mu=priors_config['intercept']['mu'],
                              sigma=priors_config['intercept']['sigma'])

            ## trend
            if add_trend:
                b_trend = pm.Normal(name="b_trend",
                                    mu=priors_config['b_trend']['mu'],
                                    sigma=priors_config['b_trend']['sigma'])
            ## seasonality
            if add_seasonality:
                b_fourier = pm.Laplace(name="b_fourier",
                                       mu=priors_config['b_seasonality']['mu'],
                                       b=priors_config['b_seasonality']['b'],
                                       dims="fourier_mode")

            ## standard deviation of the likelihood
            sigma = pm.HalfNormal(name='sigma',
                                  sigma=priors_config['sigma_likelihood']['sigma'])
            # degrees of freedom of the likelihood
            nu = pm.Gamma(name='nu',
                          alpha=priors_config['nu_likelihood']['alpha'],
                          beta=priors_config['nu_likelihood']['beta'])

            # Trend and seasonality determination
            if add_trend:
                trend = pm.Deterministic(name='trend', var = a + b_trend * t_, dims='date')
            if add_seasonality:
                seasonality = pm.Deterministic(
                    name='seasonality', var=pm.math.dot(fourier_features_, b_fourier), dims='date'
                )

            # --- data containers ---

            m_effects = []

            for ii, p in enumerate(self.ps):

                media_scaled_ = pm.Data(name=f"m_s_{p}", value=self.m_vs_s[:,ii], dims='date', mutable = True)

                ## func1 effect
                alpha = pm.Beta(name=f"alpha_{p}",
                                alpha=priors_config['alpha_adstock']['alpha'],
                                beta=priors_config['alpha_adstock']['beta'])
                ## func2 effect
                lam = pm.Gamma(name=f"lam_{p}",
                               alpha=priors_config['lambda_saturation']['alpha'],
                               beta=priors_config['lambda_saturation']['beta'])
                ## gaussian random walks standard deviation
                sigma_slope = pm.HalfNormal(name=f'sigma_slope_{p}',
                                            sigma=priors_config['sigma_slope']['sigma'])

                # --- model parametrization ---

                # Random walks for the betas of the media values
                slopes = pm.GaussianRandomWalk(
                    name=f'slopes_{p}',
                    sigma=sigma_slope,
                    init_dist=pymc.Normal.dist(
                        name=f'init_dist_{p}',
                        mu = priors_config['rw_initial_dist']['mu'],
                        sigma = priors_config['rw_initial_dist']['sigma']
                    ),
                    dims=['d'],
                )

                media_part1 = pm.Deterministic(
                    name=f"m_part1_{p}",
                    var=func1(x=m_s_, alpha=alpha, l_max=9),
                    dims=['d']
                )

                m_part1_part2 = pm.Deterministic(
                    name=f"m_part1_{p}",
                    var=func2(x=media_part1, lam=lam),
                    dims=['d'],
                )

                m_effect = pm.Deterministic(
                    name=f"m_part2_{p}",
                    var= pm.math.exp(slopes)*m_part1_part2,
                    dims = ['d']
                )

                m_effects.append(m_effect)


            # Final model
            obj_var = sum(m_effects)

            if add_trend:
                obj_var += trend
            else:
                obj_var += a
            if add_seasonality:
                obj_var += seasonality

            mu = pm.Deterministic(name="mu", var = obj_var, dims="d")


            # --- likelihood ---
            pm.StudentT(name="likelihood", nu=nu, mu=mu, sigma=sigma, observed=self.y_scaled, dims="d")

And the error is as follows:
First a lot of warnings:

/Users/leonardoalcivar/opt/anaconda3/envs/mmm_v3/lib/python3.11/site-packages/pymc/logprob/joint_logprob.py:167: UserWarning: Found a random variable that was neither among the observations nor the conditioned variables: [normal_rv{0, (0, 0), floatX, False}.0, init_dist_p]

And the error:

Cell In[30], line 280, in BayesianModel.train(self, model, iterations, chains)
    277 self.execution_time = now
    279 with model:
--> 280     self.model_trace = pm.sampling.jax.sample_numpyro_nuts(
    281         draws=iterations,
    282         chains=chains,
    283     )
    284     self.model_posterior_predictive = pm.sample_posterior_predictive(
    285         trace=self.model_trace
    286     )
    288 self.posterior_predictive_likelihood = (
    289     self.model_posterior_predictive.posterior_predictive[
    290         "likelihood"
    291     ].stack(sample=("chain", "draw"))
    292 )

File ~/opt/anaconda3/envs/mmm_v3/lib/python3.11/site-packages/pymc/sampling/jax.py:576, in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progress_bar, keep_untransformed, chain_method, postprocessing_backend, idata_kwargs, nuts_kwargs)
    567 print("Compiling...", file=sys.stdout)
    569 init_params = _get_batched_jittered_initial_points(
    570     model=model,
    571     chains=chains,
    572     initvals=initvals,
    573     random_seed=random_seed,
    574 )
--> 576 logp_fn = get_jaxified_logp(model, negative_logp=False)
    578 nuts_kwargs = _update_numpyro_nuts_kwargs(nuts_kwargs)
    579 nuts_kernel = NUTS(
    580     potential_fn=logp_fn,
    581     target_accept_prob=target_accept,
    582     **nuts_kwargs,
    583 )

File ~/opt/anaconda3/envs/mmm_v3/lib/python3.11/site-packages/pymc/sampling/jax.py:116, in get_jaxified_logp(model, negative_logp)
    114 if not negative_logp:
    115     model_logp = -model_logp
--> 116 logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp])
    118 def logp_fn_wrap(x):
    119     return logp_fn(*x)[0]

File ~/opt/anaconda3/envs/mmm_v3/lib/python3.11/site-packages/pymc/sampling/jax.py:109, in get_jaxified_graph(inputs, outputs)
    106 mode.JAX.optimizer.rewrite(fgraph)
    108 # We now jaxify the optimized fgraph
--> 109 return jax_funcify(fgraph)

File ~/opt/anaconda3/envs/mmm_v3/lib/python3.11/functools.py:909, in singledispatch.<locals>.wrapper(*args, **kw)
    905 if not args:
    906     raise TypeError(f'{funcname} requires at least '
    907                     '1 positional argument')
--> 909 return dispatch(args[0].__class__)(*args, **kw)

File ~/opt/anaconda3/envs/mmm_v3/lib/python3.11/site-packages/pytensor/link/jax/dispatch/basic.py:49, in jax_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
     42 @jax_funcify.register(FunctionGraph)
     43 def jax_funcify_FunctionGraph(
     44     fgraph,
   (...)
     47     **kwargs,
     48 ):
---> 49     return fgraph_to_python(
     50         fgraph,
     51         jax_funcify,
     52         type_conversion_fn=jax_typify,
     53         fgraph_name=fgraph_name,
     54         **kwargs,
     55     )

File ~/opt/anaconda3/envs/mmm_v3/lib/python3.11/site-packages/pytensor/link/utils.py:740, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, **kwargs)
    738 body_assigns = []
    739 for node in order:
--> 740     compiled_func = op_conversion_fn(
    741         node.op, node=node, storage_map=storage_map, **kwargs
    742     )
    744     # Create a local alias with a unique name
    745     local_compiled_func_name = unique_name(compiled_func)

File ~/opt/anaconda3/envs/mmm_v3/lib/python3.11/functools.py:909, in singledispatch.<locals>.wrapper(*args, **kw)
    905 if not args:
    906     raise TypeError(f'{funcname} requires at least '
    907                     '1 positional argument')
--> 909 return dispatch(args[0].__class__)(*args, **kw)

File ~/opt/anaconda3/envs/mmm_v3/lib/python3.11/site-packages/pytensor/link/jax/dispatch/basic.py:39, in jax_funcify(op, node, storage_map, **kwargs)
     36 @singledispatch
     37 def jax_funcify(op, node=None, storage_map=None, **kwargs):
     38     """Create a JAX compatible function from an PyTensor `Op`."""
---> 39     raise NotImplementedError(f"No JAX conversion for the given `Op`: {op}")

NotImplementedError: No JAX conversion for the given `Op`: Split{2}
@twiecki twiecki added bug Something isn't working jax labels Dec 20, 2022
@twiecki
Copy link
Member

twiecki commented Dec 20, 2022

Thanks for the report @sp-leonardo-alcivar . Your post is a bit hard to read, can you format it correctly using "```python" tags.

@sp-leonardo-alcivar
Copy link
Author

I think it should be better now :)

@ricardoV94 ricardoV94 removed the bug Something isn't working label Dec 21, 2022
@ricardoV94 ricardoV94 transferred this issue from pymc-devs/pymc Dec 21, 2022
@Hoernchen
Copy link

I am facing the same issue, downgrading to pymc==4.4.0 didn't help, but pymc==4.1.0 works.

@Bodisatva
Copy link

Bodisatva commented Feb 4, 2023

Hi there,

  • 1st : I am very gratefull and appreciative of the very important work you are all doing in maintaining and updating this jewel of pyMC so big thank you !!!

  • The issue:

Running a dummy model using GaussianRandomWalk with 5.0.2
work perfect when using the regular sampling method but facing the very same error message when
trying pm.sampling.jax.sample_numpyro_nuts during the compiling phase:

No JAX conversion for the given `Op`: Split{2}

  • Here is my system specs:
Python implementation: CPython
Python version       : 3.10.9
IPython version      : 8.9.0

numpy     : 1.23.5
arviz     : 0.14.0
pandas    : 1.5.3
pytensor  : 2.9.1
pymc      : 5.0.2

sys       : 3.10.9 (main, Dec 15 2022, 17:11:09) [Clang 14.0.0 (clang-1400.0.29.202)]

Watermark: 2.3.1
  • Full error:
---> 11 trace = pmsj.sample_numpyro_nuts(100, tune=2000, target_accept=0.9) 

File ~/.local/share/virtualenvs/python310-zlKQFs2r/lib/python3.10/site-packages/pymc/sampling/jax.py:614, in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_chunks, idata_kwargs, nuts_kwargs)
    605 print("Compiling...", file=sys.stdout)
    607 init_params = _get_batched_jittered_initial_points(
    608     model=model,
    609     chains=chains,
    610     initvals=initvals,
    611     random_seed=random_seed,
    612 )
--> 614 logp_fn = get_jaxified_logp(model, negative_logp=False)
    616 nuts_kwargs = _update_numpyro_nuts_kwargs(nuts_kwargs)
    617 nuts_kernel = NUTS(
    618     potential_fn=logp_fn,
    619     target_accept_prob=target_accept,
    620     **nuts_kwargs,
    621 )

File ~/.local/share/virtualenvs/python310-zlKQFs2r/lib/python3.10/site-packages/pymc/sampling/jax.py:118, in get_jaxified_logp(model, negative_logp)
    116 if not negative_logp:
    117     model_logp = -model_logp
--> 118 logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp])
    120 def logp_fn_wrap(x):
    121     return logp_fn(*x)[0]

File ~/.local/share/virtualenvs/python310-zlKQFs2r/lib/python3.10/site-packages/pymc/sampling/jax.py:111, in get_jaxified_graph(inputs, outputs)
    108 mode.JAX.optimizer.rewrite(fgraph)
    110 # We now jaxify the optimized fgraph
--> 111 return jax_funcify(fgraph)

File /opt/homebrew/Cellar/python@3.10/3.10.9/Frameworks/Python.framework/Versions/3.10/lib/python3.10/functools.py:889, in singledispatch.<locals>.wrapper(*args, **kw)
    885 if not args:
    886     raise TypeError(f'{funcname} requires at least '
    887                     '1 positional argument')
--> 889 return dispatch(args[0].__class__)(*args, **kw)

File ~/.local/share/virtualenvs/python310-zlKQFs2r/lib/python3.10/site-packages/pytensor/link/jax/dispatch/basic.py:49, in jax_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
     42 @jax_funcify.register(FunctionGraph)
     43 def jax_funcify_FunctionGraph(
     44     fgraph,
   (...)
     47     **kwargs,
     48 ):
---> 49     return fgraph_to_python(
     50         fgraph,
     51         jax_funcify,
     52         type_conversion_fn=jax_typify,
     53         fgraph_name=fgraph_name,
     54         **kwargs,
     55     )

File ~/.local/share/virtualenvs/python310-zlKQFs2r/lib/python3.10/site-packages/pytensor/link/utils.py:740, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, **kwargs)
    738 body_assigns = []
    739 for node in order:
--> 740     compiled_func = op_conversion_fn(
    741         node.op, node=node, storage_map=storage_map, **kwargs
    742     )
    744     # Create a local alias with a unique name
    745     local_compiled_func_name = unique_name(compiled_func)

File /opt/homebrew/Cellar/python@3.10/3.10.9/Frameworks/Python.framework/Versions/3.10/lib/python3.10/functools.py:889, in singledispatch.<locals>.wrapper(*args, **kw)
    885 if not args:
    886     raise TypeError(f'{funcname} requires at least '
    887                     '1 positional argument')
--> 889 return dispatch(args[0].__class__)(*args, **kw)

File ~/.local/share/virtualenvs/python310-zlKQFs2r/lib/python3.10/site-packages/pytensor/link/jax/dispatch/basic.py:39, in jax_funcify(op, node, storage_map, **kwargs)
     36 @singledispatch
     37 def jax_funcify(op, node=None, storage_map=None, **kwargs):
     38     """Create a JAX compatible function from an PyTensor `Op`."""
---> 39     raise NotImplementedError(f"No JAX conversion for the given `Op`: {op}")

NotImplementedError: No JAX conversion for the given `Op`: Split{2}

@ricardoV94 ricardoV94 self-assigned this Feb 6, 2023
@ricardoV94 ricardoV94 changed the title NotImplementedError(f"No JAX conversion for the given Op: Split{2} Missing JAX Implementation for Split Op Feb 6, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants