Skip to content

[Dynamo] Bad accuracy detected for torch.flip on CUDA #131805

@bohnstingl

Description

@bohnstingl

🐛 Describe the bug

I was working on the generic_associative_scan feature and I noticed in one unittest that a result involving torch.flip for a specific length (n=9) of the input sequence was creating problems. With the help of @peterbell10 I created a repro script that can be used to reproduce this behavior. In particular, when running the script attached with this checkpoint python minifier_launcher.py run --strict-accuracy the below error is raised. I found that this behavior only show when CUDA is used. With the CPU device, this error does not appear. Could somebody please help me out?

Error logs

100.0%
E0725 22:12:38.847000 1245977 torch/_dynamo/utils.py:1595] Accuracy failed: allclose not within tol=0.001
Traceback (most recent call last):
  File "minifier_launcher.py", line 93, in <module>
    run_repro(mod, load_args, accuracy=True, command='minify', save_dir=os.path.join(os.path.dirname(os.path.realpath(__file__)), 'checkpoints'), tracing_mode='real', check_str=None)
  File "/data_malta3_ssd/pytorch_git/torch/_dynamo/repro/after_aot.py", line 966, in run_repro
    return COMMAND_FNS[options.command](options, mod, load_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/repro/after_aot.py", line 729, in repro_run
    raise AccuracyError("Bad accuracy detected")
torch._dynamo.debug_utils.AccuracyError: Bad accuracy detected

Minified repro

import torch
from torch import tensor, device
import torch.fx as fx
from torch._dynamo.testing import rand_strided
from math import inf
import torch._inductor.inductor_prims

import torch._dynamo.config
import torch._inductor.config
import torch._functorch.config
import torch.fx.experimental._config






isolate_fails_code_str = None



# torch version: 2.5.0a0+gitd526743
# torch cuda version: 12.1
# torch git version: d526743dd97820f70e37ac23f3d0cd6aaad73e0b


# CUDA Info: 
# nvcc: NVIDIA (R) Cuda compiler driver 
# Copyright (c) 2005-2023 NVIDIA Corporation 
# Built on Mon_Apr__3_17:16:06_PDT_2023 
# Cuda compilation tools, release 12.1, V12.1.105 
# Build cuda_12.1.r12.1/compiler.32688072_0 

# GPU Hardware Info: 
# NVIDIA GeForce RTX 3090 : 1 


from torch.nn import *
class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()

    
    
    def forward(self, arg0_1):
        rev = torch.ops.prims.rev.default(arg0_1, [0]);  arg0_1 = None
        slice_1 = torch.ops.aten.slice.Tensor(rev, 0, 0, -1, 2)
        slice_2 = torch.ops.aten.slice.Tensor(rev, 0, 1, 9223372036854775807, 2)
        add_1 = torch.ops.aten.add.Tensor(slice_1, slice_2);  slice_1 = slice_2 = None
        slice_3 = torch.ops.aten.slice.Tensor(add_1, 0, 0, -1, 2)
        slice_4 = torch.ops.aten.slice.Tensor(add_1, 0, 1, 9223372036854775807, 2)
        add_2 = torch.ops.aten.add.Tensor(slice_3, slice_4);  slice_3 = slice_4 = None
        slice_5 = torch.ops.aten.slice.Tensor(add_2, 0, 0, -1, 2)
        slice_6 = torch.ops.aten.slice.Tensor(add_2, 0, 1, 9223372036854775807, 2)
        add_3 = torch.ops.aten.add.Tensor(slice_5, slice_6);  slice_5 = slice_6 = None
        slice_9 = torch.ops.aten.slice.Tensor(add_2, 0, 0, 1);  add_2 = None
        unsqueeze = torch.ops.aten.unsqueeze.default(slice_9, 1);  slice_9 = None
        unsqueeze_1 = torch.ops.aten.unsqueeze.default(add_3, 1);  add_3 = None
        cat = torch.ops.aten.cat.default([unsqueeze, unsqueeze_1], 1);  unsqueeze = unsqueeze_1 = None
        view = torch.ops.aten.view.default(cat, [2]);  cat = None
        slice_10 = torch.ops.aten.slice.Tensor(view, 0, 0, -1)
        slice_11 = torch.ops.aten.slice.Tensor(add_1, 0, 2, 9223372036854775807, 2)
        add_5 = torch.ops.aten.add.Tensor(slice_10, slice_11);  slice_10 = slice_11 = None
        slice_12 = torch.ops.aten.slice.Tensor(add_1, 0, 0, 1);  add_1 = None
        cat_1 = torch.ops.aten.cat.default([slice_12, add_5]);  slice_12 = add_5 = None
        unsqueeze_2 = torch.ops.aten.unsqueeze.default(cat_1, 1);  cat_1 = None
        unsqueeze_3 = torch.ops.aten.unsqueeze.default(view, 1);  view = None
        cat_2 = torch.ops.aten.cat.default([unsqueeze_2, unsqueeze_3], 1);  unsqueeze_2 = unsqueeze_3 = None
        view_1 = torch.ops.aten.view.default(cat_2, [4]);  cat_2 = None
        slice_13 = torch.ops.aten.slice.Tensor(rev, 0, 2, 9223372036854775807, 2)
        add_6 = torch.ops.aten.add.Tensor(view_1, slice_13);  slice_13 = None
        slice_14 = torch.ops.aten.slice.Tensor(rev, 0, 0, 1);  rev = None
        cat_3 = torch.ops.aten.cat.default([slice_14, add_6]);  slice_14 = add_6 = None
        constant_pad_nd = torch.ops.aten.constant_pad_nd.default(view_1, [0, 1], 0.0);  view_1 = None
        unsqueeze_4 = torch.ops.aten.unsqueeze.default(cat_3, 1);  cat_3 = None
        unsqueeze_5 = torch.ops.aten.unsqueeze.default(constant_pad_nd, 1);  constant_pad_nd = None
        cat_4 = torch.ops.aten.cat.default([unsqueeze_4, unsqueeze_5], 1);  unsqueeze_4 = unsqueeze_5 = None
        view_2 = torch.ops.aten.view.default(cat_4, [10]);  cat_4 = None
        slice_15 = torch.ops.aten.slice.Tensor(view_2, 0, 0, 9);  view_2 = None
        rev_1 = torch.ops.prims.rev.default(slice_15, [0]);  slice_15 = None
        return (rev_1,)
        
def load_args(reader):
    buf0 = reader.storage('daea2d1a60a0939fea1275603b97befe81962b03', 72, device=device(type='cuda', index=0), dtype_hint=torch.int64)
    reader.tensor(buf0, (9,), dtype=torch.int64, is_leaf=True)  # arg0_1
load_args._version = 0
mod = Repro()
if __name__ == '__main__':
    import os
    from torch._dynamo.repro.after_aot import run_repro
    with torch.no_grad():
        run_repro(mod, load_args, accuracy=True, command='minify', save_dir=os.path.dirname(os.path.realpath(__file__)), tracing_mode='real', check_str=None)
        # To run it separately, do 
        # mod, args = run_repro(mod, load_args, accuracy=True, command='get_args', save_dir=os.path.dirname(os.path.realpath(__file__)), tracing_mode='real', check_str=None)
        # mod(*args)

Versions

Collecting environment information...
PyTorch version: 2.5.0a0+gitd526743
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (conda-forge gcc 10.4.0-19) 10.4.0
Clang version: Could not collect
CMake version: version 3.26.4
Libc version: glibc-2.31

Python version: 3.11.8 | packaged by conda-forge | (main, Feb 16 2024, 20:53:32) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-113-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA GeForce RTX 3090
GPU 1: NVIDIA GeForce RTX 3090
GPU 2: NVIDIA GeForce RTX 3090
GPU 3: NVIDIA GeForce RTX 3090

Nvidia driver version: 550.90.07
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.7
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 43 bits physical, 48 bits virtual
CPU(s): 48
On-line CPU(s) list: 0-47
Thread(s) per core: 2
Core(s) per socket: 24
Socket(s): 1
NUMA node(s): 1
Vendor ID: AuthenticAMD
CPU family: 23
Model: 49
Model name: AMD Ryzen Threadripper 3960X 24-Core Processor
Stepping: 0
Frequency boost: enabled
CPU MHz: 2200.000
CPU max MHz: 3800.0000
CPU min MHz: 2200.0000
BogoMIPS: 7600.39
Virtualization: AMD-V
L1d cache: 768 KiB
L1i cache: 768 KiB
L2 cache: 12 MiB
L3 cache: 128 MiB
NUMA node0 CPU(s): 0-47
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Mitigation; untrained return thunk; SMT enabled with STIBP protection
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip rdpid overflow_recov succor smca sme sev sev_es

Versions of relevant libraries:
[pip3] flake8==6.1.0
[pip3] flake8-bugbear==23.3.23
[pip3] flake8-comprehensions==3.12.0
[pip3] flake8-executable==2.1.3
[pip3] flake8-logging-format==0.9.0
[pip3] flake8-pyi==23.3.1
[pip3] flake8-simplify==0.19.3
[pip3] mypy==1.9.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.0
[pip3] optree==0.12.1
[pip3] pytorch-triton==3.0.0+dedb7bdf33
[pip3] torch==2.5.0a0+gitd526743
[conda] magma-cuda121 2.6.1 1 pytorch
[conda] mkl-include 2024.1.0 intel_691 intel
[conda] mkl-static 2024.1.0 intel_691 intel
[conda] numpy 1.26.0 pypi_0 pypi
[conda] optree 0.12.1 pypi_0 pypi
[conda] pytorch-triton 3.0.0+dedb7bdf33 pypi_0 pypi
[conda] torch 2.5.0a0+gitd526743 dev_0
[conda] torchfix 0.4.0 pypi_0 pypi

cc @ezyang @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions