Skip to content

ZeroDivisionError in torch.compile with torch.unfold_copy and Zero-Dimension Tensor #113026

Closed
@zoux1a

Description

@zoux1a

🐛 Describe the bug

Encountered a ZeroDivisionError on torch.compile mode. And there was no bug triggered on eagermode.

import torch

def forward(x):
  return torch.unfold_copy(dimension=1, input=x,size=0,step=7) 

x = torch.rand([1,0], dtype=torch.float32)# generate arg
forward(x)# on eagermode
print("build succeeded")
torch.compile(forward, mode='max-autotune',fullgraph=True)(x)# encountered a ZeroDivisionError on torch.compile mode

error trace:

build succeeded
......
 raise ZeroDivisionError("division by zero")
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: ZeroDivisionError: division by zero
  target: aten.unfold.default
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='arg0_1', layout=FixedLayout('cpu', torch.float32, size=[1, 0], stride=[1, 1]))
  ))
  args[1]: 1
  args[2]: 0
  args[3]: 7
......

build succeeded
Traceback (most recent call last):
  File "/home/guihuan/LLM/results/torch-2/2023-11-03-15-06/repros/repro25.py", line 9, in <module>
    torch.compile(forward, mode='max-autotune',fullgraph=True)(x)# encountered a ZeroDivisionError on torch.compile mode
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 410, in _fn
    return fn(*args, **kwargs)
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 558, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 148, in _fn
    return fn(*args, **kwargs)
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 402, in _convert_frame_assert
    return _compile(
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 610, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 221, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 527, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 497, in transform
    tracer.run()
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2117, in run
    super().run()
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 742, in run
    and self.step()
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 705, in step
    getattr(self, inst.opname)(inst)
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2227, in RETURN_VALUE
    self.output.compile_subgraph(
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 865, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 993, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 221, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 1064, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 1045, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_dynamo/repro/after_dynamo.py", line 117, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_dynamo/repro/after_dynamo.py", line 117, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/__init__.py", line 1604, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 975, in compile_fx
    return compile_fx(
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 1189, in compile_fx
    return aot_autograd(
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_dynamo/backends/common.py", line 55, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 4755, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 221, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 4293, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata)
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 2622, in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 2809, in aot_wrapper_synthetic_base
    return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 1951, in aot_dispatch_base
    compiled_fw = compiler(fw_module, updated_flat_args)
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 221, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 1126, in fw_compiler_base
    return inner_compile(
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_dynamo/repro/after_aot.py", line 80, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_inductor/debug.py", line 300, in inner
    return fn(*args, **kwargs)
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 351, in compile_fx_inner
    compiled_graph = fx_codegen_and_compile(
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 553, in fx_codegen_and_compile
    graph.run(*example_inputs)
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 221, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_inductor/graph.py", line 460, in run
    return super().run(*args)
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/fx/interpreter.py", line 138, in run
    self.env[node] = self.run_node(node)
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_inductor/graph.py", line 762, in run_node
    result = super().run_node(n)
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/fx/interpreter.py", line 195, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_inductor/graph.py", line 632, in call_function
    raise LoweringException(e, target, args, kwargs).with_traceback(
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_inductor/graph.py", line 629, in call_function
    out = lowerings[target](*args, **kwargs)
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_inductor/lowering.py", line 289, in wrapped
    out = decomp_fn(*args, **kwargs)
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/_inductor/lowering.py", line 1191, in unfold
    x.mark_reuse(sizevars.size_hint(CeilDiv(new_dim_size * size, sizes[dim])))
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/utils/_sympy/functions.py", line 236, in __new__
    return CleanDiv(base, divisor)
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/sympy/core/cache.py", line 70, in wrapper
    retval = cfunc(*args, **kwargs)
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/sympy/core/function.py", line 469, in __new__
    result = super().__new__(cls, *args, **options)
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/sympy/core/cache.py", line 70, in wrapper
    retval = cfunc(*args, **kwargs)
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/sympy/core/function.py", line 309, in __new__
    evaluated = cls.eval(*args)
  File "/home/guihuan/.conda/envs/nightly/lib/python3.9/site-packages/torch/utils/_sympy/functions.py", line 58, in eval
    raise ZeroDivisionError("division by zero")
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: ZeroDivisionError: division by zero
  target: aten.unfold.default
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='arg0_1', layout=FixedLayout('cpu', torch.float32, size=[1, 0], stride=[1, 1]))
  ))
  args[1]: 1
  args[2]: 0
  args[3]: 7

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Versions

PyTorch version: 2.2.0.dev20231023+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.9.18 (main, Sep 11 2023, 13:41:44) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-86-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA GeForce RTX 2070
GPU 1: NVIDIA GeForce RTX 2070
GPU 2: NVIDIA GeForce RTX 2070
GPU 3: NVIDIA GeForce RTX 2070

Nvidia driver version: 535.104.12
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 32
On-line CPU(s) list: 0-31
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) CPU E5-2630 v3 @ 2.40GHz
CPU family: 6
Model: 63
Thread(s) per core: 2
Core(s) per socket: 8
Socket(s): 2
Stepping: 2
CPU max MHz: 3200.0000
CPU min MHz: 1200.0000
BogoMIPS: 4794.64
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm cpuid_fault epb invpcid_single pti ssbd ibrs ibpb stibp tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm xsaveopt cqm_llc cqm_occup_llc dtherm ida arat pln pts md_clear flush_l1d
Virtualization: VT-x
L1d cache: 512 KiB (16 instances)
L1i cache: 512 KiB (16 instances)
L2 cache: 4 MiB (16 instances)
L3 cache: 40 MiB (2 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0,2,4,6,8,10,12,14,16,18,20,22,24,26,28,30
NUMA node1 CPU(s): 1,3,5,7,9,11,13,15,17,19,21,23,25,27,29,31
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: KVM: Mitigation: VMX disabled
Vulnerability L1tf: Mitigation; PTE Inversion; VMX conditional cache flushes, SMT vulnerable
Vulnerability Mds: Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP conditional, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.1
[pip3] pytorch-triton==2.1.0+6e4932cda8
[pip3] torch==2.2.0.dev20231023+cu118
[pip3] torchaudio==2.2.0.dev20231023+cu118
[pip3] torchvision==0.17.0.dev20231023+cu118
[conda] cudatoolkit 11.8.0 h6a678d5_0 defaults
[conda] numpy 1.26.1 pypi_0 pypi
[conda] pytorch-triton 2.1.0+6e4932cda8 pypi_0 pypi
[conda] torch 2.2.0.dev20231023+cu118 pypi_0 pypi
[conda] torchaudio 2.2.0.dev20231023+cu118 pypi_0 pypi
[conda] torchvision 0.17.0.dev20231023+cu118 pypi_0 pypi

cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305 @zou3519 @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions