Skip to content

torch.jit.trace with pack_padded_sequence cannot do dynamic batch #68968

@wangyunxiaa

Description

@wangyunxiaa

🐛 Bug

the code part as followings:

   pack_pad_x = nn_utils.rnn.pack_padded_sequence(x, seq_num.cpu(), batch_first=True, enforce_sorted=False)
   output, hn =self.bilstm(pack_pad_x)
   output, outlen = nn_utils.rnn.pad_packed_sequence(output, batch_first=True, total_length=sample_size)

when traced my model with torch.jit.trace I can get the torchscript model, and do inference whit the same input on scriptmodle.
but when i concat the input to large batch (from 1 to 2),I got an error:
图片

however I can traced the model with batchsize=2, and saved to torchscipt model, but when do inference with batchsize=1, I got another error:
图片

when I delete the nn_utils.rnn.pack_padded_sequence and nn_utils.rnn.pad_packed_sequence line in the code, everything is ok.
but the inference result is uncorrect,

I also tried to change the lstm function, and do this part with torch.jit.script

 @torch.jit.script
    def forward(self, x, seq_num, sample_size):
        pack_pad_x = nn_utils.rnn.pack_padded_sequence(x, seq_num.cpu(), batch_first=True, enforce_sorted=False)
        output, hn =self.bilstm(pack_pad_x)
        output, outlen = nn_utils.rnn.pad_packed_sequence(output, batch_first=True, total_length=sample_size)
        return output

I cannot do inference and the pytorch also run failed.
图片

To Reproduce

Steps to reproduce the behavior:

Expected behavior

Environment

PyTorch version: 1.8.1
Is debug build: False
CUDA used to build PyTorch: 10.1
ROCM used to build PyTorch: N/A

OS: CentOS Linux 7 (Core) (x86_64)
GCC version: (GCC) 5.4.0
Clang version: Could not collect
CMake version: Could not collect

Python version: 3.7 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 10.0.130
GPU models and configuration: GPU 0: GeForce GTX 1080 Ti
Nvidia driver version: 460.73.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.18.5
[pip3] pytorch-asr==0.3.1.dev48+g166d73d.d20201130
[pip3] pytorch-memlab==0.2.4
[pip3] pytorch-pcen==0.0.1
[pip3] pytorch-ranger==0.1.1
[pip3] pytorch-wpe==0.0.0
[pip3] torch==1.8.1
[pip3] torch-complex==0.2.0
[pip3] torch-optimizer==0.0.1a17
[pip3] torchaudio==0.8.0a0+e4e171a
[pip3] torchvision==0.9.1
[pip3] warpctc-pytorch==0.1
[conda] blas                      1.0                         mkl    https://mirrors.bfsu.edu.cn/anaconda/pkgs/main
[conda] cudatoolkit               10.1.243             h6bb024c_0    https://mirrors.bfsu.edu.cn/anaconda/pkgs/main
[conda] ffmpeg                    4.3                  hf484d3e_0    https://mirrors.bfsu.edu.cn/anaconda/cloud/pytorch
[conda] mkl                       2020.2                      256    https://mirrors.bfsu.edu.cn/anaconda/pkgs/main
[conda] mkl-service               2.3.0            py37he8ac12f_0    https://mirrors.bfsu.edu.cn/anaconda/pkgs/main
[conda] mkl_fft                   1.3.0            py37h54f3939_0    https://mirrors.bfsu.edu.cn/anaconda/pkgs/main
[conda] mkl_random                1.1.1            py37h0573a6f_0    https://mirrors.bfsu.edu.cn/anaconda/pkgs/main
[conda] numpy                     1.19.2           py37h54aff64_0    https://mirrors.bfsu.edu.cn/anaconda/pkgs/main
[conda] numpy-base                1.19.2           py37hfa32c7d_0    https://mirrors.bfsu.edu.cn/anaconda/pkgs/main
[conda] pytorch                   1.8.1           py3.7_cuda10.1_cudnn7.6.3_0    https://mirrors.bfsu.edu.cn/anaconda/cloud/pytorch
[conda] pytorch-asr               0.3.1.dev48+g166d73d.d20201130           dev_0    <develop>
[conda] pytorch-memlab            0.2.4                     dev_0    <develop>
[conda] pytorch-pcen              0.0.1                    pypi_0    pypi
[conda] pytorch-ranger            0.1.1                    pypi_0    pypi
[conda] pytorch-wpe               0.0.0                    pypi_0    pypi
[conda] torch-complex             0.2.0                    pypi_0    pypi
[conda] torch-optimizer           0.0.1a17                 pypi_0    pypi
[conda] torchaudio                0.8.1                      py37    https://mirrors.bfsu.edu.cn/anaconda/cloud/pytorch
[conda] torchvision               0.9.1                py37_cu101    https://mirrors.bfsu.edu.cn/anaconda/cloud/pytorch
[conda] warpctc-pytorch           0.2                      pypi_0    pypi

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    Staleoncall: jitAdd this issue/PR to JIT oncall triage queue

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions