diff --git a/time_sequence_prediction/train.py b/time_sequence_prediction/train.py index 4829a93ec4..06aa1bd9a3 100644 --- a/time_sequence_prediction/train.py +++ b/time_sequence_prediction/train.py @@ -42,6 +42,8 @@ def forward(self, input, future = 0): data = torch.load('traindata.pt') input = Variable(torch.from_numpy(data[3:, :-1]), requires_grad=False) target = Variable(torch.from_numpy(data[3:, 1:]), requires_grad=False) + test_input = Variable(torch.from_numpy(data[:3, :-1]), requires_grad=False) + test_target = Variable(torch.from_numpy(data[:3, 1:]), requires_grad=False) # build the model seq = Sequence() seq.double() @@ -61,7 +63,9 @@ def closure(): optimizer.step(closure) # begin to predict future = 1000 - pred = seq(input[:3], future = future) + pred = seq(test_input, future = future) + loss = criterion(pred[:, :-future], test_target) + print('test loss:', loss.data.numpy()[0]) y = pred.data.numpy() # draw the result plt.figure(figsize=(30,10))