-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Open
Labels
Staleoncall: jitAdd this issue/PR to JIT oncall triage queueAdd this issue/PR to JIT oncall triage queue
Description
🐛 Bug
the code part as followings:
pack_pad_x = nn_utils.rnn.pack_padded_sequence(x, seq_num.cpu(), batch_first=True, enforce_sorted=False)
output, hn =self.bilstm(pack_pad_x)
output, outlen = nn_utils.rnn.pad_packed_sequence(output, batch_first=True, total_length=sample_size)
when traced my model with torch.jit.trace I can get the torchscript model, and do inference whit the same input on scriptmodle.
but when i concat the input to large batch (from 1 to 2),I got an error:

however I can traced the model with batchsize=2, and saved to torchscipt model, but when do inference with batchsize=1, I got another error:

when I delete the nn_utils.rnn.pack_padded_sequence and nn_utils.rnn.pad_packed_sequence line in the code, everything is ok.
but the inference result is uncorrect,
I also tried to change the lstm function, and do this part with torch.jit.script
@torch.jit.script
def forward(self, x, seq_num, sample_size):
pack_pad_x = nn_utils.rnn.pack_padded_sequence(x, seq_num.cpu(), batch_first=True, enforce_sorted=False)
output, hn =self.bilstm(pack_pad_x)
output, outlen = nn_utils.rnn.pad_packed_sequence(output, batch_first=True, total_length=sample_size)
return output
I cannot do inference and the pytorch also run failed.

To Reproduce
Steps to reproduce the behavior:
Expected behavior
Environment
PyTorch version: 1.8.1
Is debug build: False
CUDA used to build PyTorch: 10.1
ROCM used to build PyTorch: N/A
OS: CentOS Linux 7 (Core) (x86_64)
GCC version: (GCC) 5.4.0
Clang version: Could not collect
CMake version: Could not collect
Python version: 3.7 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 10.0.130
GPU models and configuration: GPU 0: GeForce GTX 1080 Ti
Nvidia driver version: 460.73.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Versions of relevant libraries:
[pip3] numpy==1.18.5
[pip3] pytorch-asr==0.3.1.dev48+g166d73d.d20201130
[pip3] pytorch-memlab==0.2.4
[pip3] pytorch-pcen==0.0.1
[pip3] pytorch-ranger==0.1.1
[pip3] pytorch-wpe==0.0.0
[pip3] torch==1.8.1
[pip3] torch-complex==0.2.0
[pip3] torch-optimizer==0.0.1a17
[pip3] torchaudio==0.8.0a0+e4e171a
[pip3] torchvision==0.9.1
[pip3] warpctc-pytorch==0.1
[conda] blas 1.0 mkl https://mirrors.bfsu.edu.cn/anaconda/pkgs/main
[conda] cudatoolkit 10.1.243 h6bb024c_0 https://mirrors.bfsu.edu.cn/anaconda/pkgs/main
[conda] ffmpeg 4.3 hf484d3e_0 https://mirrors.bfsu.edu.cn/anaconda/cloud/pytorch
[conda] mkl 2020.2 256 https://mirrors.bfsu.edu.cn/anaconda/pkgs/main
[conda] mkl-service 2.3.0 py37he8ac12f_0 https://mirrors.bfsu.edu.cn/anaconda/pkgs/main
[conda] mkl_fft 1.3.0 py37h54f3939_0 https://mirrors.bfsu.edu.cn/anaconda/pkgs/main
[conda] mkl_random 1.1.1 py37h0573a6f_0 https://mirrors.bfsu.edu.cn/anaconda/pkgs/main
[conda] numpy 1.19.2 py37h54aff64_0 https://mirrors.bfsu.edu.cn/anaconda/pkgs/main
[conda] numpy-base 1.19.2 py37hfa32c7d_0 https://mirrors.bfsu.edu.cn/anaconda/pkgs/main
[conda] pytorch 1.8.1 py3.7_cuda10.1_cudnn7.6.3_0 https://mirrors.bfsu.edu.cn/anaconda/cloud/pytorch
[conda] pytorch-asr 0.3.1.dev48+g166d73d.d20201130 dev_0 <develop>
[conda] pytorch-memlab 0.2.4 dev_0 <develop>
[conda] pytorch-pcen 0.0.1 pypi_0 pypi
[conda] pytorch-ranger 0.1.1 pypi_0 pypi
[conda] pytorch-wpe 0.0.0 pypi_0 pypi
[conda] torch-complex 0.2.0 pypi_0 pypi
[conda] torch-optimizer 0.0.1a17 pypi_0 pypi
[conda] torchaudio 0.8.1 py37 https://mirrors.bfsu.edu.cn/anaconda/cloud/pytorch
[conda] torchvision 0.9.1 py37_cu101 https://mirrors.bfsu.edu.cn/anaconda/cloud/pytorch
[conda] warpctc-pytorch 0.2 pypi_0 pypi
Additional context
Metadata
Metadata
Assignees
Labels
Staleoncall: jitAdd this issue/PR to JIT oncall triage queueAdd this issue/PR to JIT oncall triage queue