New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
Errors using torch.compile() on wav2vec2 model #91719
Comments
dynamic=True isn't going to work with inductor on current master, as we're still waiting on a batch of fixes from @Chillee . To preview if the dynamic shapes infra sans inductor works, you can try |
This is currently tagged "dynamic shapes" but I want to call out that (1) and (3) are NOT dynamic shapes related |
@kalakris inference with inductor and dynamic shapes should be substantially working on master, can you give this another try? |
Here is the result of my trial with the nightly. env
1) (no_grad + dynamic=False) seems to be resolved.(the code still fails but for reasons other than
2) no_grad + dynamic=True fails both on CPU and GPU after long compilation timeCPU
GPU
3) Passing optional length parameter failed.
BTW on my M1 MacBook Pro I encountered following compilation error.
|
Thanks a lot! |
Should probably figure out how to get type checking going, would have caught these cases. Discovered in pursuit of #91719 though this is not enough. Signed-off-by: Edward Z. Yang <ezyang@meta.com> [ghstack-poisoned]
Should probably figure out how to get type checking going, would have caught these cases. Discovered in pursuit of #91719 though this is not enough. Signed-off-by: Edward Z. Yang <ezyangmeta.com> cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
Should probably figure out how to get type checking going, would have caught these cases. Discovered in pursuit of #91719 though this is not enough. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: #92997 Approved by: https://github.com/Chillee
I got (2) only working. With #93071 on master (in particular you need #92997 which has landed) and with this hotpatch for an error that I need to track down elsewhere
dynamic=True runs successfully with CUDA. I get
If there is a way to make the benchmark script actually vary sequence length, I think that would be the real test. |
hotpatch no longer needed with #93073 |
With all of these patches, all 1/2/3 work with CUDA. CPU still failing though. |
These errors were found by looking at wav2vec2 See #91719 Signed-off-by: Edward Z. Yang <ezyang@meta.com> [ghstack-poisoned]
And with #93077 works with CPU. So once all these PRs land I think you're good to go! |
These errors were found by looking at wav2vec2 See #91719 Signed-off-by: Edward Z. Yang <ezyangmeta.com> cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
These errors were found by looking at wav2vec2 See #91719 Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: #93077 Approved by: https://github.com/voznesenskym, https://github.com/ngimel
@ezyang I tried again with the latest nightly. With the
BTW this was tested with V100. Should I tests with A100? Stacktrace
|
Not expected and gpu shouldn't matter |
Cc @voznesenskym this feels dedupe related |
I'm getting the same error on V100 using hugginface transformer model. |
@mthrok I tried (2) again with origin/wav2vec2-pytorch2 and I get this, which seems like it's working:
|
Is this fixed ? |
Closing due to inactivity - reopen if this is still an issue. |
馃悰 Describe the bug
I've been experimenting with enabling torch.compile() for the wav2vec2 model in torchaudio, and ran into a few issues with it. My code can be found here. Here's a summary of the issues I found, with reproduction instructions for each:
1) Errors with
inference_mode()
To reproduce this, change the
no_grad()
toinference_mode()
inexamples/asr/librispeech_ctc_decoder/inference.py
, and run it:It produces the error "RuntimeError: Inference tensors do not track version counter.", but continues to run. Full stack trace is here.
2) Errors with
dynamic=True
Ideally, we'd want to run wav2vec2 with dynamic tensor sizes per call. But I wasn't able to do that. To reproduce, uncomment line 47 of
examples/asr/librispeech_ctc_decoder/inference.py
, comment out line 48, and run it:Stack trace is here.
3) Errors when passing the
lengths
parameter to the modelWhen batching inputs to this model, we need to pass in the lengths of each sample as well. Unfortunately this performs some operations which seems to break with
torch.compile()
. To reproduce, uncomment line 93 ofexamples/asr/librispeech_ctc_decoder/inference.py
, comment out line 94, and run it:Stack trace is here.
Versions
PyTorch version: 2.0.0.dev20221216
Is debug build: False
CUDA used to build PyTorch: 11.6
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.4 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: version 3.16.3
Libc version: glibc-2.31
Python version: 3.10.8 (main, Nov 24 2022, 14:13:03) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.4.0-124-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.6.112
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: Tesla V100-SXM2-32GB
Nvidia driver version: 470.141.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] numpy==1.23.4
[pip3] torch==2.0.0.dev20221216
[pip3] torchaudio==0.14.0.dev20221216
[pip3] torchvision==0.15.0.dev20221216
[conda] blas 1.0 mkl
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py310h7f8727e_0
[conda] mkl_fft 1.3.1 py310hd6ae3a3_0
[conda] mkl_random 1.2.2 py310h00e6091_0
[conda] numpy 1.23.4 py310hd5efca6_0
[conda] numpy-base 1.23.4 py310h8e6c178_0
[conda] pytorch 2.0.0.dev20221216 py3.10_cuda11.6_cudnn8.3.2_0 pytorch-nightly
[conda] pytorch-cuda 11.6 h867d48c_2 pytorch-nightly
[conda] pytorch-mutex 1.0 cuda pytorch-nightly
[conda] torchaudio 0.14.0.dev20221216 py310_cu116 pytorch-nightly
[conda] torchtriton 2.0.0+0d7e753227 py310 pytorch-nightly
[conda] torchvision 0.15.0.dev20221216 py310_cu116 pytorch-nightly
cc @ezyang @soumith @msaroufim @wconstab @ngimel @bdhirsh @mlazos @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire @mthrok
The text was updated successfully, but these errors were encountered: