Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12897,6 +12897,40 @@ def forward(self, input):

self.assertEqual(eager_out, script_out)

def test_nn_GRU(self):
from torch.nn.utils.rnn import PackedSequence
seq_input = torch.nn.utils.rnn.pack_sequence([torch.randn(5, 5)])
tensor_input = torch.randn(5, 5, 5)

class SeqLengthGRU(torch.jit.ScriptModule):
def __init__(self):
super(SeqLengthGRU, self).__init__()
self.x = torch.nn.GRU(5, 5)

@torch.jit.script_method
def forward(self, input):
# type: (PackedSequence) -> Tuple[PackedSequence, Tensor]
return self.x(input)

class TensorGRU(torch.jit.ScriptModule):
def __init__(self):
super(TensorGRU, self).__init__()
self.x = torch.nn.GRU(5, 5)

@torch.jit.script_method
def forward(self, input):
# type: (Tensor) -> Tuple[Tensor, Tensor]
return self.x(input)

seq_eager_out = self.runAndSaveRNG(lambda x: torch.nn.GRU(5, 5)(x), (seq_input,))[0]
seq_script_out = self.runAndSaveRNG(lambda x: SeqLengthGRU()(x), (seq_input,))[0]
tensor_eager_out = self.runAndSaveRNG(lambda x: torch.nn.GRU(5, 5)(x), (tensor_input,))[0]
tensor_script_out = self.runAndSaveRNG(lambda x: TensorGRU()(x), (tensor_input,))[0]

self.assertEqual(seq_eager_out, seq_script_out)
self.assertEqual(tensor_eager_out, tensor_script_out)


def test_torchscript_multi_head_attn(self):
@torch.jit.script
def jit_multihead_attn_forward(query, # type: Tensor
Expand Down
69 changes: 68 additions & 1 deletion torch/nn/modules/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from ..._jit_internal import _parameter_list

_rnn_impls = {
'GRU': _VF.gru,
'RNN_TANH': _VF.rnn_tanh,
'RNN_RELU': _VF.rnn_relu,
}
Expand Down Expand Up @@ -167,12 +166,14 @@ def check_hidden_size(self, hx, expected_hidden_size, msg='Expected hidden size
raise RuntimeError(msg.format(expected_hidden_size, tuple(hx.size())))

def check_forward_args(self, input, hidden, batch_sizes):
# type: (Tensor, Tensor, Optional[Tensor]) -> None
self.check_input(input, batch_sizes)
expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)

self.check_hidden_size(hidden, expected_hidden_size)

def permute_hidden(self, hx, permutation):
# type: (Tensor, Optional[Tensor]) -> Tensor
if permutation is None:
return hx
return apply_permutation(hx, permutation)
Expand Down Expand Up @@ -369,6 +370,16 @@ def __init__(self, *args, **kwargs):
super(RNN, self).__init__(mode, *args, **kwargs)


# XXX: LSTM and GRU implementation is different from RNNBase, this is because:
# 1. we want to support nn.LSTM and nn.GRU in TorchScript and TorchScript in
# its current state could not support the python Union Type or Any Type.
# 2. TorchScript static typing does not allow a Function or Callable type in
# Dict values, so we have to separately call _VF instead of using _rnn_impls
# 3. This is temporary only and in the transition state that we want to make it
# on time for the release
#
# TODO: remove the overriding implementations for LSTM and GRU when TorchScript
# support expressing these two modules generally.
class LSTM(RNNBase):
r"""Applies a multi-layer long short-term memory (LSTM) RNN to an input
sequence.
Expand Down Expand Up @@ -655,10 +666,66 @@ class GRU(RNNBase):
>>> h0 = torch.randn(2, 3, 20)
>>> output, hn = rnn(input, h0)
"""
__overloads__ = {'forward': ['forward_packed', 'forward_tensor']}

def __init__(self, *args, **kwargs):
super(GRU, self).__init__('GRU', *args, **kwargs)

def run_impl(self, input, hx, batch_sizes):
# type: (Tensor, Tensor, Optional[Tensor]) -> Tuple[Tensor, Tensor]
if batch_sizes is None:
result = _VF.gru(input, hx, self._get_flat_weights(), self.bias, self.num_layers,
self.dropout, self.training, self.bidirectional, self.batch_first)
else:
result = _VF.gru(input, batch_sizes, hx, self._get_flat_weights(), self.bias,
self.num_layers, self.dropout, self.training, self.bidirectional)
return result

def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices):
# type: (Tensor, Optional[Tensor], Optional[Tensor], int, Optional[Tensor]) -> Tuple[Tensor, Tensor] # noqa
if hx is None:
Copy link
Contributor

@zou3519 zou3519 Jul 24, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not so nice that all of this extra code needs to be added in, but I'm sure there are reasons

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that's kind of unfortunate, I originally want to directly script the RNNBase instead of GRU separately, but the dict to dispatch different module is a dict(str, function), JIT dictionary does not support function/module as a value now so we could only do it separately..

num_directions = 2 if self.bidirectional else 1
hx = torch.zeros(self.num_layers * num_directions,
max_batch_size, self.hidden_size,
dtype=input.dtype, device=input.device)
else:
# Each batch of the hidden state should match the input sequence that
# the user believes he/she is passing in.
hx = self.permute_hidden(hx, sorted_indices)

self.check_forward_args(input, hx, batch_sizes)
result = self.run_impl(input, hx, batch_sizes)
output = result[0]
hidden = result[1]
return output, hidden

@torch._jit_internal.export
def forward_packed(self, input, hx=None):
# type: (PackedSequence, Optional[Tensor]) -> Tuple[PackedSequence, Tensor]
input, batch_sizes, sorted_indices, unsorted_indices = input
max_batch_size = batch_sizes[0]
max_batch_size = int(max_batch_size)
output, hidden = self.forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
output = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
return output, self.permute_hidden(hidden, unsorted_indices)

@torch._jit_internal.export
def forward_tensor(self, input, hx=None):
# type: (Tensor, Optional[Tensor]) -> Tuple[Tensor, Tensor]
batch_sizes = None
max_batch_size = input.size(0) if self.batch_first else input.size(1)
sorted_indices = None
unsorted_indices = None
output, hidden = self.forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
return output, self.permute_hidden(hidden, unsorted_indices)

@torch._jit_internal.ignore
def forward(self, input, hx=None):
if isinstance(input, PackedSequence):
return self.forward_packed(input, hx)
else:
return self.forward_tensor(input, hx)


class RNNCellBase(Module):
__constants__ = ['input_size', 'hidden_size', 'bias']
Expand Down
1 change: 1 addition & 0 deletions torch/nn/utils/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ def pad_sequence(sequences, batch_first=False, padding_value=0):


def pack_sequence(sequences, enforce_sorted=True):
# type: (List[Tensor], bool) -> PackedSequence
r"""Packs a list of variable length Tensors

``sequences`` should be a list of Tensors of size ``L x *``, where `L` is
Expand Down