diff --git a/time_sequence_prediction/train.py b/time_sequence_prediction/train.py index f4ae42b664..9bf64bad9e 100644 --- a/time_sequence_prediction/train.py +++ b/time_sequence_prediction/train.py @@ -23,12 +23,12 @@ def forward(self, input, future = 0): for i, input_t in enumerate(input.chunk(input.size(1), dim=1)): h_t, c_t = self.lstm1(input_t, (h_t, c_t)) - h_t2, c_t2 = self.lstm2(c_t, (h_t2, c_t2)) - outputs += [c_t2] + h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2)) + outputs += [h_t2] for i in range(future):# if we should predict the future - h_t, c_t = self.lstm1(c_t2, (h_t, c_t)) - h_t2, c_t2 = self.lstm2(c_t, (h_t2, c_t2)) - outputs += [c_t2] + h_t, c_t = self.lstm1(h_t2, (h_t, c_t)) + h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2)) + outputs += [h_t2] outputs = torch.stack(outputs, 1).squeeze(2) return outputs