Advanced Indexing does not trace correctly for tensor shape that has leading 1s #49852
Labels
days
module: advanced indexing
Related to x[i] = y, index functions
oncall: jit
Add this issue/PR to JIT oncall triage queue
Projects
馃悰 Bug
A statement like
a[b] = c
where a, b, c are all tensors anda[b].shape == c.shape
should be able to trace correctly because no reshape is needed. However it still uses reshape under the hood, and therefore does not trace correctly (does not generalize) whenc.shape
has leading 1s.To Reproduce
Steps to reproduce the behavior:
prints:
Environment
pytorch master
Additional context
An old commit #9424 aims at skipping reshape to make the op traceable. However it only skips unnecessary reshapes when "c.shape does not have leading 1s" - which now results in this issue. This logic is later mirrored to aten in #32841.
#45828 fixes a similar problem for basic indexing. Advanced indexing seems more tricky.
cc @gmagogsfm
The text was updated successfully, but these errors were encountered: