Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

vmap over getitem indexing raises a RunTime error #124291

Open
ordabayevy opened this issue Apr 17, 2024 · 0 comments
Open

vmap over getitem indexing raises a RunTime error #124291

ordabayevy opened this issue Apr 17, 2024 · 0 comments
Labels
module: functorch Pertaining to torch.func or pytorch/functorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ordabayevy
Copy link

ordabayevy commented Apr 17, 2024

馃悰 Describe the bug

Calling torch.vmap over indexing raises a runtime error:

import torch


def index(x, y):
    return x[y]

batched_index = torch.vmap(index, (None, 0))

x = torch.arange(3)

# this works as expected
# tensor([[0, 2],
#         [1, 2]])
y1 = torch.tensor([[0, 2], [1, 2]])
result1 = batched_index(x, y1)
print(result1)

# this raises RuntimeError: vmap: It looks like you're calling .item() on a Tensor. We don't support vmap over calling .item()
# on a Tensor, please try to rewrite what you're doing with other operations. If error is occurring somewhere inside PyTorch 
# internals, please file a bug report.
y2 = torch.tensor([0, 2])
result2 = batched_index(x, y2)
print(result2)

Expected result2 to return torch.tensor([0, 2])

This is related to #115347 but not exactly the same.

Versions

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] optree==0.11.0
[pip3] torch==2.4.0a0+git9c4fc5f
[conda] magma-cuda121             2.6.1                         1    pytorch
[conda] magma-cuda124             2.6.1                         1    pytorch
[conda] mkl-include               2024.1.0              intel_691    intel
[conda] mkl-static                2024.1.0              intel_691    intel
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] optree                    0.11.0                   pypi_0    pypi
[conda] torch                     2.4.0a0+git9c4fc5f           dev_0    <develop>

cc @zou3519 @Chillee @samdow @kshitij12345 @janeyx99

@janeyx99 janeyx99 added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: functorch Pertaining to torch.func or pytorch/functorch labels Apr 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: functorch Pertaining to torch.func or pytorch/functorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

2 participants