Skip to content

Commit

Permalink
fix CI
Browse files Browse the repository at this point in the history
Former-commit-id: e5ef1a4075a4bae41984e77c6aad49ddaa1ff225 [formerly f70c669a530aee77d1b961e9528e9c621bb7b5c0]
Former-commit-id: de5a9ecd936088fce8c380dbeb1c46051552d448
  • Loading branch information
zuoxingdong committed May 23, 2019
1 parent e6ea323 commit 1072780
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions lagom/networks/ln_rnn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List, Tuple

import torch
import torch.nn as nn
import torch.jit as jit
Expand All @@ -19,7 +21,7 @@ def __init__(self, input_size, hidden_size):

@jit.script_method
def forward(self, input, state):
# type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
# (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
hx, cx = state
igates = self.layernorm_i(torch.mm(input, self.weight_ih.t()))
hgates = self.layernorm_h(torch.mm(hx, self.weight_hh.t()))
Expand All @@ -44,7 +46,7 @@ def __init__(self, cell, *cell_args):

@jit.script_method
def forward(self, input, state):
# type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
# (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
inputs = input.unbind(0)
outputs = []
for i in range(len(inputs)):
Expand All @@ -64,9 +66,9 @@ def __init__(self, num_layers, layer, first_layer_args, other_layer_args):

@jit.script_method
def forward(self, input, states):
# type: (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]
# (Tensor, List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]
# List[LSTMState]: One state per layer
output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
output_states = jit.annotate(List[Tuple[torch.Tensor, torch.Tensor]], [])
output = input
# XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
i = 0
Expand Down

0 comments on commit 1072780

Please sign in to comment.