Skip to content

Output mismatch of torch.gt with torch.compile when swapping output sequence #113014

@Azyka

Description

@Azyka

🐛 Describe the bug

When the model has multiple outputs, changing the sequence of the outputs should not affect their values. However, by swapping output sequence, the output of torch.gt differs for the same input in compiled execution. And it works properly in eager execution.

Error logs

Output of numpy.testing.assert_allclose()

=========================
torch_complie triggers assertion

Not equal to tolerance rtol=1, atol=0

Mismatched elements: 2576 / 2632 (97.9%)
 x: array([[ True,  True,  True, ...,  True,  True, False],
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True, False],...
 y: array([[False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False,  True],
       [False, False, False, ..., False, False, False],...
=========================
=========================
torch_eager does not trigger assertion
=========================

Minified repro

import numpy as np
import pickle
from numpy import testing
import torch

class Model0(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, *args):
        _args = args
        getitem = _args[0];  _args = None
        pad = torch.nn.functional.pad(getitem, (46, 0), 'constant', value = 0.5)
        mean = pad.mean(0)
        to = pad.to(dtype = torch.bool)
        gt = torch.gt(mean, pad)
        return (to, gt)

model_0 = Model0()
output_names_0 = ['v4_0', 'v3_0']

class Model1(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, *args):
        _args = args
        getitem = _args[0];  _args = None
        pad = torch.nn.functional.pad(getitem, (46, 0), 'constant', value = 0.5)
        mean = pad.mean(0)
        to = pad.to(dtype = torch.bool)
        gt = torch.gt(mean, pad)
        return (gt, to)

model_1 = Model1()
output_names_1 = ['v3_0', 'v2_0']

input_data_0 = np.array([[6.46 ], [3.229], [5.785], [5.145], [6.676], [3.998], [5.77 ], [3.09 ], [4.66 ], [6.348], [4.273], [5.133], [3.367], [5.7  ], [3.979], [5.836], [4.543], [6.586], [3.504], [3.416], [3.117], [4.58 ], [5.793], [5.24 ], [5.566], [6.688], [3.459], [6.13 ], [3.07 ], [6.824], [4.91 ], [6.938], [4.38 ], [3.69 ], [5.324], [4.957], [4.12 ], [3.271], [5.375], [4.223], [3.71 ], [3.252], [6.504], [3.713], [5.285], [4.145], [3.746], [5.414], [3.84 ], [6.08 ], [6.457], [3.57 ], [5.805], [3.318], [4.215], [4.473]], dtype=np.float16)
input_data = [input_data_0]

optmodel_0 = torch.compile(model_0, fullgraph=True, backend='inductor', mode=None)
model_out_0 = optmodel_0(*[torch.from_numpy(v).to('cpu') for v in input_data])
model_out_0 = [v.cpu().detach() for v in model_out_0] if isinstance(model_out_0, tuple) else [model_out_0.cpu().detach()]
model_out_0 = [v.resolve_conj().numpy() if v.is_conj() else v.numpy() for v in model_out_0]
output_0 = dict(zip(output_names_0, model_out_0))

optmodel_1 = torch.compile(model_1, fullgraph=True, backend='inductor', mode=None)
model_out_1 = optmodel_1(*[torch.from_numpy(v).to('cpu') for v in input_data])
model_out_1 = [v.cpu().detach() for v in model_out_1] if isinstance(model_out_1, tuple) else [model_out_1.cpu().detach()]
model_out_1 = [v.resolve_conj().numpy() if v.is_conj() else v.numpy() for v in model_out_1]
output_1 = dict(zip(output_names_1, model_out_1))
output_name_dict = {'v4_0': 'v2_0', 'v3_0': 'v3_0'}

print('=========================')
try:
    for tensor_name_0, tensor_name_1 in output_name_dict.items():
        testing.assert_allclose(output_0[tensor_name_0], output_1[tensor_name_1], rtol=1)
    print("torch_complie does not trigger assertion")
except AssertionError as e:
    print("torch_complie triggers assertion")
    print(e)
print('=========================')

model_out_0 = model_0(*[torch.from_numpy(v).to('cpu') for v in input_data])
model_out_0 = [v.cpu().detach() for v in model_out_0] if isinstance(model_out_0, tuple) else [model_out_0.cpu().detach()]
model_out_0 = [v.resolve_conj().numpy() if v.is_conj() else v.numpy() for v in model_out_0]
output_0 = dict(zip(output_names_0, model_out_0))

model_out_1 = model_1(*[torch.from_numpy(v).to('cpu') for v in input_data])
model_out_1 = [v.cpu().detach() for v in model_out_1] if isinstance(model_out_1, tuple) else [model_out_1.cpu().detach()]
model_out_1 = [v.resolve_conj().numpy() if v.is_conj() else v.numpy() for v in model_out_1]
output_1 = dict(zip(output_names_1, model_out_1))

print('=========================')
try:
    for tensor_name_0, tensor_name_1 in output_name_dict.items():
        testing.assert_allclose(output_0[tensor_name_0], output_1[tensor_name_1], rtol=1)
    print("torch_eager does not trigger assertion")
except AssertionError as e:
    print("torch_eager triggers assertion")
    print(e)
print('=========================')

Versions

PyTorch version: 2.1.0+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35

Python version: 3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-6.1.10-1-pve-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 43 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 128
On-line CPU(s) list: 0-127
Vendor ID: AuthenticAMD
Model name: AMD EPYC 7742 64-Core Processor
CPU family: 23
Model: 49
Thread(s) per core: 2
Core(s) per socket: 64
Socket(s): 1
Stepping: 0
Frequency boost: enabled
CPU max MHz: 2250.0000
CPU min MHz: 1500.0000
BogoMIPS: 4500.19
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip rdpid overflow_recov succor smca sev sev_es
Virtualization: AMD-V
L1d cache: 2 MiB (64 instances)
L1i cache: 2 MiB (64 instances)
L2 cache: 32 MiB (64 instances)
L3 cache: 256 MiB (16 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-127
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Mitigation; untrained return thunk; SMT enabled with STIBP protection
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] numpy==1.24.3
[pip3] onnx==1.14.1
[pip3] onnxruntime==1.15.1
[pip3] torch==2.1.0+cu118
[pip3] torchaudio==2.1.0+cu118
[pip3] torchvision==0.16.0+cu118
[pip3] triton==2.1.0
[conda] Could not collect

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @wconstab @bdhirsh @anijain2305

Metadata

Metadata

Assignees

Labels

high prioritymodule: correctness (silent)issue that returns an incorrect result silentlyoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions