Skip to content

BUG: AdvancedSubTensor with None and integer indices raises a logprob error instead of silently failing #7762

@lucianopaz

Description

@lucianopaz
Member

Describe the issue:

I just ran into a logprob rewrite error with an AdvancedSubTensor op that mixed None entries and int32 indices together. This wasn't actually a mixture model but logprob found the op and tried to apply its rewrite rules and raised an error instead of just failing silently. The problem seems to be from this line that doesn't include a guard against a None constant as well as a slice constant.

Reproduceable code example:

import numpy as np
import pymc as pm


obs = np.random.default_rng().normal(size=(7, 4))
with pm.Model():
   inds = np.arange(obs.shape[1])
   a = pm.Normal("a", shape=10)
   b = pm.Deterministic("b", a[None, inds])
   c = pm.Normal("c", mu=b, sigma=1, observed=obs)
   pm.sample()

Error message:

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: find_measurable_index_mixture
ERROR (pytensor.graph.rewriting.basic): node: AdvancedSubtensor(a, NoneConst{None}, [0 1 2 3])
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "pytensor/graph/rewriting/basic.py", line 1913, in process_node
    replacements = node_rewriter.transform(fgraph, node)
  File "pytensor/graph/rewriting/basic.py", line 1085, in transform
    return self.fn(fgraph, node)
  File "pymc/logprob/mixture.py", line 291, in find_measurable_index_mixture
    if any(
  File "pymc/logprob/mixture.py", line 292, in <genexpr>
    indices.dtype.startswith("int") and sum(1 - b for b in indices.type.broadcastable) > 0
AttributeError: 'Constant' object has no attribute 'dtype'. Did you mean: 'type'?

but sampling works fine because the rewrite was actually supposed to fail and return None.

PyMC version information:

Github main

Context for the issue:

This doesn't really affect anything. It just confuses regular users that see the error traceback from rewriting and get alarmed. It would be more elegant to handle this extra indexer type just like with slice constants.

Activity

changed the title [-]BUG: <Please write a comprehensive title after the 'BUG: ' prefix>[/-] [+]BUG: `AdvancedSubTensor` with `None` and integer indices raises a `logprob` error instead of silently failing[/+] on Apr 23, 2025
Hashcode-Ankit

Hashcode-Ankit commented on Apr 26, 2025

@Hashcode-Ankit

hi @lucianopaz just to understand the end outcome, are we expecting it to skip the check if that is constant or none value?

as in you case it is a constant and the condition that is trying to get dtype of it which it not possible here.

Thanks

lucianopaz

lucianopaz commented on Apr 26, 2025

@lucianopaz
MemberAuthor

Hi @Hashcode-Ankit. The line that I quoted above needs to also check if the indices are None or NoneConst. That way, the rewrite will return None when it has a mixture of integer indexes, and slices or new axis.
If you look through the code base, you’ll see that rewrites have a bunch of conditions that check whether the rewrite could be applied to the inputs. When the conditions fail, the rewrite returns None. When it succeeds, it returns the modified graph or node. By adding the extra check on that condition, we are explicitly telling pytensor that the rewrite can’t work if the indexing operation mixes integers and other basic indexing things.

Hashcode-Ankit

Hashcode-Ankit commented on Apr 27, 2025

@Hashcode-Ankit

Hi, I have raised a pr for the same, can you check that once.

Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

      Development

      Participants

      @fonnesbeck@lucianopaz@Hashcode-Ankit

      Issue actions

        BUG: `AdvancedSubTensor` with `None` and integer indices raises a `logprob` error instead of silently failing · Issue #7762 · pymc-devs/pymc