Skip to content

[Inductor][miscompile] Input of torch.add abnormally changed when processing complex64 data #115071

@Azyka

Description

@Azyka

🐛 Describe the bug

When executing the operator torch.add with cpu inductor, the input should not be changed after execution. However, when feeding the operator with complex64 data, the initial input with be changed after execution, thus causing the change of outputs during multiple execution. For example, torch.add(a, a) will make a twice of its initial values after execution.

Error logs

=========================
torch_compile triggers assertion

Not equal to tolerance rtol=1e-07, atol=1e-06
at v5_0, v5_0
Mismatched elements: 41 / 41 (100%)
Max absolute difference: 6.4343762
Max relative difference: 0.5
 x: array([-5.149941+0.j, -4.457067+0.j, -3.683667+0.j, -5.871741+0.j,
       -3.181744+0.j, -4.034095+0.j, -4.833083+0.j, -5.038692+0.j,
       -6.156036+0.j, -5.223591+0.j, -4.831304+0.j, -4.913375+0.j,...
 y: array([-10.299882+0.j,  -8.914133+0.j,  -7.367335+0.j, -11.743482+0.j,
        -6.363487+0.j,  -8.068191+0.j,  -9.666165+0.j, -10.077383+0.j,
       -12.312072+0.j, -10.447183+0.j,  -9.662608+0.j,  -9.826751+0.j,...
=========================
=========================
torch_eager does not trigger assertion
=========================

Minified repro

import numpy as np
from numpy import testing
import torch

DEVICE='cpu'

class Model0(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, *args):
        neg = torch.neg(args[0])
        add = torch.add(args[0], args[0])
        return (neg,  add)

model_0 = Model0()
output_names_0 = ['v5_0', 'v3_0']

data_0 = np.random.normal(5, 1, size=(41,)).astype(np.complex64)
input_data_0 = [data_0,]

optmodel_0 = torch.compile(model_0, fullgraph=True, backend='inductor', mode=None)
model_out_0 = optmodel_0(*[torch.from_numpy(v).to(DEVICE) for v in input_data_0])
model_out_0 = [v.to(DEVICE).detach() for v in model_out_0] if isinstance(model_out_0, tuple) else [model_out_0.to(DEVICE).detach()]
model_out_0 = [v.cpu().resolve_conj().numpy() if v.is_conj() else v.cpu().numpy() for v in model_out_0]
output_0 = dict(zip(output_names_0, model_out_0))

input_data_1 = input_data_0

optmodel_1 = torch.compile(model_0, fullgraph=True, backend='inductor', mode=None)
model_out_1 = optmodel_1(*[torch.from_numpy(v).to(DEVICE) for v in input_data_1])
model_out_1 = [v.to(DEVICE).detach() for v in model_out_1] if isinstance(model_out_1, tuple) else [model_out_1.to(DEVICE).detach()]
model_out_1 = [v.cpu().resolve_conj().numpy() if v.is_conj() else v.cpu().numpy() for v in model_out_1]
output_1 = dict(zip(output_names_0, model_out_1))

output_name_dict = {'v3_0': 'v3_0', 'v5_0': 'v5_0'}

print('=========================')
try:
    for tensor_name_0, tensor_name_1 in output_name_dict.items():
        testing.assert_allclose(output_0[tensor_name_0], output_1[tensor_name_1], atol=1e-6, err_msg=f'at {tensor_name_0}, {tensor_name_1}')
    print("torch_compile does not trigger assertion")
except AssertionError as e:
    print("torch_compile triggers assertion")
    print(e)
print('=========================')

model_out_0 = model_0(*[torch.from_numpy(v).to(DEVICE) for v in input_data_0])
model_out_0 = [v.to(DEVICE).detach() for v in model_out_0] if isinstance(model_out_0, tuple) else [model_out_0.to(DEVICE).detach()]
model_out_0 = [v.cpu().resolve_conj().numpy() if v.is_conj() else v.cpu().numpy() for v in model_out_0]
output_0 = dict(zip(output_names_0, model_out_0))

model_out_1 = model_0(*[torch.from_numpy(v).to(DEVICE) for v in input_data_1])
model_out_1 = [v.to(DEVICE).detach() for v in model_out_1] if isinstance(model_out_1, tuple) else [model_out_1.to(DEVICE).detach()]
model_out_1 = [v.cpu().resolve_conj().numpy() if v.is_conj() else v.cpu().numpy() for v in model_out_1]
output_1 = dict(zip(output_names_0, model_out_1))

print('=========================')
try:
    for tensor_name_0, tensor_name_1 in output_name_dict.items():
        testing.assert_allclose(output_0[tensor_name_0], output_1[tensor_name_1], atol=1e-6, err_msg=f'at {tensor_name_0}, {tensor_name_1}')
    print("torch_eager does not trigger assertion")
except AssertionError as e:
    print("torch_eager triggers assertion")
    print(e)
print('=========================')

Versions

PyTorch version: 2.2.0.dev20231129+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-6.2.0-37-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 11.8.89
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: Tesla V100S-PCIE-32GB
GPU 1: Tesla V100S-PCIE-32GB
GPU 2: Tesla V100S-PCIE-32GB
GPU 3: Tesla V100S-PCIE-32GB
GPU 4: Tesla V100S-PCIE-32GB
GPU 5: Tesla V100S-PCIE-32GB
GPU 6: Tesla V100S-PCIE-32GB
GPU 7: Tesla V100S-PCIE-32GB

Nvidia driver version: 525.147.05
cuDNN version: Probably one of the following:
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn.so
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn.so.8
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_adv_train.so.8
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_ops_train.so.8
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 256
On-line CPU(s) list: 0-255
Vendor ID: AuthenticAMD
Model name: AMD EPYC 7763 64-Core Processor
CPU family: 25
Model: 1
Thread(s) per core: 2
Core(s) per socket: 64
Socket(s): 2
Stepping: 1
Frequency boost: enabled
CPU max MHz: 3529.0520
CPU min MHz: 1500.0000
BogoMIPS: 4890.78
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin brs arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm
Virtualization: AMD-V
L1d cache: 4 MiB (128 instances)
L1i cache: 4 MiB (128 instances)
L2 cache: 64 MiB (128 instances)
L3 cache: 512 MiB (16 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-63,128-191
NUMA node1 CPU(s): 64-127,192-255
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] numpy==1.24.3
[pip3] onnx==1.14.1
[pip3] onnxruntime==1.15.1
[pip3] onnxsim==0.4.35
[pip3] pytorch-triton==2.1.0+6e4932cda8
[pip3] torch==2.2.0.dev20231129+cu118
[pip3] torchaudio==2.2.0.dev20231129+cu118
[pip3] torchvision==0.17.0.dev20231129+cu118
[pip3] triton==2.1.0
[conda] Could not collect

cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305 @zou3519 @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler

Metadata

Metadata

Assignees

Labels

module: inductoroncall: cpu inductorCPU Inductor issues for Intel team to triageoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions