-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Closed
Labels
oncall: jitAdd this issue/PR to JIT oncall triage queueAdd this issue/PR to JIT oncall triage queue
Description
🐛 Bug
Using torch.jit.trace
on a function that feeds inputs to LSTM produces a graph with a hard-coded batch size. This makes it impossible to use the traced function with inputs of different batch sizes.
To Reproduce
Steps to reproduce the behavior:
import torch
import torch.nn as nn
a = torch.randn((8, 5, 30))
a_lengths = torch.randint(low=1, high=a.shape[1], size=(len(a),))
a_lengths, _ = torch.sort(a_lengths, descending=True)
b = torch.randn((16, 5, 30))
b_lengths = torch.randint(low=1, high=b.shape[1], size=(len(b),))
b_lengths, _ = torch.sort(b_lengths, descending=True)
lstm = nn.LSTM(30, 25, batch_first=True)
def feed_rnn(X: torch.Tensor, sorted_lengths: torch.Tensor) -> torch.Tensor:
X = nn.utils.rnn.pack_padded_sequence(X, sorted_lengths, batch_first=True)
X, hidden_states = lstm(X)
# pad_packed_sequence returns a tuple of sequences and lengths
X, sorted_lengths = nn.utils.rnn.pad_packed_sequence(X, batch_first=True)
return X
func = torch.jit.trace(feed_rnn, (b, b_lengths), check_inputs=[(b, b_lengths), (a, a_lengths)])
This results in:
RuntimeError:
The size of tensor a (8) must match the size of tensor b (16) at non-singleton dimension 0
...
ERROR: Graphs differed across invocations!
Graph diff:
graph(%input.1 : Tensor
%lengths.1 : Tensor) {
%2 : Device = prim::Constant[value="cpu"]()
%3 : int = prim::Constant[value=4]()
%4 : bool = prim::Constant[value=0]()
%5 : bool = prim::Constant[value=0]()
%lengths : Tensor = aten::to(%lengths.1, %2, %3, %4, %5)
%7 : bool = prim::Constant[value=1]()
%input : Tensor, %batch_sizes : Tensor = aten::_pack_padded_sequence(%input.1, %lengths, %7)
%10 : int = prim::Constant[value=1](), scope: LSTM
- %11 : int = prim::Constant[value=16](), scope: LSTM
? ^^
+ %11 : int = prim::Constant[value=8](), scope: LSTM
? ^
Expected behavior
Both inputs with different batch sizes should work.
Environment
Collecting environment information...
PyTorch version: 1.0.0
Is debug build: No
CUDA used to build PyTorch: 9.0.176
OS: Ubuntu 18.04.1 LTS
GCC version: (Homebrew gcc 5.5.0_4) 5.5.0
CMake version: Could not collect
Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 9.1.85
GPU models and configuration: GPU 0: GeForce GTX 1050
Nvidia driver version: 396.54
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.1.1
/usr/local/cuda-9.1/targets/x86_64-linux/lib/libcudnn.so.7.1.1
/usr/local/cuda-9.1/targets/x86_64-linux/lib/libcudnn_static.a
Versions of relevant libraries:
[pip] Could not collect
[conda] cuda92 1.0 0 pytorch
[conda] pytorch 1.0.0 py3.7_cuda9.0.176_cudnn7.4.1_1 pytorch
[conda] torch 1.0.0 <pip>
[conda] torchfile 0.1.0 py_0 conda-forge
[conda] torchfile 0.1.0 <pip>
[conda] torchnet 0.0.4 <pip>
[conda] torchtext 0.3.1 <pip>
[conda] torchvision 0.2.1 py37_1000 conda-forge
Metadata
Metadata
Assignees
Labels
oncall: jitAdd this issue/PR to JIT oncall triage queueAdd this issue/PR to JIT oncall triage queue