-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Open
Labels
Description
Describe the issue:
I just ran into a logprob rewrite error with an AdvancedSubTensor
op that mixed None
entries and int32
indices together. This wasn't actually a mixture model but logprob found the op and tried to apply its rewrite rules and raised an error instead of just failing silently. The problem seems to be from this line that doesn't include a guard against a None
constant as well as a slice
constant.
Reproduceable code example:
import numpy as np
import pymc as pm
obs = np.random.default_rng().normal(size=(7, 4))
with pm.Model():
inds = np.arange(obs.shape[1])
a = pm.Normal("a", shape=10)
b = pm.Deterministic("b", a[None, inds])
c = pm.Normal("c", mu=b, sigma=1, observed=obs)
pm.sample()
Error message:
ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: find_measurable_index_mixture
ERROR (pytensor.graph.rewriting.basic): node: AdvancedSubtensor(a, NoneConst{None}, [0 1 2 3])
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
File "pytensor/graph/rewriting/basic.py", line 1913, in process_node
replacements = node_rewriter.transform(fgraph, node)
File "pytensor/graph/rewriting/basic.py", line 1085, in transform
return self.fn(fgraph, node)
File "pymc/logprob/mixture.py", line 291, in find_measurable_index_mixture
if any(
File "pymc/logprob/mixture.py", line 292, in <genexpr>
indices.dtype.startswith("int") and sum(1 - b for b in indices.type.broadcastable) > 0
AttributeError: 'Constant' object has no attribute 'dtype'. Did you mean: 'type'?
but sampling works fine because the rewrite was actually supposed to fail and return None
.
PyMC version information:
Github main
Context for the issue:
This doesn't really affect anything. It just confuses regular users that see the error traceback from rewriting and get alarmed. It would be more elegant to handle this extra indexer type just like with slice constants.
Metadata
Metadata
Assignees
Labels
Type
Projects
Milestone
Relationships
Development
Select code repository
Activity
[-]BUG: <Please write a comprehensive title after the 'BUG: ' prefix>[/-][+]BUG: `AdvancedSubTensor` with `None` and integer indices raises a `logprob` error instead of silently failing[/+]Hashcode-Ankit commentedon Apr 26, 2025
hi @lucianopaz just to understand the end outcome, are we expecting it to skip the check if that is constant or none value?
as in you case it is a constant and the condition that is trying to get dtype of it which it not possible here.
Thanks
lucianopaz commentedon Apr 26, 2025
Hi @Hashcode-Ankit. The line that I quoted above needs to also check if the indices are
None
orNoneConst
. That way, the rewrite will returnNone
when it has a mixture of integer indexes, and slices or new axis.If you look through the code base, you’ll see that rewrites have a bunch of conditions that check whether the rewrite could be applied to the inputs. When the conditions fail, the rewrite returns
None
. When it succeeds, it returns the modified graph or node. By adding the extra check on that condition, we are explicitly telling pytensor that the rewrite can’t work if the indexing operation mixes integers and other basic indexing things.Hashcode-Ankit commentedon Apr 27, 2025
Hi, I have raised a pr for the same, can you check that once.
Thanks