Skip to content

torch.jit.trace hardcodes batch size with packed input to LSTM #15319

@anna-hope

Description

@anna-hope

🐛 Bug

Using torch.jit.trace on a function that feeds inputs to LSTM produces a graph with a hard-coded batch size. This makes it impossible to use the traced function with inputs of different batch sizes.

To Reproduce

Steps to reproduce the behavior:

import torch
import torch.nn as nn

a = torch.randn((8, 5, 30))
a_lengths = torch.randint(low=1, high=a.shape[1], size=(len(a),))
a_lengths, _ = torch.sort(a_lengths, descending=True)

b = torch.randn((16, 5, 30))
b_lengths = torch.randint(low=1, high=b.shape[1], size=(len(b),))
b_lengths, _ = torch.sort(b_lengths, descending=True)

lstm = nn.LSTM(30, 25, batch_first=True)

def feed_rnn(X: torch.Tensor, sorted_lengths: torch.Tensor) -> torch.Tensor:
    X = nn.utils.rnn.pack_padded_sequence(X, sorted_lengths, batch_first=True)
    X, hidden_states = lstm(X)
    # pad_packed_sequence returns a tuple of sequences and lengths
    X, sorted_lengths = nn.utils.rnn.pad_packed_sequence(X, batch_first=True)
    return X

func = torch.jit.trace(feed_rnn, (b, b_lengths), check_inputs=[(b, b_lengths), (a, a_lengths)])

This results in:

RuntimeError: 
The size of tensor a (8) must match the size of tensor b (16) at non-singleton dimension 0

...

ERROR: Graphs differed across invocations!
	Graph diff:
		  graph(%input.1 : Tensor
		        %lengths.1 : Tensor) {
		    %2 : Device = prim::Constant[value="cpu"]()
		    %3 : int = prim::Constant[value=4]()
		    %4 : bool = prim::Constant[value=0]()
		    %5 : bool = prim::Constant[value=0]()
		    %lengths : Tensor = aten::to(%lengths.1, %2, %3, %4, %5)
		    %7 : bool = prim::Constant[value=1]()
		    %input : Tensor, %batch_sizes : Tensor = aten::_pack_padded_sequence(%input.1, %lengths, %7)
		    %10 : int = prim::Constant[value=1](), scope: LSTM
		-   %11 : int = prim::Constant[value=16](), scope: LSTM
		?                                    ^^
		+   %11 : int = prim::Constant[value=8](), scope: LSTM
		?                                    ^

Expected behavior

Both inputs with different batch sizes should work.

Environment

Collecting environment information...
PyTorch version: 1.0.0
Is debug build: No
CUDA used to build PyTorch: 9.0.176

OS: Ubuntu 18.04.1 LTS
GCC version: (Homebrew gcc 5.5.0_4) 5.5.0
CMake version: Could not collect

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 9.1.85
GPU models and configuration: GPU 0: GeForce GTX 1050
Nvidia driver version: 396.54
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.1.1
/usr/local/cuda-9.1/targets/x86_64-linux/lib/libcudnn.so.7.1.1
/usr/local/cuda-9.1/targets/x86_64-linux/lib/libcudnn_static.a

Versions of relevant libraries:
[pip] Could not collect
[conda] cuda92                    1.0                           0    pytorch
[conda] pytorch                   1.0.0           py3.7_cuda9.0.176_cudnn7.4.1_1    pytorch
[conda] torch                     1.0.0                     <pip>
[conda] torchfile                 0.1.0                      py_0    conda-forge
[conda] torchfile                 0.1.0                     <pip>
[conda] torchnet                  0.0.4                     <pip>
[conda] torchtext                 0.3.1                     <pip>
[conda] torchvision               0.2.1                 py37_1000    conda-forge

Metadata

Metadata

Assignees

Labels

oncall: jitAdd this issue/PR to JIT oncall triage queue

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions