-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Closed
Labels
oncall: jitAdd this issue/PR to JIT oncall triage queueAdd this issue/PR to JIT oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Bug
Unable to call __getitem__ of self defined class in torchscript function
To Reproduce
import torch
@torch.jit.script
class Foo(object):
def __init__(self, value: torch.Tensor):
self.value = value
def __getitem__(self, item: torch.Tensor):
updated_value = self.value[item]
return Foo(updated_value)
@torch.jit.script
def bar(v: Foo, index: torch.Tensor):
return v[index]
if __name__ == "__main__":
a = torch.tensor([0,1,2,3,4,5])
slice_index = torch.tensor([0,2,5])
b = Foo(a)
b_slice = b[slice_index] # this one works
print(b_slice.value)
print(bar(b, slice_index).value) # this one does not work with @torch.jit.scriptExpected behavior
fn = torch._C._jit_script_compile(qualified_name, ast, _rcb, get_default_args(obj))
RuntimeError:
'__torch__.Foo' object is not subscriptable:
at demo_class.py:14:11
@torch.jit.script
def bar(v: Foo, index: torch.Tensor):
return v[index]
~~~~~~~ <--- HERE
Environment
PyTorch version: 1.3.0.dev20190902
Is debug build: No
CUDA used to build PyTorch: 10.0.130
OS: Ubuntu 18.04.3 LTS
GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0
CMake version: version 3.10.2
Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 10.1.243
GPU models and configuration: GPU 0: GeForce GTX 1080
Nvidia driver version: 430.26
cuDNN version: Could not collect
Versions of relevant libraries:
[pip3] numpy==1.15.0
[pip3] torch==1.3.0.dev20190902
[pip3] torchvision==0.2.2.post3
[conda] Could not collect
cc @suo
Metadata
Metadata
Assignees
Labels
oncall: jitAdd this issue/PR to JIT oncall triage queueAdd this issue/PR to JIT oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module