-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Description
🐛 Describe the bug
Hi! I found the following snippet will trigger the Cell is empty
error in the function match_nested_cell
.
import torch
def forward():
x = torch.zeros(torch.Size([1]), device='cuda:0')
def subfunc():
x[0] = backup
if x[0] >= -1e5:
pass
backup = 1
subfunc()
return x
with torch.no_grad():
print(forward())
fn_compiled = torch.compile(forward)
print(fn_compiled())
Error logs
Traceback (most recent call last):
File "/home/su/accdiff/thirdparty/pytorch/torch/_dynamo/convert_frame.py", line 331, in _compile
out_code = transform_code_object(code, transform)
File "/home/su/accdiff/thirdparty/pytorch/torch/_dynamo/bytecode_transformation.py", line 530, in transform_code_object
transformations(instructions, code_options)
File "/home/su/accdiff/thirdparty/pytorch/torch/_dynamo/convert_frame.py", line 318, in transform
tracer.run()
File "/home/su/accdiff/thirdparty/pytorch/torch/_dynamo/symbolic_convert.py", line 1848, in run
super().run()
File "/home/su/accdiff/thirdparty/pytorch/torch/_dynamo/symbolic_convert.py", line 604, in run
and self.step()
File "/home/su/accdiff/thirdparty/pytorch/torch/_dynamo/symbolic_convert.py", line 564, in step
getattr(self, inst.opname)(inst)
File "/home/su/accdiff/thirdparty/pytorch/torch/_dynamo/symbolic_convert.py", line 347, in wrapper
return inner_fn(self, inst)
File "/home/su/accdiff/thirdparty/pytorch/torch/_dynamo/symbolic_convert.py", line 1000, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/home/su/accdiff/thirdparty/pytorch/torch/_dynamo/symbolic_convert.py", line 495, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/home/su/accdiff/thirdparty/pytorch/torch/_dynamo/variables/functions.py", line 265, in call_function
return super().call_function(tx, args, kwargs)
File "/home/su/accdiff/thirdparty/pytorch/torch/_dynamo/variables/functions.py", line 98, in call_function
return tx.inline_user_function_return(
File "/home/su/accdiff/thirdparty/pytorch/torch/_dynamo/symbolic_convert.py", line 531, in inline_user_function_return
result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/home/su/accdiff/thirdparty/pytorch/torch/_dynamo/symbolic_convert.py", line 1941, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/home/su/accdiff/thirdparty/pytorch/torch/_dynamo/symbolic_convert.py", line 1974, in inline_call_
sub_locals, closure_cells = func.bind_args(parent, args, kwargs)
File "/home/su/accdiff/thirdparty/pytorch/torch/_dynamo/variables/functions.py", line 202, in bind_args
var = tx.match_nested_cell(name, cell)
File "/home/su/accdiff/thirdparty/pytorch/torch/_dynamo/symbolic_convert.py", line 1852, in match_nested_cell
value = cell.cell_contents
ValueError: Cell is empty
from user code:
File "/home/su/accdiff/test.py", line 11, in <resume in forward>
subfunc()
Set torch._dynamo.config.verbose=True or TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
torch._dynamo.config.suppress_errors = True
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/su/accdiff/test.py", line 17, in <module>
print(fn_compiled())
File "/home/su/accdiff/thirdparty/pytorch/torch/_dynamo/eval_frame.py", line 235, in _fn
return fn(*args, **kwargs)
File "/home/su/accdiff/test.py", line 2, in forward
def forward():
File "/home/su/accdiff/thirdparty/pytorch/torch/_dynamo/eval_frame.py", line 372, in catch_errors
return callback(frame, cache_size, hooks)
File "/home/su/accdiff/thirdparty/pytorch/torch/_dynamo/convert_frame.py", line 412, in _convert_frame
result = inner_convert(frame, cache_size, hooks)
File "/home/su/accdiff/thirdparty/pytorch/torch/_dynamo/convert_frame.py", line 110, in _fn
return fn(*args, **kwargs)
File "/home/su/accdiff/thirdparty/pytorch/torch/_dynamo/convert_frame.py", line 269, in _convert_frame_assert
return _compile(
File "/home/su/accdiff/thirdparty/pytorch/torch/_dynamo/utils.py", line 165, in time_wrapper
r = func(*args, **kwargs)
File "/home/su/accdiff/thirdparty/pytorch/torch/_dynamo/convert_frame.py", line 402, in _compile
raise InternalTorchDynamoError() from e
torch._dynamo.exc.InternalTorchDynamoError
Minified repro
No response
Versions
PyTorch version: 2.1.0a0+gitfe05266
Is debug build: False
CUDA used to build PyTorch: 11.6
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.4 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.25.2
Libc version: glibc-2.31
Python version: 3.9.5 (default, Nov 23 2021, 15:27:38) [GCC 9.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-67-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.6.112
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 2070
Nvidia driver version: 510.108.03
cuDNN version: Probably one of the following:
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn.so.8.2.4
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.2.4
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.2.4
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.2.4
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.2.4
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.2.4
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.2.4
/usr/local/cuda-11.6/targets/x86_64-linux/lib/libcudnn.so.8.6.0
/usr/local/cuda-11.6/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.6.0
/usr/local/cuda-11.6/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.6.0
/usr/local/cuda-11.6/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.6.0
/usr/local/cuda-11.6/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.6.0
/usr/local/cuda-11.6/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.6.0
/usr/local/cuda-11.6/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.6.0
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn.so.8.7.0
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.7.0
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.7.0
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.7.0
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.7.0
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.7.0
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.7.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 43 bits physical, 48 bits virtual
CPU(s): 32
On-line CPU(s) list: 0-31
Thread(s) per core: 2
Core(s) per socket: 16
Socket(s): 1
NUMA node(s): 1
Vendor ID: AuthenticAMD
CPU family: 23
Model: 1
Model name: AMD Ryzen Threadripper 1950X 16-Core Processor
Stepping: 1
Frequency boost: enabled
CPU MHz: 2083.655
CPU max MHz: 3400.0000
CPU min MHz: 2200.0000
BogoMIPS: 6786.49
Virtualization: AMD-V
L1d cache: 512 KiB
L1i cache: 1 MiB
L2 cache: 8 MiB
L3 cache: 32 MiB
NUMA node0 CPU(s): 0-31
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Mitigation; untrained return thunk; SMT vulnerable
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid amd_dcm aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb hw_pstate ssbd ibpb vmmcall fsgsbase bmi1 avx2 smep bmi2 rdseed adx smap clflushopt sha_ni xsaveopt xsavec xgetbv1 xsaves clzero irperf xsaveerptr arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif overflow_recov succor smca sme sev
Versions of relevant libraries:
[pip3] numpy==1.24.2
[pip3] torch==2.1.0a0+git4805441
[pip3] triton==2.0.0.post1
cc @ezyang @soumith @msaroufim @wconstab @ngimel @bdhirsh @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire