-
Notifications
You must be signed in to change notification settings - Fork 25.2k
Closed
Labels
module: dynamic shapesoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
import torch
@torch.compile(dynamic=True)
def fn(x, sections):
return torch.dsplit(x, sections)
fn(torch.randn(4, 4, 4), [1,2,3])
File "/scratch/anijain/work/pytorch/torch/_dynamo/utils.py", line 890, in wrap_fake_exception
return fn()
File "/scratch/anijain/work/pytorch/torch/_dynamo/utils.py", line 1301, in <lambda>
lambda: run_node(tx.output, node, args, kwargs, nnmodule)
File "/scratch/anijain/work/pytorch/torch/_dynamo/utils.py", line 1366, in run_node
raise RuntimeError(fn_str + str(e)).with_traceback(e.__traceback__) from e
File "/scratch/anijain/work/pytorch/torch/_dynamo/utils.py", line 1353, in run_node
return node.target(*args, **kwargs)
torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in method dsplit of type object at 0x7fc96edb3fa0>(*(FakeTensor(..., size=(s0, s0, s0)), [1, s1, s2]), **{}):
dsplit() received an invalid combination of arguments - got (FakeTensor, immutable_list), but expected one of:
* (Tensor input, int sections)
didn't match because some of the arguments have invalid types: (FakeTensor, immutable_list of [int, SymInt, SymInt])
* (Tensor input, tuple of ints indices)
didn't match because some of the arguments have invalid types: (FakeTensor, immutable_list of [int, SymInt, SymInt])
from user code:
File "/scratch/anijain/work/pytorch/examples/dsplit.py", line 5, in fn
return torch.dsplit(x, sections)
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
cc @ezyang @msaroufim @wconstab @bdhirsh
Versions
N/A
Metadata
Metadata
Assignees
Labels
module: dynamic shapesoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module