Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.compile(model.generate) cannot run under torch.inference_mode() with dynamic input shape #103132

Closed
jiqing-feng opened this issue Jun 7, 2023 · 1 comment
Assignees
Labels
inference mode Everything related to InferenceMode guard module: dynamic shapes oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@jiqing-feng
Copy link

jiqing-feng commented Jun 7, 2023

🐛 Describe the bug

torch.compile(model.generate) cannot run under torch.inference_mode() with dynamic input shape, but it can run if I change torch.inference_mode() to torch.no_grad(). Users use torch.inference_mode() in most cases, would you please help to check it? Thx!

import time
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

kwargs = dict(torch_dtype=torch.float32, use_cache=True)
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base", **kwargs)
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
model = model.eval()

input_sentence =  "Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?"
inputs = tokenizer(input_sentence, return_tensors='pt')

with torch.inference_mode():
    for i in range(10):
        pre = time.time()
        output = model.generate(**inputs)
        print(f"eager eval time {i}: {time.time()-pre}")

model.generate = torch.compile(model.generate, backend='inductor', dynamic=True)
with torch.inference_mode():
    for i in range(10):
        pre = time.time()
        output_compile = model.generate(**inputs)
        print(f"compile eval time {i}: {time.time()-pre}")

The error is as follows

Traceback (most recent call last):
  File "/home/jiqingfe/optimum-script/issue.py", line 27, in <module>
    output_compile = model.generate(**inputs)
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 286, in _fn
    return fn(*args, **kwargs)
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/transformers/generation/utils.py", line 1246, in generate
    self._validate_model_class()
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/transformers/generation/utils.py", line 1253, in <resume in generate>
    new_generation_config = GenerationConfig.from_model_config(self.config)
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/transformers/generation/utils.py", line 1515, in <resume in generate>
    return self.greedy_search(
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/transformers/generation/utils.py", line 2332, in greedy_search
    outputs = self(
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 439, in catch_errors
    return callback(frame, cache_size, hooks, frame_state)
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 527, in _convert_frame
    result = inner_convert(frame, cache_size, hooks, frame_state)
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 127, in _fn
    return fn(*args, **kwargs)
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 360, in _convert_frame_assert
    return _compile(
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 180, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 430, in _compile
    out_code = transform_code_object(code, transform)
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/torch/_dynamo/bytecode_transformation.py", line 1000, in transform_code_object
    transformations(instructions, code_options)
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 415, in transform
    tracer.run()
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2023, in run
    super().run()
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 707, in run
    and self.step()
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 667, in step
    getattr(self, inst.opname)(inst)
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2111, in RETURN_VALUE
    self.output.compile_subgraph(
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 784, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 857, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 180, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 913, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 909, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/torch/_dynamo/repro/after_dynamo.py", line 117, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/torch/__init__.py", line 1536, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 645, in compile_fx
    return compile_fx(
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 802, in compile_fx
    return aot_autograd(
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/torch/_dynamo/backends/common.py", line 55, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 3707, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 180, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 3183, in create_aot_dispatcher_function
    fw_metadata = run_functionalized_fw_and_collect_metadata(
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 723, in inner
    flat_f_outs = f(*flat_f_args)
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 3316, in functional_call
    out = Interpreter(mod).run(*args[params_len:], **kwargs)
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/torch/fx/interpreter.py", line 138, in run
    self.env[node] = self.run_node(node)
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/torch/fx/interpreter.py", line 195, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/torch/fx/interpreter.py", line 267, in call_function
    return target(*args, **kwargs)
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: Cannot call sizes() on tensor with symbolic sizes/strides

While executing %matmul : [#users=1] = call_function[target=torch.matmul](args = (%transpose, %transpose_3), kwargs = {})
Original traceback:
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/transformers/models/t5/modeling_t5.py", line 1718, in forward
    decoder_outputs = self.decoder(
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/transformers/models/t5/modeling_t5.py", line 1088, in forward
    layer_outputs = layer_module(
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/transformers/models/t5/modeling_t5.py", line 695, in forward
    self_attention_outputs = self.layer[0](
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/transformers/models/t5/modeling_t5.py", line 602, in forward
    attention_output = self.SelfAttention(
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jiqingfe/miniconda3/envs/trace/lib/python3.9/site-packages/transformers/models/t5/modeling_t5.py", line 532, in forward
    scores = torch.matmul(



You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Versions

PyTorch version: 2.1.0.dev20230604+cpu
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: CentOS Stream 8 (x86_64)
GCC version: (GCC) 11.4.0
Clang version: Could not collect
CMake version: version 3.26.3
Libc version: glibc-2.28

Python version: 3.9.16 (main, Mar 8 2023, 14:00:05) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.16.0-rc8-intel-next-01534-g53cb5f883cf7-x86_64-with-glibc2.28
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
CPU(s): 224
On-line CPU(s) list: 0-223
Thread(s) per core: 2
Core(s) per socket: 56
Socket(s): 2
NUMA node(s): 2
Vendor ID: GenuineIntel
CPU family: 6
Model: 143
Model name: Genuine Intel(R) CPU 0000%@
Stepping: 3
CPU MHz: 1900.000
CPU max MHz: 1900.0000
CPU min MHz: 800.0000
BogoMIPS: 3800.00
Virtualization: VT-x
L1d cache: 48K
L1i cache: 32K
L2 cache: 2048K
L3 cache: 107520K
NUMA node0 CPU(s): 0-55,112-167
NUMA node1 CPU(s): 56-111,168-223
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single intel_ppin cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req hfi avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm uintr avx512_vp2intersect md_clear serialize tsxldtrk pconfig arch_lbr amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities

Versions of relevant libraries:
[pip3] intel-extension-for-pytorch==2.1.0+git1f1ee89
[pip3] numpy==1.24.1
[pip3] torch==2.1.0.dev20230604+cpu
[pip3] torchaudio==2.1.0.dev20230604+cpu
[pip3] torchvision==0.16.0.dev20230604+cpu
[conda] intel-extension-for-pytorch 2.1.0+git1f1ee89 pypi_0 pypi
[conda] mkl 2023.1.0 h6d00ec8_46342
[conda] mkl-include 2023.1.0 pypi_0 pypi
[conda] mkl-static 2023.1.0 pypi_0 pypi
[conda] numpy 1.24.1 pypi_0 pypi
[conda] torch 2.1.0.dev20230604+cpu pypi_0 pypi
[conda] torchaudio 2.1.0.dev20230604+cpu pypi_0 pypi
[conda] torchvision 0.16.0.dev20230604+cpu pypi_0 pypi

cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305

@jiqing-feng jiqing-feng changed the title torch.compile(model.generate) cannot run under torch.inference_mode() torch.compile(model.generate) cannot run under torch.inference_mode() with dynamic input shape Jun 7, 2023
@zou3519 zou3519 added the inference mode Everything related to InferenceMode guard label Jun 8, 2023
@bdhirsh
Copy link
Contributor

bdhirsh commented Jun 8, 2023

Yep, looks like inference mode is still partially broken.

I can repro locally - from running with TORCH_SHOW_CPP_STACKTRACES=1, the size call is coming from at::matmul:

const auto sizes_1 = t1->sizes();

That shouldn't be happening, because we have a python decomp for matmul that's supposed to run when the python dispatcher is enabled.

bdhirsh added a commit that referenced this issue Jun 9, 2023
…ion"

Fixes #103132

This is kind of annoying: Functionalization (and also vmap, I think?) manually figures out which ops have C++ CompositeImplicit decomps, and directly registers them to the Functionalize key. This is a problem for the PyDispatcher: We normally want the PyDispatcher to take precedence over the regular dispatcher. But in this case, we have a python decomp registered to `CompositeImplicitAutograd`, and a C++ decomp registered *directly* to the `Functionalize` key, so the C++ decomp gets precedence over the python decomp.

The way this showed up was that a model was running `matmul()` under inference mode, so we never hit the autograd dispatch key, and go straight to the functionalize dispatch key. Matmul has both a python decomp and a c++ decomp, but we were running the C++ decomp. That C++ decomp isn't meant to be used with dynamic shapes, so we were failing with the "tried to call `.sizes()` on a tensor with dynamic shapes" error.

For now, I had the PyDispatcher mimic the behavior of functionalization codegen: when you register a python decomp to the `CompositeImplicitAutograd` key, this PR just automatically registers that decomp to the `Functionalize` key at the same time.

I'm trying to remember now why we didn't just add `Functionalize` (and all of the other functorch transform keys) directly to the `CompositeImplicitAutograd` alias keyset, but I couldn't remember (zou3519 any chance you remember?).




cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy aakhundov

[ghstack-poisoned]
bdhirsh added a commit that referenced this issue Jun 9, 2023
…onalize interaction"

Fixes #103132

This is kind of annoying: Functionalization (and also vmap, I think?) manually figures out which ops have C++ CompositeImplicit decomps, and directly registers them to the Functionalize key. This is a problem for the PyDispatcher: We normally want the PyDispatcher to take precedence over the regular dispatcher. But in this case, we have a python decomp registered to `CompositeImplicitAutograd`, and a C++ decomp registered *directly* to the `Functionalize` key, so the C++ decomp gets precedence over the python decomp.

The way this showed up was that a model was running `matmul()` under inference mode, so we never hit the autograd dispatch key, and go straight to the functionalize dispatch key. Matmul has both a python decomp and a c++ decomp, but we were running the C++ decomp. That C++ decomp isn't meant to be used with dynamic shapes, so we were failing with the "tried to call `.sizes()` on a tensor with dynamic shapes" error.

For now, I had the PyDispatcher mimic the behavior of functionalization codegen: when you register a python decomp to the `CompositeImplicitAutograd` key, this PR just automatically registers that decomp to the `Functionalize` key at the same time.

I'm trying to remember now why we didn't just add `Functionalize` (and all of the other functorch transform keys) directly to the `CompositeImplicitAutograd` alias keyset, but I couldn't remember (zou3519 any chance you remember?).




cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy aakhundov

[ghstack-poisoned]
bdhirsh added a commit that referenced this issue Jun 9, 2023
…ion"

Fixes #103132

This is kind of annoying: Functionalization (and also vmap, I think?) manually figures out which ops have C++ CompositeImplicit decomps, and directly registers them to the Functionalize key. This is a problem for the PyDispatcher: We normally want the PyDispatcher to take precedence over the regular dispatcher. But in this case, we have a python decomp registered to `CompositeImplicitAutograd`, and a C++ decomp registered *directly* to the `Functionalize` key, so the C++ decomp gets precedence over the python decomp.

The way this showed up was that a model was running `matmul()` under inference mode, so we never hit the autograd dispatch key, and go straight to the functionalize dispatch key. Matmul has both a python decomp and a c++ decomp, but we were running the C++ decomp. That C++ decomp isn't meant to be used with dynamic shapes, so we were failing with the "tried to call `.sizes()` on a tensor with dynamic shapes" error.

For now, I had the PyDispatcher mimic the behavior of functionalization codegen: when you register a python decomp to the `CompositeImplicitAutograd` key, this PR just automatically registers that decomp to the `Functionalize` key at the same time.

I'm trying to remember now why we didn't just add `Functionalize` (and all of the other functorch transform keys) directly to the `CompositeImplicitAutograd` alias keyset, but I couldn't remember (zou3519 any chance you remember?).




cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy aakhundov

[ghstack-poisoned]
@zou3519 zou3519 added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jun 12, 2023
pytorchmergebot pushed a commit that referenced this issue Jun 20, 2023
…onalize interaction"

Fixes #103132

This is kind of annoying: Functionalization (and also vmap, I think?) manually figures out which ops have C++ CompositeImplicit decomps, and directly registers them to the Functionalize key. This is a problem for the PyDispatcher: We normally want the PyDispatcher to take precedence over the regular dispatcher. But in this case, we have a python decomp registered to `CompositeImplicitAutograd`, and a C++ decomp registered *directly* to the `Functionalize` key, so the C++ decomp gets precedence over the python decomp.

The way this showed up was that a model was running `matmul()` under inference mode, so we never hit the autograd dispatch key, and go straight to the functionalize dispatch key. Matmul has both a python decomp and a c++ decomp, but we were running the C++ decomp. That C++ decomp isn't meant to be used with dynamic shapes, so we were failing with the "tried to call `.sizes()` on a tensor with dynamic shapes" error.

For now, I had the PyDispatcher mimic the behavior of functionalization codegen: when you register a python decomp to the `CompositeImplicitAutograd` key, this PR just automatically registers that decomp to the `Functionalize` key at the same time.

I'm trying to remember now why we didn't just add `Functionalize` (and all of the other functorch transform keys) directly to the `CompositeImplicitAutograd` alias keyset, but I couldn't remember (zou3519 any chance you remember?).




cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this issue Jun 20, 2023
…ion"

Fixes #103132

This is kind of annoying: Functionalization (and also vmap, I think?) manually figures out which ops have C++ CompositeImplicit decomps, and directly registers them to the Functionalize key. This is a problem for the PyDispatcher: We normally want the PyDispatcher to take precedence over the regular dispatcher. But in this case, we have a python decomp registered to `CompositeImplicitAutograd`, and a C++ decomp registered *directly* to the `Functionalize` key, so the C++ decomp gets precedence over the python decomp.

The way this showed up was that a model was running `matmul()` under inference mode, so we never hit the autograd dispatch key, and go straight to the functionalize dispatch key. Matmul has both a python decomp and a c++ decomp, but we were running the C++ decomp. That C++ decomp isn't meant to be used with dynamic shapes, so we were failing with the "tried to call `.sizes()` on a tensor with dynamic shapes" error.

For now, I had the PyDispatcher mimic the behavior of functionalization codegen: when you register a python decomp to the `CompositeImplicitAutograd` key, this PR just automatically registers that decomp to the `Functionalize` key at the same time.

I'm trying to remember now why we didn't just add `Functionalize` (and all of the other functorch transform keys) directly to the `CompositeImplicitAutograd` alias keyset, but I couldn't remember (zou3519 any chance you remember?).




cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
inference mode Everything related to InferenceMode guard module: dynamic shapes oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants