-
Notifications
You must be signed in to change notification settings - Fork 139
Closed
Description
Description
I was a bit surprised that AdvancedSubtensor
was not able to determine the static shape of this operation:
a = pt.dtensor3(shape=(2, 3, 4))
ind1 = pt.ivector(shape=(10,))
ind2 = pt.ivector(shape=(10,))
ind3 = pt.ivector(shape=(10,))
b = a[ind1, ind2, ind3]
Pytensor should be able to tell that b.type.shape = (10,)
. Looking at the source code, I think that the issue is that indexed_result_shape
is a computational graph that can be constant folded but is left as is. I can write a make_node
function that acts like AdvancedSubtensor.make_node
:
import itertools
from pytensor import tensor as pt
from pytensor.graph.basic import Apply, Constant
from pytensor.tensor.subtensor import indexed_result_shape
from pytensor.tensor.rewriting.basic import topo_unconditional_constant_folding
import pytensor
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor.rewriting.shape import ShapeFeature
from pytensor.graph.rewriting.utils import rewrite_graph
def make_node(x, *index):
fake_shape = tuple(
1 if bcast else (
pt.tensor(dtype="int64", shape=()) if s is None else s
)
for s, bcast in zip(x.type.shape, x.broadcastable)
)
fake_index = tuple(
itertools.chain.from_iterable(
pt.basic.nonzero(idx)
if getattr(idx, "ndim", 0) > 0
and getattr(idx, "dtype", None) == "bool"
else (idx,)
for idx in index
)
)
indexed = indexed_result_shape(fake_shape, fake_index)
fg = FunctionGraph(outputs=indexed, features=[ShapeFeature()], copy_inputs=False, clone=True)
rewrite_graph(fg)
topo_unconditional_constant_folding(fg)
indexed = fg.outputs
out_shape = tuple(
int(i.value) if isinstance(i, Constant) else None
for i in indexed
)
return Apply(
pt.subtensor.AdvancedSubtensor,
(x, *index),
[tensor(dtype=x.type.dtype, shape=out_shape)],
)
and this infers static shape (10,)
:
>>> make_node(a, ind1, ind2, ind3).outputs[0].type.shape
(10,)
I'm a bit hesitant about whether this should be used as the make_node
method because of the graph rewrites involved, so I decided to just open this issue to get input on whether there's a better way to get the static shape information to propagate.
Metadata
Metadata
Assignees
Labels
No labels