-
Notifications
You must be signed in to change notification settings - Fork 150
Closed
Labels
NumPy compatibilityOp implementationbugSomething isn't workingSomething isn't workinggraph rewriting
Description
Describe the issue:
Using pymc.draw() with truncated distributions throws warnings about shape inference failures in pytensor when no shape information is provided in pymc. The draws are still returned. Adding shape parameters to pymc code fixes the issue.
Reproducable code example:
import pymc as pm
with pm.Model() as m:
θ = pm.Bernoulli("θ", p=0.5)
days = pm.Truncated("days", pm.Binomial.dist(n=7, p=0.5), lower=1)
observed_days = θ * days
draws = pm.draw([θ, observed_days], draws=100)Error message:
/home/aurimas/micromamba/envs/pymc5/lib/python3.11/site-packages/pytensor/tensor/rewriting/shape.py:157: UserWarning: Failed to infer_shape from Op AdvancedSubtensor.
Input shapes: [(), ()]
Exception encountered during infer_shape: <class 'ValueError'>
Exception message: Nonzero only supports non-scalar arrays.
Traceback: Traceback (most recent call last):
File "/home/aurimas/micromamba/envs/pymc5/lib/python3.11/site-packages/pytensor/tensor/rewriting/shape.py", line 133, in get_node_infer_shape
o_shapes = shape_infer(
^^^^^^^^^^^^
File "/home/aurimas/micromamba/envs/pymc5/lib/python3.11/site-packages/pytensor/tensor/subtensor.py", line 2628, in infer_shape
(shape_i(nz_dim, 0, fgraph=fgraph),) for nz_dim in nonzero(idx)
^^^^^^^^^^^^
File "/home/aurimas/micromamba/envs/pymc5/lib/python3.11/site-packages/pytensor/tensor/basic.py", line 935, in nonzero
res = _nonzero(a)
^^^^^^^^^^^
File "/home/aurimas/micromamba/envs/pymc5/lib/python3.11/site-packages/pytensor/graph/op.py", line 304, in __call__
node = self.make_node(*inputs, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/aurimas/micromamba/envs/pymc5/lib/python3.11/site-packages/pytensor/tensor/basic.py", line 884, in make_node
raise ValueError("Nonzero only supports non-scalar arrays.")
ValueError: Nonzero only supports non-scalar arrays.PyTensor version information:
pytensor=='2.16.1'
pymc=5.8.1.
Context for the issue:
No response
Metadata
Metadata
Assignees
Labels
NumPy compatibilityOp implementationbugSomething isn't workingSomething isn't workinggraph rewriting