Skip to content

Unable to subscirpt self-defined class in torchscript function #25637

@Godricly

Description

@Godricly

🐛 Bug

Unable to call __getitem__ of self defined class in torchscript function

To Reproduce

import torch

@torch.jit.script
class Foo(object):
    def __init__(self, value: torch.Tensor):
        self.value = value

    def __getitem__(self, item: torch.Tensor):
        updated_value = self.value[item]
        return Foo(updated_value)

@torch.jit.script 
def bar(v: Foo, index: torch.Tensor):
    return v[index]

if __name__ == "__main__":
    a = torch.tensor([0,1,2,3,4,5])
    slice_index = torch.tensor([0,2,5])
    b = Foo(a)
    b_slice = b[slice_index]  # this one works
    print(b_slice.value)
    print(bar(b, slice_index).value) # this one does not work with @torch.jit.script

Expected behavior

    fn = torch._C._jit_script_compile(qualified_name, ast, _rcb, get_default_args(obj))
RuntimeError: 
'__torch__.Foo' object is not subscriptable:
at demo_class.py:14:11
@torch.jit.script
def bar(v: Foo, index: torch.Tensor):
    return v[index]
           ~~~~~~~ <--- HERE

Environment

PyTorch version: 1.3.0.dev20190902
Is debug build: No
CUDA used to build PyTorch: 10.0.130

OS: Ubuntu 18.04.3 LTS
GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0
CMake version: version 3.10.2

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 10.1.243
GPU models and configuration: GPU 0: GeForce GTX 1080
Nvidia driver version: 430.26
cuDNN version: Could not collect

Versions of relevant libraries:
[pip3] numpy==1.15.0
[pip3] torch==1.3.0.dev20190902
[pip3] torchvision==0.2.2.post3
[conda] Could not collect

cc @suo

Metadata

Metadata

Assignees

Labels

oncall: jitAdd this issue/PR to JIT oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions