-
Notifications
You must be signed in to change notification settings - Fork 25.5k
Description
🐛 Describe the bug
When I use a custom op with default arguments, torch.compile
passed if I pass value that is different with default value, but failed if passed value is same as default value.
import torch
from typing import Optional, Sequence
@torch.library.custom_op("debug::fwd", mutates_args=())
def _fwd(t0: torch.Tensor,
a1: Optional[float] = None,
a2: bool = False,
a3: bool = True) -> list[torch.Tensor]:
return [t0.clone()]
@_fwd.register_fake
def _(t0, a1=None, a2=False, a3=True):
return [torch.empty_like(t0)]
def _bwd_standalone(ctx, dgrads: list[torch.Tensor]):
do = dgrads[0]
return (do, None, None, None)
def _bwd_setup_context(ctx, inputs, output):
pass
_fwd.register_autograd(_bwd_standalone, setup_context=_bwd_setup_context)
@torch.compile(fullgraph=True)
def fn1(t0):
# passed value not equal to default value
return _fwd(t0, None, True, False)
@torch.compile(fullgraph=True)
def fn2(t0):
# passed value equal to default value, error
return _fwd(t0, None, False, True)
t0 = torch.randn(2, 3, requires_grad=True)
fn1(t0)
print("FN1 PASSED!!!!!")
fn2(t0)
print("FN2 PASSED!!!!!")
Error msg:
FN1 PASSED!!!!!
torch/autograd/graph.py:841: UserWarning: Error detected in GeneratedBackwardFor_debug_fwd_defaultBackward. Traceback of forward call that caused the error:
File "test.py", line 32, in fn2
return _fwd(t0, None, False, True)
File torch/_library/custom_ops.py", line 676, in __call__
return self._opoverload(*args, **kwargs)
(Triggered internally at /pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:122.)
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
...
...
...
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: Expected the return from backward to be of the same structure as the inputs. Got: TreeSpec(tuple, None, [*,
*,
*,
*,
*]) (return from backward), TreeSpec(tuple, None, [*,
*]) (inputs)
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
workaround: never use default argument in custom op.
Versions
2.10.0.dev20250910+cu126
Collecting environment information...
PyTorch version: 2.10.0.dev20250910+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A
OS: Ubuntu 24.04.3 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version: Could not collect
CMake version: version 3.28.3
Libc version: glibc-2.39
Python version: 3.12.0 | packaged by Anaconda, Inc. | (main, Oct 2 2023, 17:29:18) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.8.0-64-generic-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to:
GPU models and configuration: GPU 0: NVIDIA GeForce GTX 1080
Nvidia driver version: 575.64.03
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 39 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 16
On-line CPU(s) list: 0-15
Vendor ID: GenuineIntel
Model name: Intel(R) Core(TM) i7-10700K CPU @ 3.80GHz
CPU family: 6
Model: 165
Thread(s) per core: 2
Core(s) per socket: 8
Socket(s): 1
Stepping: 5
CPU(s) scaling MHz: 59%
CPU max MHz: 5100.0000
CPU min MHz: 800.0000
BogoMIPS: 7599.80
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx rdseed adx smap clflushopt intel_pt xsaveopt xsavec xgetbv1 xsaves dtherm ida arat pln pts hwp hwp_notify hwp_act_window hwp_epp pku ospke md_clear flush_l1d arch_capabilities
L1d cache: 256 KiB (8 instances)
L1i cache: 256 KiB (8 instances)
L2 cache: 2 MiB (8 instances)
L3 cache: 16 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-15
Vulnerability Gather data sampling: Mitigation; Microcode
Vulnerability Itlb multihit: KVM: Mitigation: VMX unsupported
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop
Vulnerability Srbds: Mitigation; Microcode
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] numpy==2.3.2
[pip3] nvidia-cublas-cu12==12.6.4.1
[pip3] nvidia-cuda-cupti-cu12==12.6.80
[pip3] nvidia-cuda-nvrtc-cu12==12.6.77
[pip3] nvidia-cuda-runtime-cu12==12.6.77
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cufft-cu12==11.3.0.4
[pip3] nvidia-curand-cu12==10.3.7.77
[pip3] nvidia-cusolver-cu12==11.7.1.2
[pip3] nvidia-cusparse-cu12==12.5.4.2
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-nccl-cu12==2.27.5
[pip3] nvidia-nvjitlink-cu12==12.6.85
[pip3] nvidia-nvtx-cu12==12.6.77
[pip3] pytorch-triton==3.5.0+gitfccfc522
[pip3] torch==2.10.0.dev20250910+cu126
[pip3] torchvision==0.24.0.dev20250910+cu126
[conda] numpy 2.3.2 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.6.4.1 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.6.80 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.6.77 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.6.77 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.10.2.21 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.3.0.4 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.7.77 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.7.1.2 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.5.4.2 pypi_0 pypi
[conda] nvidia-cusparselt-cu12 0.7.1 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.27.5 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.6.85 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.6.77 pypi_0 pypi
[conda] pytorch-triton 3.5.0+gitfccfc522 pypi_0 pypi
[conda] torch 2.10.0.dev20250910+cu126 pypi_0 pypi
[conda] torchvision 0.24.0.dev20250910+cu126 pypi_0 pypi
cc @ezyang @albanD @gqchen @nikitaved @soulitzer @Varal7 @xmfan @chauhang @penguinwu @zou3519 @bdhirsh