Skip to content

'torch._tensor' has no attribute 'split' while using torch.compile under torch.device context. #160077

@kshitij12345

Description

@kshitij12345

🐛 Describe the bug

Repro

import torch

def f(xs):
    return xs.split(1, dim=0)

def backend(gm, inps):
    gm.print_readable()

    # class GraphModule(torch.nn.Module):
    #     def forward(self, L_xs_: "f32[2, 2]"):
    #         l_xs_ = L_xs_
            
    #         # File: /usr/local/lib/python3.12/dist-packages/torch/utils/_device.py:100 in __torch_function__, code: return func(*args, **kwargs)
    #         split = torch._tensor.split(l_xs_, 1, dim = 0);  l_xs_ = None
    #         getitem: "f32[1, 2]" = split[0]
    #         getitem_1: "f32[1, 2]" = split[1];  split = None
    #         return (getitem, getitem_1)
    return gm

with torch.device("cuda"):
    xs = torch.randn(2, 2, device="cuda")

    # Eager works
    f(xs)

    # This fails with `module 'torch._tensor' has no attribute 'split'`
    torch.compile(f, backend=backend)(xs)

# Outside of device context, this works
# torch.compile(f, backend=backend)(xs)

Error

  File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 414, in __call__
    raise e.with_traceback(None)  # noqa: B904
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: module 'torch._tensor' has no attribute 'split'

Versions

torch version - 2.8

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames @Lucaskabela @zou3519

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions