Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.jit.trace hardcodes batch size with packed input to LSTM #15319

Closed
anna-hope opened this issue Dec 17, 2018 · 12 comments
Closed

torch.jit.trace hardcodes batch size with packed input to LSTM #15319

anna-hope opened this issue Dec 17, 2018 · 12 comments
Assignees
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue

Comments

@anna-hope
Copy link

馃悰 Bug

Using torch.jit.trace on a function that feeds inputs to LSTM produces a graph with a hard-coded batch size. This makes it impossible to use the traced function with inputs of different batch sizes.

To Reproduce

Steps to reproduce the behavior:

import torch
import torch.nn as nn

a = torch.randn((8, 5, 30))
a_lengths = torch.randint(low=1, high=a.shape[1], size=(len(a),))
a_lengths, _ = torch.sort(a_lengths, descending=True)

b = torch.randn((16, 5, 30))
b_lengths = torch.randint(low=1, high=b.shape[1], size=(len(b),))
b_lengths, _ = torch.sort(b_lengths, descending=True)

lstm = nn.LSTM(30, 25, batch_first=True)

def feed_rnn(X: torch.Tensor, sorted_lengths: torch.Tensor) -> torch.Tensor:
    X = nn.utils.rnn.pack_padded_sequence(X, sorted_lengths, batch_first=True)
    X, hidden_states = lstm(X)
    # pad_packed_sequence returns a tuple of sequences and lengths
    X, sorted_lengths = nn.utils.rnn.pad_packed_sequence(X, batch_first=True)
    return X

func = torch.jit.trace(feed_rnn, (b, b_lengths), check_inputs=[(b, b_lengths), (a, a_lengths)])

This results in:

RuntimeError: 
The size of tensor a (8) must match the size of tensor b (16) at non-singleton dimension 0

...

ERROR: Graphs differed across invocations!
	Graph diff:
		  graph(%input.1 : Tensor
		        %lengths.1 : Tensor) {
		    %2 : Device = prim::Constant[value="cpu"]()
		    %3 : int = prim::Constant[value=4]()
		    %4 : bool = prim::Constant[value=0]()
		    %5 : bool = prim::Constant[value=0]()
		    %lengths : Tensor = aten::to(%lengths.1, %2, %3, %4, %5)
		    %7 : bool = prim::Constant[value=1]()
		    %input : Tensor, %batch_sizes : Tensor = aten::_pack_padded_sequence(%input.1, %lengths, %7)
		    %10 : int = prim::Constant[value=1](), scope: LSTM
		-   %11 : int = prim::Constant[value=16](), scope: LSTM
		?                                    ^^
		+   %11 : int = prim::Constant[value=8](), scope: LSTM
		?                                    ^

Expected behavior

Both inputs with different batch sizes should work.

Environment

Collecting environment information...
PyTorch version: 1.0.0
Is debug build: No
CUDA used to build PyTorch: 9.0.176

OS: Ubuntu 18.04.1 LTS
GCC version: (Homebrew gcc 5.5.0_4) 5.5.0
CMake version: Could not collect

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 9.1.85
GPU models and configuration: GPU 0: GeForce GTX 1050
Nvidia driver version: 396.54
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.1.1
/usr/local/cuda-9.1/targets/x86_64-linux/lib/libcudnn.so.7.1.1
/usr/local/cuda-9.1/targets/x86_64-linux/lib/libcudnn_static.a

Versions of relevant libraries:
[pip] Could not collect
[conda] cuda92                    1.0                           0    pytorch
[conda] pytorch                   1.0.0           py3.7_cuda9.0.176_cudnn7.4.1_1    pytorch
[conda] torch                     1.0.0                     <pip>
[conda] torchfile                 0.1.0                      py_0    conda-forge
[conda] torchfile                 0.1.0                     <pip>
[conda] torchnet                  0.0.4                     <pip>
[conda] torchtext                 0.3.1                     <pip>
[conda] torchvision               0.2.1                 py37_1000    conda-forge
@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Dec 17, 2018
@suo suo self-assigned this Dec 17, 2018
@suo
Copy link
Member

suo commented Dec 18, 2018

Unfortunately tracing only works for fixed batch sizes here, due to some quirks in how PackedSequences are used in RNNs.

When #14831 lands, using the TorchScript implementation of LSTM should solve your problem, but there is no workaround today.

@suo
Copy link
Member

suo commented Dec 18, 2018

assigning to @driazati to track scripting LSTM

@suo suo removed their assignment Dec 18, 2018
@anna-hope
Copy link
Author

anna-hope commented Dec 18, 2018

@suo I see. This limitation complicates the production path for RNN-based models exported with JIT. Together with #15272, these issues make it very difficult if not impossible to use the JIT with RNNs.

I will wait for #14831.

@jamesr66a
Copy link
Collaborator

Fixing #16663 should fix this issue

@suo
Copy link
Member

suo commented Feb 28, 2019

This should be fixed in both the traced version (since we properly trace the constructor now
#16779) and scripted version (since we now support LSTM in TorchScript #15744). Closing.

@suo suo closed this as completed Feb 28, 2019
@shane-carroll
Copy link

Should this be fixed in the nightly, then? I'm getting the same issue with 1.0.0.dev20190301

@suo
Copy link
Member

suo commented Mar 1, 2019

yep! the changes are in the nightly

@shane-carroll
Copy link

Hmm... am I doing something wrong? I run OP's example and I get the same error (although I wrap it in an nn.Module to avoid an error message like #17583). This is the environment:

PyTorch version: 1.0.0.dev20190301
Is debug build: No
CUDA used to build PyTorch: 9.0.176

OS: CentOS Linux release 7.4.1708 (Core)
GCC version: (GCC) 6.3.1 20170216 (Red Hat 6.3.1-3)
CMake version: version 2.8.12.2

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 9.0.176
GPU models and configuration:
GPU 0: Graphics Device
GPU 1: Graphics Device
GPU 2: Graphics Device
GPU 3: Graphics Device

Nvidia driver version: 387.34
cuDNN version: /usr/local/cuda-9.0/lib64/libcudnn.so.7.1.1

Versions of relevant libraries:
[pip] numpy==1.15.4
[pip] torch-nightly==1.0.0.dev20190301
[pip] torchtext==0.4.0
[pip] torchvision==0.2.1
[conda] torch-nightly             1.0.0.dev20190301           <pip>
[conda] torchtext                 0.4.0                     <pip>
[conda] torchvision               0.2.1                     <pip>

@suo
Copy link
Member

suo commented Mar 1, 2019

oh hm, it seems like tracing still fixes the size cc @driazati.

Can you try using @torch.jit.script as I originally suggested? That should work.

@shane-carroll
Copy link

Thanks, it seems the next problem is that @torch.jit.script does not support pack_padded_sequence #16664.

@driazati
Copy link
Contributor

driazati commented Mar 1, 2019

Re-opening since the example above still is tripping a bug somewhere, it gives me a

RuntimeError: !ref.requires_grad() ASSERT FAILED at ../torch/csrc/jit/constants.cpp:29, please report a bug to PyTorch.

Looks like we have some other problems also related to tracing optional Tensors (#14455). The JIT currently supports PackedSequence by reducing it to a regular tuple, so something like this works:

class FeedRNN(torch.jit.ScriptModule):
    def __init__(self):
        super(FeedRNN, self).__init__()
        self.lstm = nn.LSTM(30, 25, batch_first=True)

    @torch.jit.script_method
    def forward(self, X):
        # type: (Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]])
        return self.lstm(X)


feed_rnn = FeedRNN()

X = nn.utils.rnn.pack_padded_sequence(b, b_lengths, batch_first=True)
X, hidden_states = feed_rnn(X)
print(X, hidden_states)

@driazati driazati reopened this Mar 1, 2019
@suo
Copy link
Member

suo commented Mar 1, 2019

@driazati related to this I think: #17583.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants