-
Notifications
You must be signed in to change notification settings - Fork 730
Description
🐛 Describe the bug
An attempt to export wav2vec model with this pytorch guide fails. In Wav2Vec2 in forward call something wrong with layer_drop. If to take only main part from logs:
x = self.encoder(x, lengths)
x = self.transformer(x, attention_mask=mask)
if not (self.training and torch.rand(1).item() <= self.layer_drop):
And this is with all dropout = 0 by default. Code for reproduction
from torchaudio.models import wav2vec2_xlsr_300m
import torch
torch_model = wav2vec2_xlsr_300m()
example_inputs = (torch.rand(1, 64600),)
onnx_program = torch.onnx.export(torch_model, example_inputs, dynamo=True)
Versions
PyTorch version: 2.8.0+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04.2) 11.4.0
Versions of relevant libraries:
[pip3] numpy==2.3.3
[pip3] nvidia-cublas-cu12==12.8.4.1
[pip3] nvidia-cuda-cupti-cu12==12.8.90
[pip3] nvidia-cuda-nvrtc-cu12==12.8.93
[pip3] nvidia-cuda-runtime-cu12==12.8.90
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cufft-cu12==11.3.3.83
[pip3] nvidia-curand-cu12==10.3.9.90
[pip3] nvidia-cusolver-cu12==11.7.3.90
[pip3] nvidia-cusparse-cu12==12.5.8.93
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-nccl-cu12==2.27.3
[pip3] nvidia-nvjitlink-cu12==12.8.93
[pip3] nvidia-nvtx-cu12==12.8.90
[pip3] torch==2.8.0
[pip3] torch-tb-profiler==0.4.3
[pip3] torchaudio==2.8.0
[pip3] triton==3.4.0
[pip3] tritonclient==2.36.0
[conda] numpy 2.3.3 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.8.4.1 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.8.90 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.8.93 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.8.90 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.10.2.21 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.3.3.83 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.9.90 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.7.3.90 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.5.8.93 pypi_0 pypi
[conda] nvidia-cusparselt-cu12 0.7.1 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.27.3 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.8.93 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.8.90 pypi_0 pypi
[conda] torch 2.8.0 pypi_0 pypi
[conda] torch-tb-profiler 0.4.3 pypi_0 pypi
[conda] torchaudio 2.8.0 pypi_0 pypi
[conda] triton 3.4.0 pypi_0 pypi
[conda] tritonclient 2.36.0 pypi_0 pypi