Skip to content

Commit

Permalink
add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
guilhermeleobas committed Nov 16, 2020
1 parent d9296e5 commit 1d07a2a
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions torch/nn/modules/rnn.py
Expand Up @@ -531,6 +531,8 @@ class LSTM(RNNBase):
def __init__(self, *args, **kwargs):
super(LSTM, self).__init__('LSTM', *args, **kwargs)

# In the future, we should prevent mypy from applying contravariance rules here.
# See torch/nn/modules/module.py::_forward_unimplemented
def check_forward_args(self, input: Tensor, hidden: Tuple[Tensor, Tensor], batch_sizes: Optional[Tensor]): # type: ignore
self.check_input(input, batch_sizes)
expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
Expand All @@ -540,17 +542,20 @@ def check_forward_args(self, input: Tensor, hidden: Tuple[Tensor, Tensor], batch
self.check_hidden_size(hidden[1], expected_hidden_size,
'Expected hidden[1] size {}, got {}')

# Same as above, see torch/nn/modules/module.py::_forward_unimplemented
def permute_hidden(self, hx: Tuple[Tensor, Tensor], permutation: Optional[Tensor]) -> Tuple[Tensor, Tensor]: # type: ignore
if permutation is None:
return hx
return apply_permutation(hx[0], permutation), apply_permutation(hx[1], permutation)

# Same as above, see torch/nn/modules/module.py::_forward_unimplemented
@overload # type: ignore
@torch._jit_internal._overload_method # noqa: F811
def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: # noqa: F811
pass

# Same as above, see torch/nn/modules/module.py::_forward_unimplemented
@overload
@torch._jit_internal._overload_method # noqa: F811
def forward(self, input: PackedSequence, hx: Optional[Tuple[Tensor, Tensor]] = None
Expand Down

0 comments on commit 1d07a2a

Please sign in to comment.