Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch._dynamo.exc.Unsupported: comparison AutogradFunctionVariable() <built-in function is_not> ConstantVariable(NoneType) #125140

Closed
YangQun1 opened this issue Apr 29, 2024 · 1 comment
Labels
good first issue module: autograd Related to torch.autograd, and the autograd engine in general module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@YangQun1
Copy link
Contributor

YangQun1 commented Apr 29, 2024

馃悰 Describe the bug

The following code will have a graph break caused by "torch._dynamo.exc.Unsupported: comparison AutogradFunctionVariable() ConstantVariable(NoneType)"

import torch

class MyExp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, i):
        result = i.exp()
        ctx.save_for_backward(result)
        return result

    @staticmethod
    def backward(ctx, grad_output):
        result, = ctx.saved_tensors
        return grad_output * result

def not_inlinable_func(x):
    x = torch.relu(x)
    assert MyExp is not None, "failed to import MyExp"
    loss = x.pow(2.0).sum()
    return loss

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(64, 64, 1,  bias=False)

    def forward(self, x):
        y = self.conv(x)
        loss = not_inlinable_func(y)
        return loss

model = MyModule()
compiled_model = torch.compile(model, backend="inductor")

input = torch.randn([1, 64, 32, 32])
loss = compiled_model(input)

But if I change the MyExp class to a normal class like "assert torch.nn.Module is not None, "fail"", there won't be any graph breaks.

Can we add support for torch.autograd.Function comparison?

Error logs

FAILED INLINING <code object not_inlinable_func at 0x7fcf5de49210, file "/home/quyang/my_test/experiments/test_graph_break
break_graph_if_unsupported triggered compile
Traceback (most recent call last):
File "/home/quyang/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
return inner_fn(self, inst)
File "/home/quyang/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1213, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/home/quyang/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 652, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/home/quyang/venv/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 248, in call_function
return super().call_function(tx, args, kwargs)
File "/home/quyang/venv/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 81, in call_function
return tx.inline_user_function_return(
File "/home/quyang/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in inline_user_functi
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/home/quyang/venv/lib/python3.10/site-packages/torch/dynamo/symbolic_convert.py", line 2261, in inline_call
return cls.inline_call
(parent, func, args, kwargs)
File "/home/quyang/venv/lib/python3.10/site-packages/torch/dynamo/symbolic_convert.py", line 2376, in inline_call
tracer.run()
File "/home/quyang/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
and self.step()
File "/home/quyang/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
getattr(self, inst.opname)(inst)
File "/home/quyang/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1636, in IS_OP
self.COMPARE_OP(new_inst)
File "/home/quyang/venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1201, in COMPARE_OP
BuiltinVariable(supported_any[op]).call_function(
File "/home/quyang/venv/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 651, in call_function
result = handler(tx, *args, **kwargs)
File "/home/quyang/venv/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 1563, in _comparison
_unimplemented()
File "/home/quyang/venv/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 1453, in _unimplemented
unimplemented(f"comparison {typestr(left)} {op} {typestr(right)}")
File "/home/quyang/venv/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 193, in unimplemented
raise Unsupported(msg)
torch._dynamo.exc.Unsupported: comparison AutogradFunctionVariable() ConstantVariable(NoneType)

Minified repro

No response

Versions

PyTorch version: 2.2.0a0+gitc023606
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: 14.0.5 (ssh://gerrit.habana-labs.com:29418/tpc_llvm10 ab097983d16bb8cedaa85e13d6123696de35047a)
CMake version: version 3.28.1
Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-102-generic-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 43 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 12
On-line CPU(s) list: 0-11
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Gold 6132 CPU @ 2.60GHz
CPU family: 6
Model: 85
Thread(s) per core: 1
Core(s) per socket: 6
Socket(s): 2
Stepping: 0
BogoMIPS: 5187.81
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon nopl xtopology tsc_reliable nonstop_tsc cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti ssbd ibrs ibpb stibp tpr_shadow vnmi ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 invpcid avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xsaves arat pku ospke md_clear flush_l1d arch_capabilities
Virtualization: VT-x
Hypervisor vendor: VMware
Virtualization type: full
L1d cache: 384 KiB (12 instances)
L1i cache: 384 KiB (12 instances)
L2 cache: 12 MiB (12 instances)
L3 cache: 38.5 MiB (2 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-11
Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status
Vulnerability Itlb multihit: KVM: Mitigation: VMX disabled
Vulnerability L1tf: Mitigation; PTE Inversion; VMX flush not necessary, SMT disabled
Vulnerability Mds: Mitigation; Clear CPU buffers; SMT Host state unknown
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT Host state unknown
Vulnerability Retbleed: Mitigation; IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; IBRS, IBPB conditional, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] habana-torch-dataloader==1.16.0.363
[pip3] habana-torch-plugin==1.16.0.363
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] torch==2.2.0a0+gitc023606
[pip3] torch-debug==2.2.0a0+gitb658878
[pip3] torch_tb_profiler==0.4.0
[pip3] torchaudio==2.2.0+08901ad
[pip3] torchdata==0.7.1+5e6f7b7
[pip3] torchtext==0.17.0+400da5c
[pip3] torchvision==0.17.0+b2383d4

cc @ezyang @albanD @gqchen @pearu @nikitaved @soulitzer @lezcano @Varal7 @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng

@jbschlosser jbschlosser added module: autograd Related to torch.autograd, and the autograd engine in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: dynamo good first issue module: graph breaks and removed module: graph breaks labels Apr 29, 2024
@YangQun1
Copy link
Contributor Author

I will close this issue since torch.autograd.Function comparison has been correctly handled by the "handle_is" handler in latest pytorch nightly build, thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue module: autograd Related to torch.autograd, and the autograd engine in general module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

2 participants