-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[jit] Support nn.GRU in script #23266
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
eb75dbb
a5300ba
ba7a908
cc77585
4278731
59795b6
d0d1c69
6c679e5
d46146c
e0aab4b
74be1ee
14fd07d
7eea76b
46fe4dc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,7 +11,6 @@ | |
| from ..._jit_internal import _parameter_list | ||
|
|
||
| _rnn_impls = { | ||
| 'GRU': _VF.gru, | ||
| 'RNN_TANH': _VF.rnn_tanh, | ||
| 'RNN_RELU': _VF.rnn_relu, | ||
| } | ||
|
|
@@ -167,12 +166,14 @@ def check_hidden_size(self, hx, expected_hidden_size, msg='Expected hidden size | |
| raise RuntimeError(msg.format(expected_hidden_size, tuple(hx.size()))) | ||
|
|
||
| def check_forward_args(self, input, hidden, batch_sizes): | ||
| # type: (Tensor, Tensor, Optional[Tensor]) -> None | ||
| self.check_input(input, batch_sizes) | ||
| expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) | ||
|
|
||
| self.check_hidden_size(hidden, expected_hidden_size) | ||
|
|
||
| def permute_hidden(self, hx, permutation): | ||
| # type: (Tensor, Optional[Tensor]) -> Tensor | ||
| if permutation is None: | ||
| return hx | ||
| return apply_permutation(hx, permutation) | ||
|
|
@@ -369,6 +370,16 @@ def __init__(self, *args, **kwargs): | |
| super(RNN, self).__init__(mode, *args, **kwargs) | ||
|
|
||
|
|
||
| # XXX: LSTM and GRU implementation is different from RNNBase, this is because: | ||
| # 1. we want to support nn.LSTM and nn.GRU in TorchScript and TorchScript in | ||
| # its current state could not support the python Union Type or Any Type. | ||
| # 2. TorchScript static typing does not allow a Function or Callable type in | ||
| # Dict values, so we have to separately call _VF instead of using _rnn_impls | ||
| # 3. This is temporary only and in the transition state that we want to make it | ||
| # on time for the release | ||
| # | ||
| # TODO: remove the overriding implementations for LSTM and GRU when TorchScript | ||
wanchaol marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # support expressing these two modules generally. | ||
| class LSTM(RNNBase): | ||
| r"""Applies a multi-layer long short-term memory (LSTM) RNN to an input | ||
| sequence. | ||
|
|
@@ -655,10 +666,66 @@ class GRU(RNNBase): | |
| >>> h0 = torch.randn(2, 3, 20) | ||
| >>> output, hn = rnn(input, h0) | ||
| """ | ||
| __overloads__ = {'forward': ['forward_packed', 'forward_tensor']} | ||
|
|
||
| def __init__(self, *args, **kwargs): | ||
| super(GRU, self).__init__('GRU', *args, **kwargs) | ||
|
|
||
| def run_impl(self, input, hx, batch_sizes): | ||
| # type: (Tensor, Tensor, Optional[Tensor]) -> Tuple[Tensor, Tensor] | ||
| if batch_sizes is None: | ||
| result = _VF.gru(input, hx, self._get_flat_weights(), self.bias, self.num_layers, | ||
| self.dropout, self.training, self.bidirectional, self.batch_first) | ||
| else: | ||
| result = _VF.gru(input, batch_sizes, hx, self._get_flat_weights(), self.bias, | ||
| self.num_layers, self.dropout, self.training, self.bidirectional) | ||
| return result | ||
|
|
||
| def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices): | ||
| # type: (Tensor, Optional[Tensor], Optional[Tensor], int, Optional[Tensor]) -> Tuple[Tensor, Tensor] # noqa | ||
| if hx is None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not so nice that all of this extra code needs to be added in, but I'm sure there are reasons
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah that's kind of unfortunate, I originally want to directly script the RNNBase instead of GRU separately, but the dict to dispatch different module is a dict(str, function), JIT dictionary does not support function/module as a value now so we could only do it separately.. |
||
| num_directions = 2 if self.bidirectional else 1 | ||
| hx = torch.zeros(self.num_layers * num_directions, | ||
| max_batch_size, self.hidden_size, | ||
| dtype=input.dtype, device=input.device) | ||
| else: | ||
| # Each batch of the hidden state should match the input sequence that | ||
| # the user believes he/she is passing in. | ||
| hx = self.permute_hidden(hx, sorted_indices) | ||
|
|
||
| self.check_forward_args(input, hx, batch_sizes) | ||
| result = self.run_impl(input, hx, batch_sizes) | ||
| output = result[0] | ||
wanchaol marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| hidden = result[1] | ||
| return output, hidden | ||
|
|
||
| @torch._jit_internal.export | ||
| def forward_packed(self, input, hx=None): | ||
| # type: (PackedSequence, Optional[Tensor]) -> Tuple[PackedSequence, Tensor] | ||
| input, batch_sizes, sorted_indices, unsorted_indices = input | ||
| max_batch_size = batch_sizes[0] | ||
| max_batch_size = int(max_batch_size) | ||
| output, hidden = self.forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices) | ||
| output = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices) | ||
| return output, self.permute_hidden(hidden, unsorted_indices) | ||
|
|
||
| @torch._jit_internal.export | ||
| def forward_tensor(self, input, hx=None): | ||
| # type: (Tensor, Optional[Tensor]) -> Tuple[Tensor, Tensor] | ||
| batch_sizes = None | ||
| max_batch_size = input.size(0) if self.batch_first else input.size(1) | ||
| sorted_indices = None | ||
| unsorted_indices = None | ||
| output, hidden = self.forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices) | ||
| return output, self.permute_hidden(hidden, unsorted_indices) | ||
|
|
||
| @torch._jit_internal.ignore | ||
| def forward(self, input, hx=None): | ||
| if isinstance(input, PackedSequence): | ||
| return self.forward_packed(input, hx) | ||
| else: | ||
| return self.forward_tensor(input, hx) | ||
|
|
||
|
|
||
| class RNNCellBase(Module): | ||
| __constants__ = ['input_size', 'hidden_size', 'bias'] | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.