-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Closed
Labels
actionablemodule: dynamomodule: vllmoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate modulevllm-compile
Description
🐛 Describe the bug
Repro
import torch
def f(xs):
return xs.split(1, dim=0)
def backend(gm, inps):
gm.print_readable()
# class GraphModule(torch.nn.Module):
# def forward(self, L_xs_: "f32[2, 2]"):
# l_xs_ = L_xs_
# # File: /usr/local/lib/python3.12/dist-packages/torch/utils/_device.py:100 in __torch_function__, code: return func(*args, **kwargs)
# split = torch._tensor.split(l_xs_, 1, dim = 0); l_xs_ = None
# getitem: "f32[1, 2]" = split[0]
# getitem_1: "f32[1, 2]" = split[1]; split = None
# return (getitem, getitem_1)
return gm
with torch.device("cuda"):
xs = torch.randn(2, 2, device="cuda")
# Eager works
f(xs)
# This fails with `module 'torch._tensor' has no attribute 'split'`
torch.compile(f, backend=backend)(xs)
# Outside of device context, this works
# torch.compile(f, backend=backend)(xs)Error
File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 414, in __call__
raise e.with_traceback(None) # noqa: B904
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: module 'torch._tensor' has no attribute 'split'Versions
torch version - 2.8
cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames @Lucaskabela @zou3519
Metadata
Metadata
Assignees
Labels
actionablemodule: dynamomodule: vllmoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate modulevllm-compile