Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Advanced Indexing does not trace correctly for tensor shape that has leading 1s #49852

Open
ppwwyyxx opened this issue Dec 25, 2020 · 0 comments
Assignees
Labels
days module: advanced indexing Related to x[i] = y, index functions oncall: jit Add this issue/PR to JIT oncall triage queue
Projects

Comments

@ppwwyyxx
Copy link
Contributor

ppwwyyxx commented Dec 25, 2020

馃悰 Bug

A statement like a[b] = c where a, b, c are all tensors and a[b].shape == c.shape should be able to trace correctly because no reshape is needed. However it still uses reshape under the hood, and therefore does not trace correctly (does not generalize) when c.shape has leading 1s.

To Reproduce

Steps to reproduce the behavior:

import torch

def tensor_setitem(x, idx, y):
        x[idx] = y + 1
        return x

x = torch.randn(3, 4)
idx = torch.tensor([1, 2])  # works
idx = torch.tensor([1])  # fails
traced = torch.jit.trace(tensor_setitem, (x, idx, x[idx]))
x = torch.randn(10, 5)
assert torch.allclose(traced(x.clone(), idx, x[idx]),
                                 tensor_setitem(x.clone(), idx, x[idx]))

prints:

RuntimeError: shape '[4]' is invalid for input of size 5

Environment

pytorch master

Additional context

An old commit #9424 aims at skipping reshape to make the op traceable. However it only skips unnecessary reshapes when "c.shape does not have leading 1s" - which now results in this issue. This logic is later mirrored to aten in #32841.

#45828 fixes a similar problem for basic indexing. Advanced indexing seems more tricky.

cc @gmagogsfm

@ngimel ngimel added the oncall: jit Add this issue/PR to JIT oncall triage queue label Dec 26, 2020
@github-actions github-actions bot added this to Need triage in JIT Triage Dec 26, 2020
@ngimel ngimel added the module: advanced indexing Related to x[i] = y, index functions label Dec 26, 2020
facebook-github-bot pushed a commit to facebookresearch/detectron2 that referenced this issue Dec 26, 2020
Summary:
Previously the trace uses advanced indexing. Due to pytorch/pytorch#49852
this does not generalize if certain level has only one box.

added a unittest case for this case

Reviewed By: alexander-kirillov

Differential Revision: D25706070

fbshipit-source-id: 5110fbba0794ca0d75000bbb95ed3edd7d12cc2a
@wanchaol wanchaol added the days label Jan 8, 2021
@wanchaol wanchaol moved this from Need triage to Pending in JIT Triage Jan 8, 2021
@ppwwyyxx ppwwyyxx changed the title Advanced Indexing is not traceable for tensor shape that has leading 1s Advanced Indexing does not trace correctly for tensor shape that has leading 1s May 22, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
days module: advanced indexing Related to x[i] = y, index functions oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
JIT Triage
  
Pending
Development

No branches or pull requests

4 participants