Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch.compile] WRONG VALUE for split+cat #99686

Closed
payphon opened this issue Apr 21, 2023 · 1 comment
Closed

[torch.compile] WRONG VALUE for split+cat #99686

payphon opened this issue Apr 21, 2023 · 1 comment
Assignees
Labels
module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@payphon
Copy link

payphon commented Apr 21, 2023

🐛 Describe the bug

torch.compile returns WRONG VALUE for split+cat

import torch

torch.manual_seed(420)

class Model(torch.nn.Module):

    def __init__(self):
        super(Model, self).__init__()

    def forward(self, x):
        x = torch.split(x, [3, 2, 3], dim=1)
        x = torch.cat([x[1], x[0], x[2]], dim=1)
        return x

input_tensor = torch.randn(1, 8)

func = Model().to('cpu')

print(input_tensor)
# tensor([[-1.6977,  0.6374,  0.0781, -0.4140,  1.5172,  0.0473,  0.8435, -0.2261]])

res1 = func(input_tensor)
print(res1)
# tensor([[-0.4140,  1.5172, -1.6977,  0.6374,  0.0781,  0.0473,  0.8435, -0.2261]])

jit_func = torch.compile(func)
res2 = jit_func(input_tensor)
print(res2)
# tensor([[-1.6977,  0.6374,  0.0781, -0.4140,  1.5172,  0.0473,  0.8435, -0.2261]])

The model will swap (0, 1) and (2, 3, 4) elements, but torch.compile will return the original tensor value

It may be caused by a bug in splitwithsizes_cat_replace

Versions

Collecting environment information...
PyTorch version: 2.1.0.dev20230419+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.1 LTS (x86_64)
GCC version: (Ubuntu 11.3.0-1ubuntu1~22.04) 11.3.0
Clang version: 14.0.0-1ubuntu1
CMake version: Could not collect
Libc version: glibc-2.35

Python version: 3.9.16 (main, Mar  8 2023, 14:00:05)  [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.19.5-051905-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 11.5.119
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3060
Nvidia driver version: 510.108.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Address sizes:                   46 bits physical, 48 bits virtual
Byte Order:                      Little Endian
CPU(s):                          24
On-line CPU(s) list:             0-23
Vendor ID:                       GenuineIntel
Model name:                      12th Gen Intel(R) Core(TM) i9-12900K
CPU family:                      6
Model:                           151
Thread(s) per core:              2
Core(s) per socket:              16
Socket(s):                       1
Stepping:                        2
CPU max MHz:                     6700.0000
CPU min MHz:                     800.0000
BogoMIPS:                        6374.40
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault cat_l2 invpcid_single cdp_l2 ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdt_a rdseed adx smap clflushopt clwb intel_pt sha_ni xsaveopt xsavec xgetbv1 xsaves split_lock_detect avx_vnni dtherm ida arat pln pts hwp hwp_notify hwp_act_window hwp_epp hwp_pkg_req hfi umip pku ospke waitpkg gfni vaes vpclmulqdq tme rdpid movdiri movdir64b fsrm md_clear serialize pconfig arch_lbr ibt flush_l1d arch_capabilities
Virtualization:                  VT-x
L1d cache:                       640 KiB (16 instances)
L1i cache:                       768 KiB (16 instances)
L2 cache:                        14 MiB (10 instances)
L3 cache:                        30 MiB (1 instance)
NUMA node(s):                    1
NUMA node0 CPU(s):               0-23
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Mmio stale data:   Not affected
Vulnerability Retbleed:          Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected

Versions of relevant libraries:
[pip3] numpy==1.24.1
[pip3] pytorch-triton==2.1.0+46672772b4
[pip3] torch==2.1.0.dev20230419+cu118
[pip3] torchaudio==2.1.0.dev20230419+cu118
[pip3] torchvision==0.16.0.dev20230419+cu118
[conda] numpy                     1.24.1                   pypi_0    pypi
[conda] pytorch-triton            2.1.0+46672772b4          pypi_0    pypi
[conda] torch                     2.1.0.dev20230419+cu118          pypi_0    pypi
[conda] torchaudio                2.1.0.dev20230419+cu118          pypi_0    pypi
[conda] torchvision               0.16.0.dev20230419+cu118          pypi_0    pypi

cc @ezyang @soumith @msaroufim @wconstab @ngimel @bdhirsh @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire

@payphon
Copy link
Author

payphon commented Apr 21, 2023

Besides, this bug can even make some invalid split succeed

import torch

torch.manual_seed(420)

class Model(torch.nn.Module):

    def __init__(self):
        super(Model, self).__init__()

    def forward(self, input):
        split_node = torch.split(input, [2, 1, 1], dim=1)
        cat_node = torch.cat([split_node[0], split_node[1], split_node[2]], dim=1)
        return cat_node

input_tensor = torch.randn(1, 5)
print(input_tensor)
# tensor([[-1.6977,  0.6374,  0.0781, -0.4140,  1.5172]])

func = Model().to('cpu')

jit_func = torch.compile(func)
res2 = jit_func(input_tensor)
print(res2)
# tensor([[-1.6977,  0.6374,  0.0781, -0.4140,  1.5172]])

res1 = func(input_tensor)
print(res1)
# RuntimeError: split_with_sizes expects split_sizes to sum exactly to 5 (input tensor's size at dimension 1), but got split_sizes=[2, 1, 1]

@XiaobingSuper XiaobingSuper self-assigned this Apr 21, 2023
@yanboliang yanboliang added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: inductor oncall: pt2 labels Apr 21, 2023
XiaobingSuper added a commit that referenced this issue Apr 21, 2023
…t_with_sizes"



Fix #99686, for eager mode, if the given sizes is not meet requirements, it will report an error, but inductor can run, I think we need align inductor behavior with eager mode, the behavior will be like after this PR:

```
Traceback (most recent call last):
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1267, in run_node
    return node.target(*args, **kwargs)
  File "/home/xiaobing/pytorch-offical/torch/functional.py", line 189, in split
    return tensor.split(split_size_or_sections, dim)
  File "/home/xiaobing/pytorch-offical/torch/_tensor.py", line 804, in split
    return torch._VF.split_with_sizes(self, split_size, dim)
  File "/home/xiaobing/pytorch-offical/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/home/xiaobing/pytorch-offical/torch/_subclasses/fake_tensor.py", line 1095, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/home/xiaobing/pytorch-offical/torch/_subclasses/fake_tensor.py", line 1259, in dispatch
    return decomposition_table[func](*args, **kwargs)
  File "/home/xiaobing/pytorch-offical/torch/_decomp/decompositions.py", line 1102, in split_with_sizes
    raise ValueError(
ValueError: Split sizes don't add up to the tensor's size in the given dimension

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1215, in get_fake_value
    return wrap_fake_exception(
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 835, in wrap_fake_exception
    return fn()
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1216, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1279, in run_node
    raise RuntimeError(
RuntimeError: Failed running call_function <function split at 0x7f45b8402ee0>(*(FakeTensor(..., size=(1, 5)), [2, 1, 1]), **{'dim': 1}):
Split sizes don't add up to the tensor's size in the given dimension
(scroll up for backtrace)

The above exception was the direct cause of the following exception:
```

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
XiaobingSuper added a commit that referenced this issue Apr 21, 2023
…t_with_sizes"



Fix #99686, for eager mode, if the given sizes is not meet requirements, it will report an error, but inductor can run, I think we need align inductor behavior with eager mode, the behavior will be like after this PR:

```
Traceback (most recent call last):
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1267, in run_node
    return node.target(*args, **kwargs)
  File "/home/xiaobing/pytorch-offical/torch/functional.py", line 189, in split
    return tensor.split(split_size_or_sections, dim)
  File "/home/xiaobing/pytorch-offical/torch/_tensor.py", line 804, in split
    return torch._VF.split_with_sizes(self, split_size, dim)
  File "/home/xiaobing/pytorch-offical/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/home/xiaobing/pytorch-offical/torch/_subclasses/fake_tensor.py", line 1095, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/home/xiaobing/pytorch-offical/torch/_subclasses/fake_tensor.py", line 1259, in dispatch
    return decomposition_table[func](*args, **kwargs)
  File "/home/xiaobing/pytorch-offical/torch/_decomp/decompositions.py", line 1102, in split_with_sizes
    raise ValueError(
ValueError: Split sizes don't add up to the tensor's size in the given dimension

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1215, in get_fake_value
    return wrap_fake_exception(
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 835, in wrap_fake_exception
    return fn()
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1216, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1279, in run_node
    raise RuntimeError(
RuntimeError: Failed running call_function <function split at 0x7f45b8402ee0>(*(FakeTensor(..., size=(1, 5)), [2, 1, 1]), **{'dim': 1}):
Split sizes don't add up to the tensor's size in the given dimension
(scroll up for backtrace)

The above exception was the direct cause of the following exception:
```

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
XiaobingSuper added a commit that referenced this issue Apr 21, 2023
…er mode for split_with_sizes"



Fix #99686, for eager mode, if the given sizes is not meet requirements, it will report an error, but inductor can run, I think we need align inductor behavior with eager mode, the behavior will be like after this PR:

```
Traceback (most recent call last):
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1267, in run_node
    return node.target(*args, **kwargs)
  File "/home/xiaobing/pytorch-offical/torch/functional.py", line 189, in split
    return tensor.split(split_size_or_sections, dim)
  File "/home/xiaobing/pytorch-offical/torch/_tensor.py", line 804, in split
    return torch._VF.split_with_sizes(self, split_size, dim)
  File "/home/xiaobing/pytorch-offical/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/home/xiaobing/pytorch-offical/torch/_subclasses/fake_tensor.py", line 1095, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/home/xiaobing/pytorch-offical/torch/_subclasses/fake_tensor.py", line 1259, in dispatch
    return decomposition_table[func](*args, **kwargs)
  File "/home/xiaobing/pytorch-offical/torch/_decomp/decompositions.py", line 1102, in split_with_sizes
    raise ValueError(
ValueError: Split sizes don't add up to the tensor's size in the given dimension

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1215, in get_fake_value
    return wrap_fake_exception(
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 835, in wrap_fake_exception
    return fn()
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1216, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1279, in run_node
    raise RuntimeError(
RuntimeError: Failed running call_function <function split at 0x7f45b8402ee0>(*(FakeTensor(..., size=(1, 5)), [2, 1, 1]), **{'dim': 1}):
Split sizes don't add up to the tensor's size in the given dimension
(scroll up for backtrace)

The above exception was the direct cause of the following exception:
```

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
XiaobingSuper added a commit that referenced this issue Apr 21, 2023
…the split output's order"



we should make sure the cat order does align with the split output's order before removing the cat operation. Fix #99686.

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
XiaobingSuper added a commit that referenced this issue Apr 21, 2023
…t_with_sizes"



Fix #99686, for eager mode, if the given sizes is not meet requirements, it will report an error, but inductor can run, I think we need align inductor behavior with eager mode, the behavior will be like after this PR:

```
Traceback (most recent call last):
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1267, in run_node
    return node.target(*args, **kwargs)
  File "/home/xiaobing/pytorch-offical/torch/functional.py", line 189, in split
    return tensor.split(split_size_or_sections, dim)
  File "/home/xiaobing/pytorch-offical/torch/_tensor.py", line 804, in split
    return torch._VF.split_with_sizes(self, split_size, dim)
  File "/home/xiaobing/pytorch-offical/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/home/xiaobing/pytorch-offical/torch/_subclasses/fake_tensor.py", line 1095, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/home/xiaobing/pytorch-offical/torch/_subclasses/fake_tensor.py", line 1259, in dispatch
    return decomposition_table[func](*args, **kwargs)
  File "/home/xiaobing/pytorch-offical/torch/_decomp/decompositions.py", line 1102, in split_with_sizes
    raise ValueError(
ValueError: Split sizes don't add up to the tensor's size in the given dimension

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1215, in get_fake_value
    return wrap_fake_exception(
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 835, in wrap_fake_exception
    return fn()
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1216, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1279, in run_node
    raise RuntimeError(
RuntimeError: Failed running call_function <function split at 0x7f45b8402ee0>(*(FakeTensor(..., size=(1, 5)), [2, 1, 1]), **{'dim': 1}):
Split sizes don't add up to the tensor's size in the given dimension
(scroll up for backtrace)

The above exception was the direct cause of the following exception:
```

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this issue Apr 25, 2023
#99702)

Fix #99686, for eager mode, if the given sizes is not meet requirements, it will report an error, but inductor can run, I think we need align inductor behavior with eager mode, the behavior will be like after this PR:

```
Traceback (most recent call last):
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1267, in run_node
    return node.target(*args, **kwargs)
  File "/home/xiaobing/pytorch-offical/torch/functional.py", line 189, in split
    return tensor.split(split_size_or_sections, dim)
  File "/home/xiaobing/pytorch-offical/torch/_tensor.py", line 804, in split
    return torch._VF.split_with_sizes(self, split_size, dim)
  File "/home/xiaobing/pytorch-offical/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/home/xiaobing/pytorch-offical/torch/_subclasses/fake_tensor.py", line 1095, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/home/xiaobing/pytorch-offical/torch/_subclasses/fake_tensor.py", line 1259, in dispatch
    return decomposition_table[func](*args, **kwargs)
  File "/home/xiaobing/pytorch-offical/torch/_decomp/decompositions.py", line 1102, in split_with_sizes
    raise ValueError(
ValueError: Split sizes don't add up to the tensor's size in the given dimension

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1215, in get_fake_value
    return wrap_fake_exception(
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 835, in wrap_fake_exception
    return fn()
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1216, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1279, in run_node
    raise RuntimeError(
RuntimeError: Failed running call_function <function split at 0x7f45b8402ee0>(*(FakeTensor(..., size=(1, 5)), [2, 1, 1]), **{'dim': 1}):
Split sizes don't add up to the tensor's size in the given dimension
(scroll up for backtrace)

The above exception was the direct cause of the following exception:
```

Pull Request resolved: #99702
Approved by: https://github.com/jgong5, https://github.com/lezcano, https://github.com/jansel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants