Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Errors using torch.compile() on wav2vec2 model #91719

Closed
kalakris opened this issue Jan 4, 2023 · 18 comments
Closed

Errors using torch.compile() on wav2vec2 model #91719

kalakris opened this issue Jan 4, 2023 · 18 comments
Assignees
Labels
ezyang's list Stuff ezyang doesn't want to lose module: dynamic shapes module: dynamo module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@kalakris
Copy link

kalakris commented Jan 4, 2023

馃悰 Describe the bug

I've been experimenting with enabling torch.compile() for the wav2vec2 model in torchaudio, and ran into a few issues with it. My code can be found here. Here's a summary of the issues I found, with reproduction instructions for each:

1) Errors with inference_mode()

To reproduce this, change the no_grad() to inference_mode() in examples/asr/librispeech_ctc_decoder/inference.py, and run it:

python examples/asr/librispeech_ctc_decoder/inference.py --librispeech_path /PATH/TO/LIBRISPEECH --batch_size 2 --compile

It produces the error "RuntimeError: Inference tensors do not track version counter.", but continues to run. Full stack trace is here.

2) Errors with dynamic=True

Ideally, we'd want to run wav2vec2 with dynamic tensor sizes per call. But I wasn't able to do that. To reproduce, uncomment line 47 of examples/asr/librispeech_ctc_decoder/inference.py, comment out line 48, and run it:

python examples/asr/librispeech_ctc_decoder/inference.py --librispeech_path /PATH/TO/LIBRISPEECH --batch_size 2 --compile

Stack trace is here.

3) Errors when passing the lengths parameter to the model

When batching inputs to this model, we need to pass in the lengths of each sample as well. Unfortunately this performs some operations which seems to break with torch.compile(). To reproduce, uncomment line 93 of examples/asr/librispeech_ctc_decoder/inference.py, comment out line 94, and run it:

python examples/asr/librispeech_ctc_decoder/inference.py --librispeech_path /PATH/TO/LIBRISPEECH --batch_size 2 --compile

Stack trace is here.

Versions

PyTorch version: 2.0.0.dev20221216
Is debug build: False
CUDA used to build PyTorch: 11.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.4 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: version 3.16.3
Libc version: glibc-2.31

Python version: 3.10.8 (main, Nov 24 2022, 14:13:03) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.4.0-124-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.6.112
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: Tesla V100-SXM2-32GB
Nvidia driver version: 470.141.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.23.4
[pip3] torch==2.0.0.dev20221216
[pip3] torchaudio==0.14.0.dev20221216
[pip3] torchvision==0.15.0.dev20221216
[conda] blas 1.0 mkl
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py310h7f8727e_0
[conda] mkl_fft 1.3.1 py310hd6ae3a3_0
[conda] mkl_random 1.2.2 py310h00e6091_0
[conda] numpy 1.23.4 py310hd5efca6_0
[conda] numpy-base 1.23.4 py310h8e6c178_0
[conda] pytorch 2.0.0.dev20221216 py3.10_cuda11.6_cudnn8.3.2_0 pytorch-nightly
[conda] pytorch-cuda 11.6 h867d48c_2 pytorch-nightly
[conda] pytorch-mutex 1.0 cuda pytorch-nightly
[conda] torchaudio 0.14.0.dev20221216 py310_cu116 pytorch-nightly
[conda] torchtriton 2.0.0+0d7e753227 py310 pytorch-nightly
[conda] torchvision 0.15.0.dev20221216 py310_cu116 pytorch-nightly

cc @ezyang @soumith @msaroufim @wconstab @ngimel @bdhirsh @mlazos @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire @mthrok

@ezyang
Copy link
Contributor

ezyang commented Jan 5, 2023

dynamic=True isn't going to work with inductor on current master, as we're still waiting on a batch of fixes from @Chillee . To preview if the dynamic shapes infra sans inductor works, you can try torch._dynamo.optimize('aot_eager', dynamic=True) and see if that errors (but I don't expect any perf improvement here.

@ezyang
Copy link
Contributor

ezyang commented Jan 5, 2023

This is currently tagged "dynamic shapes" but I want to call out that (1) and (3) are NOT dynamic shapes related

@ngimel ngimel added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jan 22, 2023
@ezyang
Copy link
Contributor

ezyang commented Jan 25, 2023

@kalakris inference with inductor and dynamic shapes should be substantially working on master, can you give this another try?

@mthrok
Copy link
Contributor

mthrok commented Jan 25, 2023

@ezyang I will try. @kalakris has moved to different work stream.

@mthrok
Copy link
Contributor

mthrok commented Jan 25, 2023

Here is the result of my trial with the nightly.

env
cpurun python -m torch.utils.collect_env
Collecting environment information...
PyTorch version: 2.0.0.dev20230125
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: 6.0.0-1ubuntu2 (tags/RELEASE_600/final)
CMake version: version 3.25.0
Libc version: glibc-2.27

Python version: 3.8.15 (default, Nov 24 2022, 15:19:38)  [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.4.0-1051-aws-x86_64-with-glibc2.17
Is CUDA available: False
CUDA runtime version: 11.7.99
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Probably one of the following:
/usr/local/cuda-10.1/targets/x86_64-linux/lib/libcudnn.so.7.6.5
/usr/local/cuda-10.2/targets/x86_64-linux/lib/libcudnn.so.7.6.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.0.5
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.23.5
[pip3] torch==2.0.0.dev20230125
[pip3] torchaudio==2.0.0a0+a9a7d84
[conda] blas                      1.0                         mkl
[conda] mkl                       2021.4.0           h06a4308_640
[conda] mkl-service               2.4.0            py38h7f8727e_0
[conda] mkl_fft                   1.3.1            py38hd3c417c_0
[conda] mkl_random                1.2.2            py38h51133e4_0
[conda] numpy                     1.24.1                   pypi_0    pypi
[conda] numpy-base                1.23.5           py38h31eccc5_0
[conda] pytorch                   2.0.0.dev20230125 py3.8_cuda11.7_cudnn8.5.0_0    pytorch-nightly
[conda] pytorch-cuda              11.7                 h67b0de4_2    pytorch-nightly
[conda] pytorch-mutex             1.0                        cuda    pytorch-nightly
[conda] torchaudio                2.0.0a0+a9a7d84          pypi_0    pypi
[conda] torchtriton               2.0.0+0d7e753227            py38    pytorch-nightly

1) (no_grad + dynamic=False) seems to be resolved.

(the code still fails but for reasons other than torch.compile)

Downloading decoder-assets/librispeech-4-gram/lexicon.txt to /data/home/moto/.cache/torch/hub/torchaudio/decoder-assets/librispeech-4-gram/lexicon.txt
100%|鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻坾 4.97M/4.97M [00:00<00:00, 61.3MB/s]
Downloading decoder-assets/librispeech-4-gram/tokens.txt to /data/home/moto/.cache/torch/hub/torchaudio/decoder-assets/librispeech-4-gram/tokens.txt
100%|鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻坾 57.0/57.0 [00:00<00:00, 51.4kB/s]
Downloading decoder-assets/librispeech-4-gram/lm.bin to /data/home/moto/.cache/torch/hub/torchaudio/decoder-assets/librispeech-4-gram/lm.bin
100%|鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅| 2.91G/2.91G [00:09<00:00, 323MB/s]
No CUDA runtime is found, using CUDA_HOME='/fsx/users/moto/conda'
[2023-01-25 16:34:59,035] torch._inductor.graph: [WARNING] Creating implicit fallback for:
  target: aten._weight_norm_interface.default
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='arg16_1', layout=FixedLayout('cpu', torch.float32, size=[768, 48, 128], stride=[6144, 128, 1]))
  ))
  args[1]: TensorBox(StorageBox(
    InputBuffer(name='arg15_1', layout=FixedLayout('cpu', torch.float32, size=[1, 1, 128], stride=[128, 128, 1]))
  ))
  args[2]: 2
[2023-01-25 16:34:59,051] torch._inductor.graph: [WARNING] Using FallbackKernel: torch.ops.aten._weight_norm_interface.default
Model evaluation 0 took 24.151989 s
Model evaluation 0 took 1.501585 s
Model evaluation 0 took 3.185728 s
Model evaluation 0 took 1.630042 s
Model evaluation 0 took 1.585629 s
Model evaluation 0 took 1.467409 s
Model evaluation 0 took 1.532574 s
Model evaluation 0 took 1.471612 s
Model evaluation 0 took 1.515457 s
Model evaluation 0 took 2.463703 s
Model evaluation 0 took 1.978550 s
Model evaluation 0 took 1.665031 s
Model evaluation 0 took 1.427234 s
Model evaluation 0 took 1.228959 s
Model evaluation 0 took 1.545177 s
Model evaluation 0 took 1.268553 s
Model evaluation 0 took 1.590519 s
Model evaluation 0 took 1.419145 s
Model evaluation 0 took 1.226012 s
Model evaluation 0 took 1.222147 s
Average runtime = 1.627635
Traceback (most recent call last):
  File "examples/asr/librispeech_ctc_decoder/inference.py", line 208, in <module>
    _main()
  File "examples/asr/librispeech_ctc_decoder/inference.py", line 204, in _main
    run_inference(args)
  File "examples/asr/librispeech_ctc_decoder/inference.py", line 108, in run_inference
    emission = emissions[i:i + 1, 0:emission_lengths[i], :]
TypeError: 'NoneType' object is not subscriptable

2) no_grad + dynamic=True fails both on CPU and GPU after long compilation time

CPU

No CUDA runtime is found, using CUDA_HOME='/fsx/users/moto/conda'
Traceback (most recent call last):
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 676, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.fake_example_inputs())
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/debug_utils.py", line 1047, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/__init__.py", line 1331, in __call__
    return self.compile_fn(model_, inputs_)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/optimizations/backends.py", line 24, in inner
    return fn(gm, example_inputs, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/optimizations/backends.py", line 61, in inductor
    return compile_fx(*args, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 376, in compile_fx
    model_ = overrides.fuse_fx(model_, example_inputs_)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_inductor/overrides.py", line 65, in fuse_fx
    is_cpu = all(
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_inductor/overrides.py", line 66, in <genexpr>
    example_input.device == torch.device("cpu") for example_input in example_inputs
AttributeError: 'SymInt' object has no attribute 'device'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "examples/asr/librispeech_ctc_decoder/inference.py", line 208, in <module>
    _main()
  File "examples/asr/librispeech_ctc_decoder/inference.py", line 204, in _main
    run_inference(args)
  File "examples/asr/librispeech_ctc_decoder/inference.py", line 94, in run_inference
    emissions, emission_lengths = model(waveforms)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1488, in _call_impl
    return forward_call(*args, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 82, in forward
    return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 211, in _fn
    return fn(*args, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 332, in catch_errors
    return callback(frame, cache_size, hooks)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 403, in _convert_frame
    result = inner_convert(frame, cache_size, hooks)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 103, in _fn
    return fn(*args, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 261, in _convert_frame_assert
    return _compile(
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 153, in time_wrapper
    r = func(*args, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 323, in _compile
    out_code = transform_code_object(code, transform)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/bytecode_transformation.py", line 341, in transform_code_object
    transformations(instructions, code_options)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 310, in transform
    tracer.run()
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1692, in run
    super().run()
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 538, in run
    and self.step()
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 501, in step
    getattr(self, inst.opname)(inst)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1758, in RETURN_VALUE
    self.output.compile_subgraph(self)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 552, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 599, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 153, in time_wrapper
    r = func(*args, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 681, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised AttributeError: 'SymInt' object has no attribute 'device'

Set torch._dynamo.config.verbose=True for more information


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True

GPU

Traceback (most recent call last):
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 676, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.fake_example_inputs())
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/debug_utils.py", line 1047, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/__init__.py", line 1331, in __call__
    return self.compile_fn(model_, inputs_)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/optimizations/backends.py", line 24, in inner
    return fn(gm, example_inputs, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/optimizations/backends.py", line 61, in inductor
    return compile_fx(*args, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 376, in compile_fx
    model_ = overrides.fuse_fx(model_, example_inputs_)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_inductor/overrides.py", line 65, in fuse_fx
    is_cpu = all(
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_inductor/overrides.py", line 66, in <genexpr>
    example_input.device == torch.device("cpu") for example_input in example_inputs
AttributeError: 'SymInt' object has no attribute 'device'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "examples/asr/librispeech_ctc_decoder/inference.py", line 208, in <module>
    _main()
  File "examples/asr/librispeech_ctc_decoder/inference.py", line 204, in _main
    run_inference(args)
  File "examples/asr/librispeech_ctc_decoder/inference.py", line 94, in run_inference
    emissions, emission_lengths = model(waveforms)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1488, in _call_impl
    return forward_call(*args, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 82, in forward
    return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 211, in _fn
    return fn(*args, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 332, in catch_errors
    return callback(frame, cache_size, hooks)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 403, in _convert_frame
    result = inner_convert(frame, cache_size, hooks)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 103, in _fn
    return fn(*args, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 261, in _convert_frame_assert
    return _compile(
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 153, in time_wrapper
    r = func(*args, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 323, in _compile
    out_code = transform_code_object(code, transform)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/bytecode_transformation.py", line 341, in transform_code_object
    transformations(instructions, code_options)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 310, in transform
    tracer.run()
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1692, in run
    super().run()
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 538, in run
    and self.step()
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 501, in step
    getattr(self, inst.opname)(inst)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1758, in RETURN_VALUE
    self.output.compile_subgraph(self)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 552, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 599, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 153, in time_wrapper
    r = func(*args, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 681, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised AttributeError: 'SymInt' object has no attribute 'device'

Set torch._dynamo.config.verbose=True for more information


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True

3) Passing optional length parameter failed.

[2023-01-25 16:45:47,620] torch._inductor.graph: [WARNING] Creating implicit fallback for:
  target: aten._weight_norm_interface.default
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='arg16_1', layout=FixedLayout('cpu', torch.float32, size=[768, 48, 128], stride=[6144, 128, 1]))
  ))
  args[1]: TensorBox(StorageBox(
    InputBuffer(name='arg15_1', layout=FixedLayout('cpu', torch.float32, size=[1, 1, 128], stride=[128, 128, 1]))
  ))
  args[2]: 2
[2023-01-25 16:45:47,635] torch._inductor.graph: [WARNING] Using FallbackKernel: torch.ops.aten._weight_norm_interface.default
[2023-01-25 16:45:47,648] torch._inductor.graph: [ERROR] Error from lowering
Traceback (most recent call last):
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_inductor/graph.py", line 314, in call_function
    out = lowerings[target](*args, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_inductor/lowering.py", line 226, in wrapped
    return decomp_fn(*args, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_inductor/lowering.py", line 1224, in convolution
    ir.Convolution.create(
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_inductor/ir.py", line 3067, in create
    output = torch.ops.aten.convolution(
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_ops.py", line 499, in __call__
    return self._op(*args, **kwargs or {})
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_inductor/overrides.py", line 37, in __torch_function__
    return func(*args, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_ops.py", line 499, in __call__
    return self._op(*args, **kwargs or {})
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 923, in __torch_dispatch__
    op_impl_out = op_impl(self, func, *args, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 502, in conv
    out = func(**kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_ops.py", line 284, in __call__
    return self._op(*args, **kwargs or {})
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_meta_registrations.py", line 600, in meta_conv
    shape_out = calc_conv_nd_return_shape(
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_meta_registrations.py", line 524, in calc_conv_nd_return_shape
    raise RuntimeError("Invalid channel dimensions")
RuntimeError: Invalid channel dimensions
Traceback (most recent call last):
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_inductor/graph.py", line 314, in call_function
    out = lowerings[target](*args, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_inductor/lowering.py", line 226, in wrapped
    return decomp_fn(*args, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_inductor/lowering.py", line 1224, in convolution
    ir.Convolution.create(
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_inductor/ir.py", line 3067, in create
    output = torch.ops.aten.convolution(
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_ops.py", line 499, in __call__
    return self._op(*args, **kwargs or {})
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_inductor/overrides.py", line 37, in __torch_function__
    return func(*args, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_ops.py", line 499, in __call__
    return self._op(*args, **kwargs or {})
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 923, in __torch_dispatch__
    op_impl_out = op_impl(self, func, *args, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 502, in conv
    out = func(**kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_ops.py", line 284, in __call__
    return self._op(*args, **kwargs or {})
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_meta_registrations.py", line 600, in meta_conv
    shape_out = calc_conv_nd_return_shape(
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_meta_registrations.py", line 524, in calc_conv_nd_return_shape
    raise RuntimeError("Invalid channel dimensions")
RuntimeError: Invalid channel dimensions

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 676, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.fake_example_inputs())
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/debug_utils.py", line 1047, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/__init__.py", line 1331, in __call__
    return self.compile_fn(model_, inputs_)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/optimizations/backends.py", line 24, in inner
    return fn(gm, example_inputs, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/optimizations/backends.py", line 61, in inductor
    return compile_fx(*args, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 410, in compile_fx
    return aot_autograd(
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/optimizations/training.py", line 74, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 2454, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 153, in time_wrapper
    r = func(*args, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 2151, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 1411, in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 1061, in aot_dispatch_base
    compiled_fw = aot_config.fw_compiler(fw_module, flat_args)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 153, in time_wrapper
    r = func(*args, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 385, in fw_compiler
    return inner_compile(
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/debug_utils.py", line 586, in debug_wrapper
    compiled_fn = compiler_fn(gm, example_inputs, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_inductor/debug.py", line 239, in inner
    return fn(*args, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 149, in compile_fx_inner
    graph.run(*example_inputs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 153, in time_wrapper
    r = func(*args, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_inductor/graph.py", line 182, in run
    return super().run(*args)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/fx/interpreter.py", line 136, in run
    self.env[node] = self.run_node(node)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_inductor/graph.py", line 386, in run_node
    result = self.call_function(n.target, args, kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_inductor/graph.py", line 318, in call_function
    raise LoweringException(e, target, args, kwargs) from e
torch._inductor.exc.LoweringException: RuntimeError: Invalid channel dimensions
  target: aten.convolution.default
  args[0]: TensorBox(
    ReinterpretView(
      StorageBox(
        ComputedBuffer(name='buf23', layout=FixedLayout('cpu', torch.float32, size=[2, 2, 1725], stride=[3450, 1725, 1]), data=Pointwise(
          'cpu',
          torch.float32,
          tmp0 = index_expr(i2, torch.int64)
          tmp1 = load(buf17, i1)
          tmp2 = tmp0 >= tmp1
          tmp3 = constant(0.0, torch.float32)
          tmp4 = load(buf22, i2 + 768 * i1 + 1324800 * i0)
          tmp5 = where(tmp2, tmp3, tmp4)
          return tmp5
          ,
          ranges=[2, 2, 1725],
          origins={lift_fresh_copy, arg10_1, erf_6, mul_21, var_mean_1, clone, arg11_1, _tensor_constant0, index_put_, add_14, _mkl_linear, rsqrt_1, arg12_1, convolution_6, mul_20, mul_23, sub_8, arg13_1, arg8_1, mul_22, permute, mul_24, arg9_1, add_17, add_16}
        ))
      ),
      FixedLayout('cpu', torch.float32, size=[2, 1725, 2], stride=[3450, 1, 1725]),
      no origins?
    )
  )
  args[1]: TensorBox(StorageBox(
    MultiOutput(
      name=buf25,
      layout=FixedLayout('cpu', torch.float32, size=[768, 48, 128], stride=[6144, 128, 1]),
      inputs=[FallbackKernel(name='buf24', layout=MultiOutputLayout(device=device(type='cpu')), inputs=[InputBuffer(name='arg16_1', layout=FixedLayout('cpu', torch.float32, size=[768, 48, 128], stride=[6144, 128, 1])), InputBuffer(name='arg15_1', layout=FixedLayout('cpu', torch.float32, size=[1, 1, 128], stride=[128, 128, 1]))], constant_args=(2,), kwargs={}, output_view=None)],
      constant_args=(),
      kwargs={},
      output_view=None,
      origins={arg15_1, _weight_norm_interface, arg16_1}
    )
  ))
  args[2]: TensorBox(StorageBox(
    InputBuffer(name='arg14_1', layout=FixedLayout('cpu', torch.float32, size=[768], stride=[1]))
  ))
  args[3]: [1]
  args[4]: [64]
  args[5]: [1]
  args[6]: False
  args[7]: [0]
  args[8]: 16

While executing %convolution_7 : [#users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%permute_1, %getitem_4, %arg14_1, [1], [64], [1], False, [0], 16), kwargs = {})
Original traceback:
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torchaudio-2.0.0a0+a9a7d84-py3.8-linux-x86_64.egg/torchaudio/models/wav2vec2/components.py", line 234, in forward
    x = self.conv(x)
 |   File "/fsx/users/moto/conda/lib/python3.8/site-packages/torchaudio-2.0.0a0+a9a7d84-py3.8-linux-x86_64.egg/torchaudio/models/wav2vec2/components.py", line 438, in _preprocess
    x = x + self.pos_conv_embed(x)
 |   File "/fsx/users/moto/conda/lib/python3.8/site-packages/torchaudio-2.0.0a0+a9a7d84-py3.8-linux-x86_64.egg/torchaudio/models/wav2vec2/components.py", line 452, in forward
    x = self._preprocess(x)
 |   File "/fsx/users/moto/conda/lib/python3.8/site-packages/torchaudio-2.0.0a0+a9a7d84-py3.8-linux-x86_64.egg/torchaudio/models/wav2vec2/components.py", line 515, in forward
    x = self.transformer(x, attention_mask=mask)
 |   File "/fsx/users/moto/conda/lib/python3.8/site-packages/torchaudio-2.0.0a0+a9a7d84-py3.8-linux-x86_64.egg/torchaudio/models/wav2vec2/model.py", line 117, in forward
    x = self.encoder(x, lengths)


The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "examples/asr/librispeech_ctc_decoder/inference.py", line 208, in <module>
    _main()
  File "examples/asr/librispeech_ctc_decoder/inference.py", line 204, in _main
    run_inference(args)
  File "examples/asr/librispeech_ctc_decoder/inference.py", line 93, in run_inference
    emissions, emission_lengths = model(waveforms, lengths)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1488, in _call_impl
    return forward_call(*args, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 82, in forward
    return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 211, in _fn
    return fn(*args, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 332, in catch_errors
    return callback(frame, cache_size, hooks)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 403, in _convert_frame
    result = inner_convert(frame, cache_size, hooks)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 103, in _fn
    return fn(*args, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 261, in _convert_frame_assert
    return _compile(
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 153, in time_wrapper
    r = func(*args, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 323, in _compile
    out_code = transform_code_object(code, transform)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/bytecode_transformation.py", line 341, in transform_code_object
    transformations(instructions, code_options)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 310, in transform
    tracer.run()
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1692, in run
    super().run()
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 538, in run
    and self.step()
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 501, in step
    getattr(self, inst.opname)(inst)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1758, in RETURN_VALUE
    self.output.compile_subgraph(self)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 552, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 599, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 153, in time_wrapper
    r = func(*args, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 681, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised LoweringException: RuntimeError: Invalid channel dimensions
  target: aten.convolution.default
  args[0]: TensorBox(
    ReinterpretView(
      StorageBox(
        ComputedBuffer(name='buf23', layout=FixedLayout('cpu', torch.float32, size=[2, 2, 1725], stride=[3450, 1725, 1]), data=Pointwise(
          'cpu',
          torch.float32,
          tmp0 = index_expr(i2, torch.int64)
          tmp1 = load(buf17, i1)
          tmp2 = tmp0 >= tmp1
          tmp3 = constant(0.0, torch.float32)
          tmp4 = load(buf22, i2 + 768 * i1 + 1324800 * i0)
          tmp5 = where(tmp2, tmp3, tmp4)
          return tmp5
          ,
          ranges=[2, 2, 1725],
          origins={lift_fresh_copy, arg10_1, erf_6, mul_21, var_mean_1, clone, arg11_1, _tensor_constant0, index_put_, add_14, _mkl_linear, rsqrt_1, arg12_1, convolution_6, mul_20, mul_23, sub_8, arg13_1, arg8_1, mul_22, permute, mul_24, arg9_1, add_17, add_16}
        ))
      ),
      FixedLayout('cpu', torch.float32, size=[2, 1725, 2], stride=[3450, 1, 1725]),
      no origins?
    )
  )
  args[1]: TensorBox(StorageBox(
    MultiOutput(
      name=buf25,
      layout=FixedLayout('cpu', torch.float32, size=[768, 48, 128], stride=[6144, 128, 1]),
      inputs=[FallbackKernel(name='buf24', layout=MultiOutputLayout(device=device(type='cpu')), inputs=[InputBuffer(name='arg16_1', layout=FixedLayout('cpu', torch.float32, size=[768, 48, 128], stride=[6144, 128, 1])), InputBuffer(name='arg15_1', layout=FixedLayout('cpu', torch.float32, size=[1, 1, 128], stride=[128, 128, 1]))], constant_args=(2,), kwargs={}, output_view=None)],
      constant_args=(),
      kwargs={},
      output_view=None,
      origins={arg15_1, _weight_norm_interface, arg16_1}
    )
  ))
  args[2]: TensorBox(StorageBox(
    InputBuffer(name='arg14_1', layout=FixedLayout('cpu', torch.float32, size=[768], stride=[1]))
  ))
  args[3]: [1]
  args[4]: [64]
  args[5]: [1]
  args[6]: False
  args[7]: [0]
  args[8]: 16

While executing %convolution_7 : [#users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%permute_1, %getitem_4, %arg14_1, [1], [64], [1], False, [0], 16), kwargs = {})
Original traceback:
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torchaudio-2.0.0a0+a9a7d84-py3.8-linux-x86_64.egg/torchaudio/models/wav2vec2/components.py", line 234, in forward
    x = self.conv(x)
 |   File "/fsx/users/moto/conda/lib/python3.8/site-packages/torchaudio-2.0.0a0+a9a7d84-py3.8-linux-x86_64.egg/torchaudio/models/wav2vec2/components.py", line 438, in _preprocess
    x = x + self.pos_conv_embed(x)
 |   File "/fsx/users/moto/conda/lib/python3.8/site-packages/torchaudio-2.0.0a0+a9a7d84-py3.8-linux-x86_64.egg/torchaudio/models/wav2vec2/components.py", line 452, in forward
    x = self._preprocess(x)
 |   File "/fsx/users/moto/conda/lib/python3.8/site-packages/torchaudio-2.0.0a0+a9a7d84-py3.8-linux-x86_64.egg/torchaudio/models/wav2vec2/components.py", line 515, in forward
    x = self.transformer(x, attention_mask=mask)
 |   File "/fsx/users/moto/conda/lib/python3.8/site-packages/torchaudio-2.0.0a0+a9a7d84-py3.8-linux-x86_64.egg/torchaudio/models/wav2vec2/model.py", line 117, in forward
    x = self.encoder(x, lengths)


Set torch._dynamo.config.verbose=True for more information


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True

BTW on my M1 MacBook Pro I encountered following compilation error.


[2023-01-25 11:15:12,212] torch._inductor.graph: [WARNING] Creating implicit fallback for:
  target: aten._weight_norm_interface.default
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='arg15_1', layout=FixedLayout('cpu', torch.float32, size=[768, 48, 128], stride=[6144, 128, 1]))
  ))
  args[1]: TensorBox(StorageBox(
    InputBuffer(name='arg14_1', layout=FixedLayout('cpu', torch.float32, size=[1, 1, 128], stride=[128, 128, 1]))
  ))
  args[2]: 2
[2023-01-25 11:15:12,222] torch._inductor.graph: [WARNING] Using FallbackKernel: torch.ops.aten._weight_norm_interface.default
Traceback (most recent call last):
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 481, in load
    subprocess.check_output(cmd, stderr=subprocess.STDOUT)
  File "/Users/moto/miniconda3/lib/python3.9/subprocess.py", line 424, in check_output
    return run(*popenargs, stdout=PIPE, timeout=timeout, check=True,
  File "/Users/moto/miniconda3/lib/python3.9/subprocess.py", line 528, in run
    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['g++', '/var/folders/jy/ln20q1m92vzdvycyw3ttf5p40000gn/T/torchinductor_moto/lk/clklc23ummstvbc4wih44j2tlnqzeum5mmwbxvjuonbljvi4qy36.cpp', '-shared', '-fPIC', '-Wall', '-std=c++17', '-Wno-unused-variable', '-I/Users/moto/miniconda3/lib/python3.9/site-packages/torch/include', '-I/Users/moto/miniconda3/lib/python3.9/site-packages/torch/include/torch/csrc/api/include', '-I/Users/moto/miniconda3/lib/python3.9/site-packages/torch/include/TH', '-I/Users/moto/miniconda3/lib/python3.9/site-packages/torch/include/THC', '-I/Users/moto/miniconda3/include/python3.9', '-lgomp', '-march=native', '-O3', '-ffast-math', '-fno-finite-math-only', '-fopenmp', '-D', 'C10_USING_CUSTOM_GENERATED_MACROS', '-o/var/folders/jy/ln20q1m92vzdvycyw3ttf5p40000gn/T/torchinductor_moto/lk/clklc23ummstvbc4wih44j2tlnqzeum5mmwbxvjuonbljvi4qy36.so']' returned non-zero exit status 1.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 676, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.fake_example_inputs())
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_dynamo/debug_utils.py", line 1047, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs, **kwargs)
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/__init__.py", line 1325, in __call__
    return self.compile_fn(model_, inputs_)
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_dynamo/optimizations/backends.py", line 24, in inner
    return fn(gm, example_inputs, **kwargs)
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_dynamo/optimizations/backends.py", line 61, in inductor
    return compile_fx(*args, **kwargs)
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 410, in compile_fx
    return aot_autograd(
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_dynamo/optimizations/training.py", line 74, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 2477, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 153, in time_wrapper
    r = func(*args, **kwargs)
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 2174, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config)
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 1412, in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config)
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 1062, in aot_dispatch_base
    compiled_fw = aot_config.fw_compiler(fw_module, flat_args)
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 153, in time_wrapper
    r = func(*args, **kwargs)
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 385, in fw_compiler
    return inner_compile(
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_dynamo/debug_utils.py", line 586, in debug_wrapper
    compiled_fn = compiler_fn(gm, example_inputs, **kwargs)
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_inductor/debug.py", line 239, in inner
    return fn(*args, **kwargs)
  File "/Users/moto/miniconda3/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 150, in compile_fx_inner
    compiled_fn = graph.compile_to_fn()
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_inductor/graph.py", line 560, in compile_to_fn
    return self.compile_to_module().call
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 153, in time_wrapper
    r = func(*args, **kwargs)
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_inductor/graph.py", line 549, in compile_to_module
    mod = PyCodeCache.load(code)
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 504, in load
    exec(code, mod.__dict__, mod.__dict__)
  File "/var/folders/jy/ln20q1m92vzdvycyw3ttf5p40000gn/T/torchinductor_moto/mu/cmuqvfap6wfhlrlji24iyctqbe5mcw5afjxep2ns6tj7nh6zllna.py", line 5185, in <module>
    async_compile.wait(globals())
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 699, in wait
    scope[key] = result.result()
  File "/Users/moto/miniconda3/lib/python3.9/concurrent/futures/_base.py", line 446, in result
    return self.__get_result()
  File "/Users/moto/miniconda3/lib/python3.9/concurrent/futures/_base.py", line 391, in __get_result
    raise self._exception
  File "/Users/moto/miniconda3/lib/python3.9/concurrent/futures/thread.py", line 58, in run
    result = self.fn(*self.args, **self.kwargs)
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 676, in task
    return CppCodeCache.load(source_code).kernel
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 483, in load
    raise exc.CppCompileError(cmd, e.output) from e
torch._inductor.exc.CppCompileError: C++ compile error

Command:
g++ /var/folders/jy/ln20q1m92vzdvycyw3ttf5p40000gn/T/torchinductor_moto/lk/clklc23ummstvbc4wih44j2tlnqzeum5mmwbxvjuonbljvi4qy36.cpp -shared -fPIC -Wall -std=c++17 -Wno-unused-variable -I/Users/moto/miniconda3/lib/python3.9/site-packages/torch/include -I/Users/moto/miniconda3/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -I/Users/moto/miniconda3/lib/python3.9/site-packages/torch/include/TH -I/Users/moto/miniconda3/lib/python3.9/site-packages/torch/include/THC -I/Users/moto/miniconda3/include/python3.9 -lgomp -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp -D C10_USING_CUSTOM_GENERATED_MACROS -o/var/folders/jy/ln20q1m92vzdvycyw3ttf5p40000gn/T/torchinductor_moto/lk/clklc23ummstvbc4wih44j2tlnqzeum5mmwbxvjuonbljvi4qy36.so

Output:
clang: error: the clang compiler does not support '-march=native'
clang: error: unsupported option '-fopenmp'
clang: error: unsupported option '-fopenmp'


The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/moto/Development/torchaudio/examples/asr/librispeech_ctc_decoder/inference.py", line 208, in <module>
    _main()
  File "/Users/moto/Development/torchaudio/examples/asr/librispeech_ctc_decoder/inference.py", line 204, in _main
    run_inference(args)
  File "/Users/moto/Development/torchaudio/examples/asr/librispeech_ctc_decoder/inference.py", line 94, in run_inference
    emissions, emission_lengths = model(waveforms)
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1488, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 82, in forward
    return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 211, in _fn
    return fn(*args, **kwargs)
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 332, in catch_errors
    return callback(frame, cache_size, hooks)
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 403, in _convert_frame
    result = inner_convert(frame, cache_size, hooks)
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 103, in _fn
    return fn(*args, **kwargs)
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 261, in _convert_frame_assert
    return _compile(
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 153, in time_wrapper
    r = func(*args, **kwargs)
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 323, in _compile
    out_code = transform_code_object(code, transform)
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_dynamo/bytecode_transformation.py", line 341, in transform_code_object
    transformations(instructions, code_options)
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 310, in transform
    tracer.run()
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1692, in run
    super().run()
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 538, in run
    and self.step()
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 501, in step
    getattr(self, inst.opname)(inst)
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1758, in RETURN_VALUE
    self.output.compile_subgraph(self)
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 552, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 599, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 153, in time_wrapper
    r = func(*args, **kwargs)
  File "/Users/moto/miniconda3/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 681, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised CppCompileError: C++ compile error

Command:
g++ /var/folders/jy/ln20q1m92vzdvycyw3ttf5p40000gn/T/torchinductor_moto/lk/clklc23ummstvbc4wih44j2tlnqzeum5mmwbxvjuonbljvi4qy36.cpp -shared -fPIC -Wall -std=c++17 -Wno-unused-variable -I/Users/moto/miniconda3/lib/python3.9/site-packages/torch/include -I/Users/moto/miniconda3/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -I/Users/moto/miniconda3/lib/python3.9/site-packages/torch/include/TH -I/Users/moto/miniconda3/lib/python3.9/site-packages/torch/include/THC -I/Users/moto/miniconda3/include/python3.9 -lgomp -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp -D C10_USING_CUSTOM_GENERATED_MACROS -o/var/folders/jy/ln20q1m92vzdvycyw3ttf5p40000gn/T/torchinductor_moto/lk/clklc23ummstvbc4wih44j2tlnqzeum5mmwbxvjuonbljvi4qy36.so

Output:
clang: error: the clang compiler does not support '-march=native'
clang: error: unsupported option '-fopenmp'
clang: error: unsupported option '-fopenmp'


Set torch._dynamo.config.verbose=True for more information


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True

@ezyang
Copy link
Contributor

ezyang commented Jan 25, 2023

Thanks a lot!

ezyang added a commit that referenced this issue Jan 25, 2023
Should probably figure out how to get type checking going, would have
caught these cases.

Discovered in pursuit of #91719
though this is not enough.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

[ghstack-poisoned]
ezyang added a commit that referenced this issue Jan 25, 2023
Should probably figure out how to get type checking going, would have
caught these cases.

Discovered in pursuit of #91719
though this is not enough.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: 118a0990137fb9368f2a8b21fd3a5a22ae1cb298
Pull Request resolved: #92997
ezyang added a commit that referenced this issue Jan 25, 2023
Should probably figure out how to get type checking going, would have
caught these cases.

Discovered in pursuit of #91719
though this is not enough.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
ezyang added a commit that referenced this issue Jan 25, 2023
Should probably figure out how to get type checking going, would have
caught these cases.

Discovered in pursuit of #91719
though this is not enough.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: 20d6b522680eae6ab2c20b3f273a1f08d610bd8b
Pull Request resolved: #92997
pytorchmergebot pushed a commit that referenced this issue Jan 26, 2023
Should probably figure out how to get type checking going, would have
caught these cases.

Discovered in pursuit of #91719
though this is not enough.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: #92997
Approved by: https://github.com/Chillee
@ezyang
Copy link
Contributor

ezyang commented Jan 26, 2023

I got (2) only working. With #93071 on master (in particular you need #92997 which has landed) and with this hotpatch for an error that I need to track down elsewhere

diff --git a/torchaudio/models/wav2vec2/components.py b/torchaudio/models/wav2vec2/components.py
index 822fed42..65028d01 100644
--- a/torchaudio/models/wav2vec2/components.py
+++ b/torchaudio/models/wav2vec2/components.py
@@ -308,7 +308,7 @@ class SelfAttention(Module):
         v = self.v_proj(x).view(*shape).transpose(2, 1)  # B, nH, L, Hd
 
         # scale down q to avoid value overflow.
-        weights = (self.scaling * q) @ k  # B, nH, L, L
+        weights = (q * self.scaling) @ k  # B, nH, L, L
         if attention_mask is not None:
             weights += attention_mask
         # subtracting a constant value from the tensor won't change the output of softmax.

dynamic=True runs successfully with CUDA.

I get

[2023-01-26 08:08:06,253] torch._inductor.graph: [WARNING] Using FallbackKernel: torch.ops.aten._weight_nor
m_interface.default                                                                                        
Model evaluation 0 took 44.195989 s                                                                        
Model evaluation 0 took 0.105661 s                                                                         
Model evaluation 0 took 0.088884 s                                                                         
Model evaluation 0 took 0.321908 s                                                                         
Model evaluation 0 took 0.326599 s                                                                         
Model evaluation 0 took 0.322213 s                                                                         
Model evaluation 0 took 0.322474 s                   
Model evaluation 0 took 0.322530 s                                                                         
Model evaluation 0 took 0.321882 s                                                                         
Model evaluation 0 took 0.326978 s                                                                         
Model evaluation 0 took 0.322435 s                                                                         
Model evaluation 0 took 0.322489 s                                                                         
Model evaluation 0 took 0.322241 s                                                                         
Model evaluation 0 took 0.326942 s 
Model evaluation 0 took 0.332610 s
Model evaluation 0 took 0.335572 s
Model evaluation 0 took 0.330488 s
Model evaluation 0 took 0.330524 s
Model evaluation 0 took 0.330936 s
Model evaluation 0 took 0.335064 s
Average runtime = 0.302549 

If there is a way to make the benchmark script actually vary sequence length, I think that would be the real test.

@ezyang
Copy link
Contributor

ezyang commented Jan 26, 2023

hotpatch no longer needed with #93073

@ezyang
Copy link
Contributor

ezyang commented Jan 26, 2023

With all of these patches, all 1/2/3 work with CUDA. CPU still failing though.

ezyang added a commit that referenced this issue Jan 26, 2023
These errors were found by looking at wav2vec2

See #91719

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

[ghstack-poisoned]
ezyang added a commit that referenced this issue Jan 26, 2023
These errors were found by looking at wav2vec2

See #91719

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: 51b6d644b4e30b9137354a4fb5eece5534d3412e
Pull Request resolved: #93077
@ezyang
Copy link
Contributor

ezyang commented Jan 26, 2023

And with #93077 works with CPU. So once all these PRs land I think you're good to go!

@ezyang ezyang self-assigned this Jan 26, 2023
ezyang added a commit that referenced this issue Jan 27, 2023
These errors were found by looking at wav2vec2

See #91719

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
ezyang added a commit that referenced this issue Jan 27, 2023
These errors were found by looking at wav2vec2

See #91719

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: f31fe08b8e91426f962c3ce6cc149b4bb4827a73
Pull Request resolved: #93077
pytorchmergebot pushed a commit that referenced this issue Jan 27, 2023
These errors were found by looking at wav2vec2

See #91719

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: #93077
Approved by: https://github.com/voznesenskym, https://github.com/ngimel
@mthrok
Copy link
Contributor

mthrok commented Feb 2, 2023

@ezyang I tried again with the latest nightly. With the dynamic_shape=True, now I am getting a new error.

torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised TypeError: unhashable type: 'SymInt'

BTW this was tested with V100. Should I tests with A100?

Stacktrace
[2023-02-02 23:42:44,083] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing forward
[2023-02-02 23:42:47,832] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo done tracing forward (RETURN_VALUE)
[2023-02-02 23:42:47,862] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function debug_wrapper
Traceback (most recent call last):
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 692, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.fake_example_inputs())
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/debug_utils.py", line 1054, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/__init__.py", line 1368, in __call__
    return self.compile_fn(model_, inputs_, config_patches=self.config)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 378, in compile_fx
    return compile_fx(
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 426, in compile_fx
    return aot_autograd(
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/optimizations/training.py", line 66, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 2483, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 162, in time_wrapper
    r = func(*args, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 2180, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 1401, in aot_wrapper_dedupe
    if a not in args_set:
TypeError: unhashable type: 'SymInt'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "examples/asr/librispeech_ctc_decoder/inference.py", line 209, in <module>
    _main()
  File "examples/asr/librispeech_ctc_decoder/inference.py", line 205, in _main
    run_inference(args)
  File "examples/asr/librispeech_ctc_decoder/inference.py", line 95, in run_inference
    emissions, emission_lengths = model(waveforms)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 82, in forward
    return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 330, in catch_errors
    return callback(frame, cache_size, hooks)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 403, in _convert_frame
    result = inner_convert(frame, cache_size, hooks)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 103, in _fn
    return fn(*args, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 261, in _convert_frame_assert
    return _compile(
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 162, in time_wrapper
    r = func(*args, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 323, in _compile
    out_code = transform_code_object(code, transform)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/bytecode_transformation.py", line 339, in transform_code_object
    transformations(instructions, code_options)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 310, in transform
    tracer.run()
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1715, in run
    super().run()
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 564, in run
    and self.step()
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 527, in step
    getattr(self, inst.opname)(inst)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1781, in RETURN_VALUE
    self.output.compile_subgraph(self)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 563, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 610, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 162, in time_wrapper
    r = func(*args, **kwargs)
  File "/fsx/users/moto/conda/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 697, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised TypeError: unhashable type: 'SymInt'

@ezyang
Copy link
Contributor

ezyang commented Feb 3, 2023

Not expected and gpu shouldn't matter

@ezyang
Copy link
Contributor

ezyang commented Feb 3, 2023

Cc @voznesenskym this feels dedupe related

@fmocking
Copy link

I'm getting the same error on V100 using hugginface transformer model.

@ezyang
Copy link
Contributor

ezyang commented Mar 22, 2023

Unhashable SymInt should have been fixed by #95533 @fmocking please try a nightly

@ezyang
Copy link
Contributor

ezyang commented Mar 22, 2023

@mthrok I tried (2) again with origin/wav2vec2-pytorch2 and I get this, which seems like it's working:

(/home/ezyang/local/a/pytorch-env) [ezyang@devgpu020.ftw1 ~/local/a/torchaudio (wav2vec2-pytorch2)]$ python examples/asr/librispeech_ctc_decoder/inference.py --librispeech_path ~/local/ --batch_size 2 --compile
The local file (/home/ezyang/.cache/torch/hub/torchaudio/decoder-assets/librispeech-4-gram/lexicon.txt) exists. Skipping the download.
The local file (/home/ezyang/.cache/torch/hub/torchaudio/decoder-assets/librispeech-4-gram/tokens.txt) exists. Skipping the download.
The local file (/home/ezyang/.cache/torch/hub/torchaudio/decoder-assets/librispeech-4-gram/lm.bin) exists. Skipping the download.
[2023-03-21 19:41:26,128] torch._inductor.utils: [WARNING] make_fallback(aten.addmv): a decomposition exists, we should switch to it
Model evaluation 0 took 56.855116 s
Model evaluation 0 took 1.263582 s
Model evaluation 0 took 1.266036 s
Model evaluation 0 took 1.234831 s
Model evaluation 0 took 1.239384 s
Model evaluation 0 took 1.247958 s
Model evaluation 0 took 1.212082 s
Model evaluation 0 took 1.224670 s
Model evaluation 0 took 1.236169 s
Model evaluation 0 took 1.224591 s
Model evaluation 0 took 1.215647 s
Model evaluation 0 took 1.250675 s
Model evaluation 0 took 1.235021 s
Model evaluation 0 took 1.223134 s
Model evaluation 0 took 1.235374 s
Model evaluation 0 took 1.242361 s
Model evaluation 0 took 1.206097 s
Model evaluation 0 took 1.208330 s
Model evaluation 0 took 1.210267 s
Model evaluation 0 took 1.223777 s
Average runtime = 1.231578
Traceback (most recent call last):
  File "/data/users/ezyang/a/torchaudio/examples/asr/librispeech_ctc_decoder/inference.py", line 208, in <module>
    _main()
  File "/data/users/ezyang/a/torchaudio/examples/asr/librispeech_ctc_decoder/inference.py", line 204, in _main
    run_inference(args)
  File "/data/users/ezyang/a/torchaudio/examples/asr/librispeech_ctc_decoder/inference.py", line 108, in run_inference
    emission = emissions[i:i + 1, 0:emission_lengths[i], :]
TypeError: 'NoneType' object is not subscriptable

@ezyang ezyang added the ezyang's list Stuff ezyang doesn't want to lose label Mar 22, 2023
@eellison
Copy link
Contributor

Is this fixed ?

@williamwen42
Copy link
Member

Closing due to inactivity - reopen if this is still an issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ezyang's list Stuff ezyang doesn't want to lose module: dynamic shapes module: dynamo module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

7 participants