Skip to content

AdvancedSubtensor doesn't manage to infer the static shape #1565

@lucianopaz

Description

@lucianopaz

Description

I was a bit surprised that AdvancedSubtensor was not able to determine the static shape of this operation:

a = pt.dtensor3(shape=(2, 3, 4))
ind1 = pt.ivector(shape=(10,))
ind2 = pt.ivector(shape=(10,))
ind3 = pt.ivector(shape=(10,))
b = a[ind1, ind2, ind3]

Pytensor should be able to tell that b.type.shape = (10,). Looking at the source code, I think that the issue is that indexed_result_shape is a computational graph that can be constant folded but is left as is. I can write a make_node function that acts like AdvancedSubtensor.make_node:

import itertools
from pytensor import tensor as pt
from pytensor.graph.basic import Apply, Constant
from pytensor.tensor.subtensor import indexed_result_shape
from pytensor.tensor.rewriting.basic import topo_unconditional_constant_folding
import pytensor
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor.rewriting.shape import ShapeFeature
from pytensor.graph.rewriting.utils import rewrite_graph

def make_node(x, *index):
    fake_shape = tuple(
        1 if bcast else (
            pt.tensor(dtype="int64", shape=()) if s is None else s
        )
        for s, bcast in zip(x.type.shape, x.broadcastable)
    )
    
    fake_index = tuple(
        itertools.chain.from_iterable(
            pt.basic.nonzero(idx)
            if getattr(idx, "ndim", 0) > 0
            and getattr(idx, "dtype", None) == "bool"
            else (idx,)
            for idx in index
        )
    )
    indexed = indexed_result_shape(fake_shape, fake_index)
    fg = FunctionGraph(outputs=indexed, features=[ShapeFeature()], copy_inputs=False, clone=True)
    rewrite_graph(fg)
    topo_unconditional_constant_folding(fg)
    indexed = fg.outputs
    
    out_shape = tuple(
        int(i.value) if isinstance(i, Constant) else None
        for i in indexed
    )
    return Apply(
        pt.subtensor.AdvancedSubtensor,
        (x, *index),
        [tensor(dtype=x.type.dtype, shape=out_shape)],
    )

and this infers static shape (10,):

>>> make_node(a, ind1, ind2, ind3).outputs[0].type.shape
(10,)

I'm a bit hesitant about whether this should be used as the make_node method because of the graph rewrites involved, so I decided to just open this issue to get input on whether there's a better way to get the static shape information to propagate.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions