New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch.compile(model.generate) cannot run under torch.inference_mode() with dynamic input shape #103132
Labels
inference mode
Everything related to InferenceMode guard
module: dynamic shapes
oncall: pt2
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Comments
jiqing-feng
changed the title
torch.compile(model.generate) cannot run under torch.inference_mode()
torch.compile(model.generate) cannot run under torch.inference_mode() with dynamic input shape
Jun 7, 2023
Yep, looks like inference mode is still partially broken. I can repro locally - from running with pytorch/aten/src/ATen/native/LinearAlgebra.cpp Line 1945 in 3c896a5
That shouldn't be happening, because we have a python decomp for matmul that's supposed to run when the python dispatcher is enabled. |
bdhirsh
added a commit
that referenced
this issue
Jun 9, 2023
…ion" Fixes #103132 This is kind of annoying: Functionalization (and also vmap, I think?) manually figures out which ops have C++ CompositeImplicit decomps, and directly registers them to the Functionalize key. This is a problem for the PyDispatcher: We normally want the PyDispatcher to take precedence over the regular dispatcher. But in this case, we have a python decomp registered to `CompositeImplicitAutograd`, and a C++ decomp registered *directly* to the `Functionalize` key, so the C++ decomp gets precedence over the python decomp. The way this showed up was that a model was running `matmul()` under inference mode, so we never hit the autograd dispatch key, and go straight to the functionalize dispatch key. Matmul has both a python decomp and a c++ decomp, but we were running the C++ decomp. That C++ decomp isn't meant to be used with dynamic shapes, so we were failing with the "tried to call `.sizes()` on a tensor with dynamic shapes" error. For now, I had the PyDispatcher mimic the behavior of functionalization codegen: when you register a python decomp to the `CompositeImplicitAutograd` key, this PR just automatically registers that decomp to the `Functionalize` key at the same time. I'm trying to remember now why we didn't just add `Functionalize` (and all of the other functorch transform keys) directly to the `CompositeImplicitAutograd` alias keyset, but I couldn't remember (zou3519 any chance you remember?). cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy aakhundov [ghstack-poisoned]
bdhirsh
added a commit
that referenced
this issue
Jun 9, 2023
…onalize interaction" Fixes #103132 This is kind of annoying: Functionalization (and also vmap, I think?) manually figures out which ops have C++ CompositeImplicit decomps, and directly registers them to the Functionalize key. This is a problem for the PyDispatcher: We normally want the PyDispatcher to take precedence over the regular dispatcher. But in this case, we have a python decomp registered to `CompositeImplicitAutograd`, and a C++ decomp registered *directly* to the `Functionalize` key, so the C++ decomp gets precedence over the python decomp. The way this showed up was that a model was running `matmul()` under inference mode, so we never hit the autograd dispatch key, and go straight to the functionalize dispatch key. Matmul has both a python decomp and a c++ decomp, but we were running the C++ decomp. That C++ decomp isn't meant to be used with dynamic shapes, so we were failing with the "tried to call `.sizes()` on a tensor with dynamic shapes" error. For now, I had the PyDispatcher mimic the behavior of functionalization codegen: when you register a python decomp to the `CompositeImplicitAutograd` key, this PR just automatically registers that decomp to the `Functionalize` key at the same time. I'm trying to remember now why we didn't just add `Functionalize` (and all of the other functorch transform keys) directly to the `CompositeImplicitAutograd` alias keyset, but I couldn't remember (zou3519 any chance you remember?). cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy aakhundov [ghstack-poisoned]
bdhirsh
added a commit
that referenced
this issue
Jun 9, 2023
…ion" Fixes #103132 This is kind of annoying: Functionalization (and also vmap, I think?) manually figures out which ops have C++ CompositeImplicit decomps, and directly registers them to the Functionalize key. This is a problem for the PyDispatcher: We normally want the PyDispatcher to take precedence over the regular dispatcher. But in this case, we have a python decomp registered to `CompositeImplicitAutograd`, and a C++ decomp registered *directly* to the `Functionalize` key, so the C++ decomp gets precedence over the python decomp. The way this showed up was that a model was running `matmul()` under inference mode, so we never hit the autograd dispatch key, and go straight to the functionalize dispatch key. Matmul has both a python decomp and a c++ decomp, but we were running the C++ decomp. That C++ decomp isn't meant to be used with dynamic shapes, so we were failing with the "tried to call `.sizes()` on a tensor with dynamic shapes" error. For now, I had the PyDispatcher mimic the behavior of functionalization codegen: when you register a python decomp to the `CompositeImplicitAutograd` key, this PR just automatically registers that decomp to the `Functionalize` key at the same time. I'm trying to remember now why we didn't just add `Functionalize` (and all of the other functorch transform keys) directly to the `CompositeImplicitAutograd` alias keyset, but I couldn't remember (zou3519 any chance you remember?). cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy aakhundov [ghstack-poisoned]
zou3519
added
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
and removed
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
labels
Jun 12, 2023
pytorchmergebot
pushed a commit
that referenced
this issue
Jun 20, 2023
…onalize interaction" Fixes #103132 This is kind of annoying: Functionalization (and also vmap, I think?) manually figures out which ops have C++ CompositeImplicit decomps, and directly registers them to the Functionalize key. This is a problem for the PyDispatcher: We normally want the PyDispatcher to take precedence over the regular dispatcher. But in this case, we have a python decomp registered to `CompositeImplicitAutograd`, and a C++ decomp registered *directly* to the `Functionalize` key, so the C++ decomp gets precedence over the python decomp. The way this showed up was that a model was running `matmul()` under inference mode, so we never hit the autograd dispatch key, and go straight to the functionalize dispatch key. Matmul has both a python decomp and a c++ decomp, but we were running the C++ decomp. That C++ decomp isn't meant to be used with dynamic shapes, so we were failing with the "tried to call `.sizes()` on a tensor with dynamic shapes" error. For now, I had the PyDispatcher mimic the behavior of functionalization codegen: when you register a python decomp to the `CompositeImplicitAutograd` key, this PR just automatically registers that decomp to the `Functionalize` key at the same time. I'm trying to remember now why we didn't just add `Functionalize` (and all of the other functorch transform keys) directly to the `CompositeImplicitAutograd` alias keyset, but I couldn't remember (zou3519 any chance you remember?). cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov [ghstack-poisoned]
pytorchmergebot
pushed a commit
that referenced
this issue
Jun 20, 2023
…ion" Fixes #103132 This is kind of annoying: Functionalization (and also vmap, I think?) manually figures out which ops have C++ CompositeImplicit decomps, and directly registers them to the Functionalize key. This is a problem for the PyDispatcher: We normally want the PyDispatcher to take precedence over the regular dispatcher. But in this case, we have a python decomp registered to `CompositeImplicitAutograd`, and a C++ decomp registered *directly* to the `Functionalize` key, so the C++ decomp gets precedence over the python decomp. The way this showed up was that a model was running `matmul()` under inference mode, so we never hit the autograd dispatch key, and go straight to the functionalize dispatch key. Matmul has both a python decomp and a c++ decomp, but we were running the C++ decomp. That C++ decomp isn't meant to be used with dynamic shapes, so we were failing with the "tried to call `.sizes()` on a tensor with dynamic shapes" error. For now, I had the PyDispatcher mimic the behavior of functionalization codegen: when you register a python decomp to the `CompositeImplicitAutograd` key, this PR just automatically registers that decomp to the `Functionalize` key at the same time. I'm trying to remember now why we didn't just add `Functionalize` (and all of the other functorch transform keys) directly to the `CompositeImplicitAutograd` alias keyset, but I couldn't remember (zou3519 any chance you remember?). cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov [ghstack-poisoned]
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
inference mode
Everything related to InferenceMode guard
module: dynamic shapes
oncall: pt2
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
🐛 Describe the bug
torch.compile(model.generate)
cannot run undertorch.inference_mode()
with dynamic input shape, but it can run if I changetorch.inference_mode()
totorch.no_grad()
. Users usetorch.inference_mode()
in most cases, would you please help to check it? Thx!The error is as follows
Versions
PyTorch version: 2.1.0.dev20230604+cpu
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: CentOS Stream 8 (x86_64)
GCC version: (GCC) 11.4.0
Clang version: Could not collect
CMake version: version 3.26.3
Libc version: glibc-2.28
Python version: 3.9.16 (main, Mar 8 2023, 14:00:05) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.16.0-rc8-intel-next-01534-g53cb5f883cf7-x86_64-with-glibc2.28
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
CPU(s): 224
On-line CPU(s) list: 0-223
Thread(s) per core: 2
Core(s) per socket: 56
Socket(s): 2
NUMA node(s): 2
Vendor ID: GenuineIntel
CPU family: 6
Model: 143
Model name: Genuine Intel(R) CPU 0000%@
Stepping: 3
CPU MHz: 1900.000
CPU max MHz: 1900.0000
CPU min MHz: 800.0000
BogoMIPS: 3800.00
Virtualization: VT-x
L1d cache: 48K
L1i cache: 32K
L2 cache: 2048K
L3 cache: 107520K
NUMA node0 CPU(s): 0-55,112-167
NUMA node1 CPU(s): 56-111,168-223
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single intel_ppin cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req hfi avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm uintr avx512_vp2intersect md_clear serialize tsxldtrk pconfig arch_lbr amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities
Versions of relevant libraries:
[pip3] intel-extension-for-pytorch==2.1.0+git1f1ee89
[pip3] numpy==1.24.1
[pip3] torch==2.1.0.dev20230604+cpu
[pip3] torchaudio==2.1.0.dev20230604+cpu
[pip3] torchvision==0.16.0.dev20230604+cpu
[conda] intel-extension-for-pytorch 2.1.0+git1f1ee89 pypi_0 pypi
[conda] mkl 2023.1.0 h6d00ec8_46342
[conda] mkl-include 2023.1.0 pypi_0 pypi
[conda] mkl-static 2023.1.0 pypi_0 pypi
[conda] numpy 1.24.1 pypi_0 pypi
[conda] torch 2.1.0.dev20230604+cpu pypi_0 pypi
[conda] torchaudio 2.1.0.dev20230604+cpu pypi_0 pypi
[conda] torchvision 0.16.0.dev20230604+cpu pypi_0 pypi
cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305
The text was updated successfully, but these errors were encountered: