-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[torch.compile] WRONG VALUE for split+cat
#99686
Labels
module: inductor
oncall: pt2
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Comments
Besides, this bug can even make some invalid import torch
torch.manual_seed(420)
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, input):
split_node = torch.split(input, [2, 1, 1], dim=1)
cat_node = torch.cat([split_node[0], split_node[1], split_node[2]], dim=1)
return cat_node
input_tensor = torch.randn(1, 5)
print(input_tensor)
# tensor([[-1.6977, 0.6374, 0.0781, -0.4140, 1.5172]])
func = Model().to('cpu')
jit_func = torch.compile(func)
res2 = jit_func(input_tensor)
print(res2)
# tensor([[-1.6977, 0.6374, 0.0781, -0.4140, 1.5172]])
res1 = func(input_tensor)
print(res1)
# RuntimeError: split_with_sizes expects split_sizes to sum exactly to 5 (input tensor's size at dimension 1), but got split_sizes=[2, 1, 1] |
yanboliang
added
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
module: inductor
oncall: pt2
labels
Apr 21, 2023
This was referenced Apr 21, 2023
XiaobingSuper
added a commit
that referenced
this issue
Apr 21, 2023
…t_with_sizes" Fix #99686, for eager mode, if the given sizes is not meet requirements, it will report an error, but inductor can run, I think we need align inductor behavior with eager mode, the behavior will be like after this PR: ``` Traceback (most recent call last): File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1267, in run_node return node.target(*args, **kwargs) File "/home/xiaobing/pytorch-offical/torch/functional.py", line 189, in split return tensor.split(split_size_or_sections, dim) File "/home/xiaobing/pytorch-offical/torch/_tensor.py", line 804, in split return torch._VF.split_with_sizes(self, split_size, dim) File "/home/xiaobing/pytorch-offical/torch/utils/_stats.py", line 20, in wrapper return fn(*args, **kwargs) File "/home/xiaobing/pytorch-offical/torch/_subclasses/fake_tensor.py", line 1095, in __torch_dispatch__ return self.dispatch(func, types, args, kwargs) File "/home/xiaobing/pytorch-offical/torch/_subclasses/fake_tensor.py", line 1259, in dispatch return decomposition_table[func](*args, **kwargs) File "/home/xiaobing/pytorch-offical/torch/_decomp/decompositions.py", line 1102, in split_with_sizes raise ValueError( ValueError: Split sizes don't add up to the tensor's size in the given dimension The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1215, in get_fake_value return wrap_fake_exception( File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 835, in wrap_fake_exception return fn() File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1216, in <lambda> lambda: run_node(tx.output, node, args, kwargs, nnmodule) File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1279, in run_node raise RuntimeError( RuntimeError: Failed running call_function <function split at 0x7f45b8402ee0>(*(FakeTensor(..., size=(1, 5)), [2, 1, 1]), **{'dim': 1}): Split sizes don't add up to the tensor's size in the given dimension (scroll up for backtrace) The above exception was the direct cause of the following exception: ``` cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
XiaobingSuper
added a commit
that referenced
this issue
Apr 21, 2023
…t_with_sizes" Fix #99686, for eager mode, if the given sizes is not meet requirements, it will report an error, but inductor can run, I think we need align inductor behavior with eager mode, the behavior will be like after this PR: ``` Traceback (most recent call last): File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1267, in run_node return node.target(*args, **kwargs) File "/home/xiaobing/pytorch-offical/torch/functional.py", line 189, in split return tensor.split(split_size_or_sections, dim) File "/home/xiaobing/pytorch-offical/torch/_tensor.py", line 804, in split return torch._VF.split_with_sizes(self, split_size, dim) File "/home/xiaobing/pytorch-offical/torch/utils/_stats.py", line 20, in wrapper return fn(*args, **kwargs) File "/home/xiaobing/pytorch-offical/torch/_subclasses/fake_tensor.py", line 1095, in __torch_dispatch__ return self.dispatch(func, types, args, kwargs) File "/home/xiaobing/pytorch-offical/torch/_subclasses/fake_tensor.py", line 1259, in dispatch return decomposition_table[func](*args, **kwargs) File "/home/xiaobing/pytorch-offical/torch/_decomp/decompositions.py", line 1102, in split_with_sizes raise ValueError( ValueError: Split sizes don't add up to the tensor's size in the given dimension The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1215, in get_fake_value return wrap_fake_exception( File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 835, in wrap_fake_exception return fn() File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1216, in <lambda> lambda: run_node(tx.output, node, args, kwargs, nnmodule) File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1279, in run_node raise RuntimeError( RuntimeError: Failed running call_function <function split at 0x7f45b8402ee0>(*(FakeTensor(..., size=(1, 5)), [2, 1, 1]), **{'dim': 1}): Split sizes don't add up to the tensor's size in the given dimension (scroll up for backtrace) The above exception was the direct cause of the following exception: ``` cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
XiaobingSuper
added a commit
that referenced
this issue
Apr 21, 2023
…er mode for split_with_sizes" Fix #99686, for eager mode, if the given sizes is not meet requirements, it will report an error, but inductor can run, I think we need align inductor behavior with eager mode, the behavior will be like after this PR: ``` Traceback (most recent call last): File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1267, in run_node return node.target(*args, **kwargs) File "/home/xiaobing/pytorch-offical/torch/functional.py", line 189, in split return tensor.split(split_size_or_sections, dim) File "/home/xiaobing/pytorch-offical/torch/_tensor.py", line 804, in split return torch._VF.split_with_sizes(self, split_size, dim) File "/home/xiaobing/pytorch-offical/torch/utils/_stats.py", line 20, in wrapper return fn(*args, **kwargs) File "/home/xiaobing/pytorch-offical/torch/_subclasses/fake_tensor.py", line 1095, in __torch_dispatch__ return self.dispatch(func, types, args, kwargs) File "/home/xiaobing/pytorch-offical/torch/_subclasses/fake_tensor.py", line 1259, in dispatch return decomposition_table[func](*args, **kwargs) File "/home/xiaobing/pytorch-offical/torch/_decomp/decompositions.py", line 1102, in split_with_sizes raise ValueError( ValueError: Split sizes don't add up to the tensor's size in the given dimension The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1215, in get_fake_value return wrap_fake_exception( File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 835, in wrap_fake_exception return fn() File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1216, in <lambda> lambda: run_node(tx.output, node, args, kwargs, nnmodule) File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1279, in run_node raise RuntimeError( RuntimeError: Failed running call_function <function split at 0x7f45b8402ee0>(*(FakeTensor(..., size=(1, 5)), [2, 1, 1]), **{'dim': 1}): Split sizes don't add up to the tensor's size in the given dimension (scroll up for backtrace) The above exception was the direct cause of the following exception: ``` cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
XiaobingSuper
added a commit
that referenced
this issue
Apr 21, 2023
…the split output's order" we should make sure the cat order does align with the split output's order before removing the cat operation. Fix #99686. cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
XiaobingSuper
added a commit
that referenced
this issue
Apr 21, 2023
…t_with_sizes" Fix #99686, for eager mode, if the given sizes is not meet requirements, it will report an error, but inductor can run, I think we need align inductor behavior with eager mode, the behavior will be like after this PR: ``` Traceback (most recent call last): File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1267, in run_node return node.target(*args, **kwargs) File "/home/xiaobing/pytorch-offical/torch/functional.py", line 189, in split return tensor.split(split_size_or_sections, dim) File "/home/xiaobing/pytorch-offical/torch/_tensor.py", line 804, in split return torch._VF.split_with_sizes(self, split_size, dim) File "/home/xiaobing/pytorch-offical/torch/utils/_stats.py", line 20, in wrapper return fn(*args, **kwargs) File "/home/xiaobing/pytorch-offical/torch/_subclasses/fake_tensor.py", line 1095, in __torch_dispatch__ return self.dispatch(func, types, args, kwargs) File "/home/xiaobing/pytorch-offical/torch/_subclasses/fake_tensor.py", line 1259, in dispatch return decomposition_table[func](*args, **kwargs) File "/home/xiaobing/pytorch-offical/torch/_decomp/decompositions.py", line 1102, in split_with_sizes raise ValueError( ValueError: Split sizes don't add up to the tensor's size in the given dimension The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1215, in get_fake_value return wrap_fake_exception( File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 835, in wrap_fake_exception return fn() File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1216, in <lambda> lambda: run_node(tx.output, node, args, kwargs, nnmodule) File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1279, in run_node raise RuntimeError( RuntimeError: Failed running call_function <function split at 0x7f45b8402ee0>(*(FakeTensor(..., size=(1, 5)), [2, 1, 1]), **{'dim': 1}): Split sizes don't add up to the tensor's size in the given dimension (scroll up for backtrace) The above exception was the direct cause of the following exception: ``` cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
pytorchmergebot
pushed a commit
that referenced
this issue
Apr 25, 2023
#99702) Fix #99686, for eager mode, if the given sizes is not meet requirements, it will report an error, but inductor can run, I think we need align inductor behavior with eager mode, the behavior will be like after this PR: ``` Traceback (most recent call last): File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1267, in run_node return node.target(*args, **kwargs) File "/home/xiaobing/pytorch-offical/torch/functional.py", line 189, in split return tensor.split(split_size_or_sections, dim) File "/home/xiaobing/pytorch-offical/torch/_tensor.py", line 804, in split return torch._VF.split_with_sizes(self, split_size, dim) File "/home/xiaobing/pytorch-offical/torch/utils/_stats.py", line 20, in wrapper return fn(*args, **kwargs) File "/home/xiaobing/pytorch-offical/torch/_subclasses/fake_tensor.py", line 1095, in __torch_dispatch__ return self.dispatch(func, types, args, kwargs) File "/home/xiaobing/pytorch-offical/torch/_subclasses/fake_tensor.py", line 1259, in dispatch return decomposition_table[func](*args, **kwargs) File "/home/xiaobing/pytorch-offical/torch/_decomp/decompositions.py", line 1102, in split_with_sizes raise ValueError( ValueError: Split sizes don't add up to the tensor's size in the given dimension The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1215, in get_fake_value return wrap_fake_exception( File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 835, in wrap_fake_exception return fn() File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1216, in <lambda> lambda: run_node(tx.output, node, args, kwargs, nnmodule) File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1279, in run_node raise RuntimeError( RuntimeError: Failed running call_function <function split at 0x7f45b8402ee0>(*(FakeTensor(..., size=(1, 5)), [2, 1, 1]), **{'dim': 1}): Split sizes don't add up to the tensor's size in the given dimension (scroll up for backtrace) The above exception was the direct cause of the following exception: ``` Pull Request resolved: #99702 Approved by: https://github.com/jgong5, https://github.com/lezcano, https://github.com/jansel
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
module: inductor
oncall: pt2
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
🐛 Describe the bug
torch.compile
returns WRONG VALUE forsplit+cat
The model will swap (0, 1) and (2, 3, 4) elements, but
torch.compile
will return the original tensor valueIt may be caused by a bug in
splitwithsizes_cat_replace
Versions
cc @ezyang @soumith @msaroufim @wconstab @ngimel @bdhirsh @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire
The text was updated successfully, but these errors were encountered: