This repository was archived by the owner on Aug 1, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 129
This repository was archived by the owner on Aug 1, 2025. It is now read-only.
[inductor] index_put - XLNetLMHeadModel #1356
Copy link
Copy link
Closed
Labels
Description
benchmarks/huggingface.py --training -dcuda --accuracy --training --inductor --only=XLNetLMHeadModel
Error
RuntimeError: Overloaded torch operator invoked from Python failed to many any schema:
aten::index_put_() Expected a value of type 'List[Optional[Tensor]]' for argument 'indices' but instead found type 'tuple'.
Position: 1
Value: (slice(None, None, None), slice(None, None, None), slice(None, None, None), FakeTensor(FakeTensor(..., device='meta', size=(512,), dtype=torch.int64), cuda:0))
Declaration: aten::index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!)
Cast error details: Unable to cast slice(None, None, None) to Tensor
aten::index_put_() Expected a value of type 'List[Tensor]' for argument 'indices' but instead found type 'tuple'.
Position: 1
Value: (slice(None, None, None), slice(None, None, None), slice(None, None, None), FakeTensor(FakeTensor(..., device='meta', size=(512,), dtype=torch.int64), cuda:0))
Declaration: aten::index_put_.hacked_twin(Tensor(a!) self, Tensor[] indices, Tensor values, bool accumulate=False) -> Tensor(a!)
Cast error details: Unable to cast Python instance to C++ type (#define PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for details)
Repro
import torch
import torchdynamo
from torch import tensor, device
import torch.fx as fx
from torchdynamo.testing import rand_strided
from math import inf
from torchdynamo.debug_utils import run_fwd_maybe_bwd
args = [((512, 1, 16, 64), (1024, 1024, 64, 1), torch.float32, 'cuda', True), ((1024, 1, 16, 64), (1024, 1024, 64, 1), torch.float32, 'cuda', True), ((16, 64), (64, 1), torch.float32, 'cuda', True), ((512,), (1,), torch.int64, 'cuda', False)]
args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args]
from torch.nn import *
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, einsum_9, einsum_12, self_self_transformer_layer_1__rel_attn_r_r_bias, arange_3):
add_7 = einsum_9 + self_self_transformer_layer_1__rel_attn_r_r_bias; einsum_9 = self_self_transformer_layer_1__rel_attn_r_r_bias = None
einsum_14 = torch.functional.einsum('ibnd,jbnd->bnij', add_7, einsum_12); add_7 = einsum_12 = None
reshape_2 = einsum_14.reshape(1, 16, 1024, 512); einsum_14 = None
getitem_4 = reshape_2[(slice(None, None, None), slice(None, None, None), slice(1, None, None), slice(None, None, None))]; reshape_2 = None
reshape_3 = getitem_4.reshape(1, 16, 512, 1023); getitem_4 = None
index_select_1 = torch.index_select(reshape_3, 3, arange_3); reshape_3 = arange_3 = None
return (index_select_1,)
mod = Repro().cuda()
opt_mod = torchdynamo.optimize("aot_inductor_debug")(mod)
with torch.cuda.amp.autocast(enabled=False):
ref = run_fwd_maybe_bwd(mod, args)
res = run_fwd_maybe_bwd(opt_mod, args)