Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issues with dimensions of unobserved components in versions 0.1.0 and 0.1.1 #338

Open
rklees opened this issue Apr 26, 2024 · 16 comments
Open

Comments

@rklees
Copy link

rklees commented Apr 26, 2024

I run my standard test example using version 0.1.0 and get an error message from pm.sample, see below. Any advice of how to proceed?

TypeError Traceback (most recent call last)
Cell In[19], line 4
2 sampler = 'numpyro'
3 with pymc_model:
----> 4 idata = pm.sample(nuts_sampler=sampler, tune=500, draws=1000, chains=4, progressbar=True, target_accept=0.95)
6 # idate is an "inference data object", provided by the sampler. Sampling statistics are provided in idata.sample_stats. For more information
7 # about the sampling statistics, see
8 # https://www.pymc.io/projects/docs/en/v3/pymc-examples/examples/diagnostics_and_criticism/sampler-stats.html

File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pymc/sampling/mcmc.py:691, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, model, **kwargs)
687 if not isinstance(step, NUTS):
688 raise ValueError(
689 "Model can not be sampled with NUTS alone. Your model is probably not continuous."
690 )
--> 691 return _sample_external_nuts(
692 sampler=nuts_sampler,
693 draws=draws,
694 tune=tune,
695 chains=chains,
696 target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
697 random_seed=random_seed,
698 initvals=initvals,
699 model=model,
700 var_names=var_names,
701 progressbar=progressbar,
702 idata_kwargs=idata_kwargs,
703 nuts_sampler_kwargs=nuts_sampler_kwargs,
704 **kwargs,
705 )
707 if isinstance(step, list):
708 step = CompoundStep(step)

File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pymc/sampling/mcmc.py:351, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, idata_kwargs, nuts_sampler_kwargs, **kwargs)
348 elif sampler in ("numpyro", "blackjax"):
349 import pymc.sampling.jax as pymc_jax
--> 351 idata = pymc_jax.sample_jax_nuts(
352 draws=draws,
353 tune=tune,
354 chains=chains,
355 target_accept=target_accept,
356 random_seed=random_seed,
357 initvals=initvals,
358 model=model,
359 var_names=var_names,
360 progressbar=progressbar,
361 nuts_sampler=sampler,
362 idata_kwargs=idata_kwargs,
363 **nuts_sampler_kwargs,
364 )
365 return idata
367 else:

File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pymc/sampling/jax.py:567, in sample_jax_nuts(draws, tune, chains, target_accept, random_seed, initvals, jitter, model, var_names, nuts_kwargs, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_vectorize, postprocessing_chunks, idata_kwargs, compute_convergence_checks, nuts_sampler)
564 raise ValueError(f"{nuts_sampler=} not recognized")
566 tic1 = datetime.now()
--> 567 raw_mcmc_samples, sample_stats, library = sampler_fn(
568 model=model,
569 target_accept=target_accept,
570 tune=tune,
571 draws=draws,
572 chains=chains,
573 chain_method=chain_method,
574 progressbar=progressbar,
575 random_seed=random_seed,
576 initial_points=initial_points,
577 nuts_kwargs=nuts_kwargs,
578 )
579 tic2 = datetime.now()
581 jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)

File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pymc/sampling/jax.py:458, in _sample_numpyro_nuts(model, target_accept, tune, draws, chains, chain_method, progressbar, random_seed, initial_points, nuts_kwargs)
454 import numpyro
456 from numpyro.infer import MCMC, NUTS
--> 458 logp_fn = get_jaxified_logp(model, negative_logp=False)
460 nuts_kwargs.setdefault("adapt_step_size", True)
461 nuts_kwargs.setdefault("adapt_mass_matrix", True)

File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pymc/sampling/jax.py:153, in get_jaxified_logp(model, negative_logp)
151 if not negative_logp:
152 model_logp = -model_logp
--> 153 logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp])
155 def logp_fn_wrap(x):
156 return logp_fn(*x)[0]

File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pymc/sampling/jax.py:128, in get_jaxified_graph(inputs, outputs)
122 def get_jaxified_graph(
123 inputs: list[TensorVariable] | None = None,
124 outputs: list[TensorVariable] | None = None,
125 ) -> list[TensorVariable]:
126 """Compile an PyTensor graph into an optimized JAX function"""
--> 128 graph = _replace_shared_variables(outputs) if outputs is not None else None
130 fgraph = FunctionGraph(inputs=inputs, outputs=graph, clone=True)
131 # We need to add a Supervisor to the fgraph to be able to run the
132 # JAX sequential optimizer without warnings. We made sure there
133 # are no mutable input variables, so we only need to check for
134 # "destroyers". This should be automatically handled by PyTensor
135 # once aesara-devs/aesara#637 is fixed.

File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pymc/sampling/jax.py:118, in _replace_shared_variables(graph)
111 raise ValueError(
112 "Graph contains shared variables with default_update which cannot "
113 "be safely replaced."
114 )
116 replacements = {var: pt.constant(var.get_value(borrow=True)) for var in shared_variables}
--> 118 new_graph = clone_replace(graph, replace=replacements)
119 return new_graph

File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pytensor/graph/replace.py:85, in clone_replace(output, replace, **rebuild_kwds)
82 _, _outs, _ = rebuild_collect_shared(output, [], tmp_replace, [], **rebuild_kwds)
84 # TODO Explain why we call it twice ?!
---> 85 _, outs, _ = rebuild_collect_shared(_outs, [], new_replace, [], **rebuild_kwds)
87 return outs

File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pytensor/compile/function/pfunc.py:313, in rebuild_collect_shared(outputs, inputs, replace, updates, rebuild_strict, copy_inputs_over, no_default_updates, clone_inner_graphs)
311 for v in outputs:
312 if isinstance(v, Variable):
--> 313 cloned_v = clone_v_get_shared_updates(v, copy_inputs_over)
314 cloned_outputs.append(cloned_v)
315 elif isinstance(v, Out):

File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pytensor/compile/function/pfunc.py:189, in rebuild_collect_shared..clone_v_get_shared_updates(v, copy_inputs_over)
187 if owner not in clone_d:
188 for i in owner.inputs:
--> 189 clone_v_get_shared_updates(i, copy_inputs_over)
190 clone_node_and_cache(
191 owner,
192 clone_d,
193 strict=rebuild_strict,
194 clone_inner_graphs=clone_inner_graphs,
195 )
196 return clone_d.setdefault(v, v)

File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pytensor/compile/function/pfunc.py:189, in rebuild_collect_shared..clone_v_get_shared_updates(v, copy_inputs_over)
187 if owner not in clone_d:
188 for i in owner.inputs:
--> 189 clone_v_get_shared_updates(i, copy_inputs_over)
190 clone_node_and_cache(
191 owner,
192 clone_d,
193 strict=rebuild_strict,
194 clone_inner_graphs=clone_inner_graphs,
195 )
196 return clone_d.setdefault(v, v)

[... skipping similar frames: rebuild_collect_shared.<locals>.clone_v_get_shared_updates at line 189 (2 times)]

File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pytensor/compile/function/pfunc.py:189, in rebuild_collect_shared..clone_v_get_shared_updates(v, copy_inputs_over)
187 if owner not in clone_d:
188 for i in owner.inputs:
--> 189 clone_v_get_shared_updates(i, copy_inputs_over)
190 clone_node_and_cache(
191 owner,
192 clone_d,
193 strict=rebuild_strict,
194 clone_inner_graphs=clone_inner_graphs,
195 )
196 return clone_d.setdefault(v, v)

File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pytensor/compile/function/pfunc.py:190, in rebuild_collect_shared..clone_v_get_shared_updates(v, copy_inputs_over)
188 for i in owner.inputs:
189 clone_v_get_shared_updates(i, copy_inputs_over)
--> 190 clone_node_and_cache(
191 owner,
192 clone_d,
193 strict=rebuild_strict,
194 clone_inner_graphs=clone_inner_graphs,
195 )
196 return clone_d.setdefault(v, v)
197 elif isinstance(v, SharedVariable):

File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pytensor/graph/basic.py:1201, in clone_node_and_cache(node, clone_d, clone_inner_graphs, **kwargs)
1197 new_op: "Op" | None = cast(Optional["Op"], clone_d.get(node.op))
1199 cloned_inputs: list[Variable] = [cast(Variable, clone_d[i]) for i in node.inputs]
-> 1201 new_node = node.clone_with_new_inputs(
1202 cloned_inputs,
1203 # Only clone inner-graph Ops when there isn't a cached clone (and
1204 # when clone_inner_graphs is enabled)
1205 clone_inner_graph=clone_inner_graphs if new_op is None else False,
1206 **kwargs,
1207 )
1209 if new_op:
1210 # If we didn't clone the inner-graph Op above, because
1211 # there was a cached version, set the cloned Apply to use
1212 # the cached clone Op
1213 new_node.op = new_op

File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pytensor/graph/basic.py:285, in Apply.clone_with_new_inputs(self, inputs, strict, clone_inner_graph)
282 if isinstance(new_op, HasInnerGraph) and clone_inner_graph: # type: ignore
283 new_op = new_op.clone() # type: ignore
--> 285 new_node = new_op.make_node(*new_inputs)
286 new_node.tag = copy(self.tag).update(new_node.tag)
287 else:

File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pytensor/scan/op.py:964, in Scan.make_node(self, *inputs)
960 argoffset = 0
961 for inner_seq, outer_seq in zip(
962 self.inner_seqs(self.inner_inputs), self.outer_seqs(inputs)
963 ):
--> 964 check_broadcast(outer_seq, inner_seq)
965 new_inputs.append(copy_var_format(outer_seq, as_var=inner_seq))
967 argoffset += len(self.outer_seqs(inputs))

File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pytensor/scan/op.py:179, in check_broadcast(v1, v2)
177 a1 = n + size - v1.type.ndim + 1
178 a2 = n + size - v2.type.ndim + 1
--> 179 raise TypeError(msg % (v1.type, v2.type, a1, b1, b2, a2))

TypeError: The broadcast pattern of the output of scan (Matrix(float64, shape=(144, 1))) is inconsistent with the one provided in output_info (Vector(float64, shape=(?,))). The output on axis 0 is True, but it is False on axis 1 in output_info. This can happen if one of the dimension is fixed to 1 in the input, while it is still variable in the output, or vice-verca. You have to make them consistent, e.g. using pytensor.tensor.{unbroadcast, specify_broadcastable}.

@ricardoV94
Copy link
Member

You're probably specifying a coords of length 1 which in the new version are mutable by default. You can specify_broadcastable to inform PyMC that this dimension will always be 1 or not specify it and instead use expand_dims, so PyMC also knows this will always have length 1.

@rklees
Copy link
Author

rklees commented Apr 26, 2024

I don't specify any coords. I set up a model as
IRW = st.LevelTrendComponent(order=2, innovations_order=[0, 1])
annual_cycle = st.CycleComponent(name='annual_cycle', cycle_length=12, innovations=True)
SA_cycle = st.CycleComponent(name='SA_cycle', cycle_length=5.347, innovations=True)
obs_noise = st.MeasurementError(name="obs")
pymc_mod = (IRW + annual_cycle + SA_cycle + obs_noise).build(name="IRW + 2 cycles + measurement error")

initial_trend_dims, sigma_trend_dims, annual_cycle_dims, SA_cycle_dims, P0_dims = pymc_mod.param_dims.values()

define the priors, build the graph and sample

with pm.Model(coords=coords) as pymc_model:
# priors for initial state vector comprising initial_trend (i.e., level and slope), initial annual cycle and initial SA cycle
initial_trend = pm.Normal("initial_trend", sigma=100, dims=initial_trend_dims)
initial_annual_cycle = pm.Normal("annual_cycle", shape=(2,), sigma=100, dims=annual_cycle_dims)
initial_SA_cycle = pm.Normal("SA_cycle", shape=(2,), sigma=100, dims=SA_cycle_dims)
# priors for square root of initial state covariance
P0_diag = pm.Gamma("P0_diag", alpha=2, beta=5)
P0 = pm.Deterministic("P0", pt.eye(pymc_mod.k_states) * P0_diag, dims=P0_dims)
# Priors for model parameters
sigma_trend = pm.Gamma("sigma_trend", alpha=2, beta=10, dims=sigma_trend_dims)
sigma_annual_cycle = pm.Gamma("sigma_annual_cycle", alpha=2, beta=5)
sigma_SA_cycle = pm.Gamma("sigma_SA_cycle", alpha=2, beta=5)
sigma_obs = pm.HalfNormal("sigma_obs", sigma=10.)
SA_cycle_length = pm.HalfNormal("SA_cycle_length", sigma=5)

pymc_mod.build_statespace_graph(pd.DataFrame(data['meas']), mode="JAX")
idata = pm.sample(nuts_sampler=sampler, tune=500, draws=1000, chains=4, progressbar=True, target_accept=0.95)

Any specific help would be very much appreciated. Thanks.

@jessegrabowski
Copy link
Member

Looks like another thing broken by pymc-devs/pymc#7047. I will look at it ASAP.

Thanks for reporting all these bugs by the way, it's really important we get them fixed. I deeply appreciate your effort.

@rklees
Copy link
Author

rklees commented May 9, 2024

Any update regarding this issue?

@rklees
Copy link
Author

rklees commented Jun 27, 2024

This problem still persists in version 0.1.1. Is there any hope that someone will fix it? Otherwise, I can't use pymc-experimental for unobserved components. Would be a pitty.

@rklees
Copy link
Author

rklees commented Jun 27, 2024

Here are the details:
NotImplementedError Traceback (most recent call last)
File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pytensor/link/basic.py:102, in Container.set(self, value)
100 try:
101 # Use in-place filtering when/if possible
--> 102 self.storage[0] = self.type.filter_inplace(
103 value, self.storage[0], **kwargs
104 )
105 except NotImplementedError:

File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pytensor/graph/type.py:128, in Type.filter_inplace(self, value, storage, strict, allow_downcast)
109 """Return data or an appropriately wrapped/converted data by converting it in-place.
110
111 This method allows one to reuse old allocated memory. If this method
(...)
126 NotImplementedError
127 """
--> 128 raise NotImplementedError()

NotImplementedError:

During handling of the above exception, another exception occurred:

TypeError Traceback (most recent call last)
Cell In[5], line 26
23 nobs = 144 # we simulate a time series of length Nobs
25 # simulate data (by simulating each SSM component and sum up)
---> 26 x, y = simulate_from_numpy_model(mod, rng, param_dict, steps=nobs)
27 print(type(y))
29 # we standardize the data
30 # y = (y - np.mean(y))/np.std(y)

Cell In[3], line 17, in simulate_from_numpy_model(mod, rng, param_dict, steps)
13 def simulate_from_numpy_model(mod, rng, param_dict, steps=100):
14 """
15 Helper function to visualize the components outside of a PyMC model context
16 """
---> 17 x0, P0, c, d, T, Z, R, H, Q = unpack_symbolic_matrices_with_params(mod, param_dict)
18 Z_time_varies = Z.ndim == 3
20 k_states = mod.k_states

Cell In[3], line 9, in unpack_symbolic_matrices_with_params(mod, param_dict)
5 def unpack_symbolic_matrices_with_params(mod, param_dict):
6 f_matrices = pytensor.function(
7 list(mod._name_to_variable.values()), unpack_statespace(mod.ssm), on_unused_input="ignore"
8 )
----> 9 x0, P0, c, d, T, Z, R, H, Q = f_matrices(**param_dict)
10 return x0, P0, c, d, T, Z, R, H, Q

File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pytensor/compile/function/types.py:897, in Function.call(self, *args, **kwargs)
895 if kwargs: # for speed, skip the items for empty kwargs
896 for k, arg in kwargs.items():
--> 897 self[k] = arg
899 if (
900 not self.trust_input
901 and
(...)
904 ):
905 # Collect aliased inputs among the storage space
906 args_share_memory = []

File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pytensor/compile/function/types.py:549, in Function.setitem(self, item, value)
548 def setitem(self, item, value):
--> 549 self.value[item] = value

File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pytensor/compile/function/types.py:505, in Function.init..ValueAttribute.setitem(self, item, value)
499 raise TypeError(
500 f"Ambiguous name: {item} - please check the "
501 "names of the inputs of your function "
502 "for duplicates."
503 )
504 if isinstance(s, Container):
--> 505 s.value = value
506 s.provided += 1
507 else:

File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pytensor/link/basic.py:106, in Container.set(self, value)
102 self.storage[0] = self.type.filter_inplace(
103 value, self.storage[0], **kwargs
104 )
105 except NotImplementedError:
--> 106 self.storage[0] = self.type.filter(value, **kwargs)
108 except Exception as e:
109 e.args = (*e.args, f'Container name "{self.name}"')

File ~/miniforge3/envs/BSSM/lib/python3.11/site-packages/pytensor/tensor/type.py:242, in TensorType.filter(self, data, strict, allow_downcast)
239 raise TypeError(err_msg)
241 if self.ndim != data.ndim:
--> 242 raise TypeError(
243 f"Wrong number of dimensions: expected {self.ndim},"
244 f" got {data.ndim} with shape {data.shape}."
245 )
246 if not data.flags.aligned:
247 raise TypeError(
248 "The numpy.ndarray object is not aligned."
249 " PyTensor C code does not support that.",
250 )

TypeError: ('Wrong number of dimensions: expected 0, got 1 with shape (1,).', 'Container name "sigma_annual_cycle"')

@rklees rklees changed the title Error message in version 0.1.0 from pm.sample Issues with dimensions of unobserved components in versions 0.1.0 and 0.1.1 Jun 27, 2024
@jessegrabowski
Copy link
Member

#346 should address this

@rklees
Copy link
Author

rklees commented Jun 27, 2024 via email

@rklees
Copy link
Author

rklees commented Jun 27, 2024 via email

@rklees
Copy link
Author

rklees commented Jun 27, 2024 via email

@rklees
Copy link
Author

rklees commented Jun 27, 2024 via email

@jessegrabowski
Copy link
Member

This should be fixed in main now that #346 is merged

@rklees
Copy link
Author

rklees commented Jun 29, 2024 via email

@zaxtax
Copy link
Contributor

zaxtax commented Jun 29, 2024 via email

@rklees
Copy link
Author

rklees commented Jun 29, 2024 via email

@zaxtax
Copy link
Contributor

zaxtax commented Jun 29, 2024 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants