In [1]:
import math
import numpy as np
import torch
T = 20
L = 1000
N = 100
np.random.seed(2)
x = np.empty((N, L), 'int64')
x[:] = np.array(range(L)) + np.random.randint(-4*T, 4*T, N).reshape(N, 1)
data = np.sin(x / 1.0 / T).astype('float64')
torch.save(data, open('traindata.pt', 'wb'))



In [2]:
from __future__ import print_function
import torch
import torch.nn as nn 
from torch.autograd import Variable
import torch.optim as optim
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

class Sequence(nn.Module):
    def __init__(self):
        super(Sequence, self).__init__()
        self.lstm1 = nn.LSTMCell(1, 51)
        self.lstm2 = nn.LSTMCell(51, 1)

    def forward(self, input, future = 0):
        outputs = []
        h_t = Variable(torch.zeros(input.size(0), 51).double(), requires_grad=False)
        c_t = Variable(torch.zeros(input.size(0), 51).double(), requires_grad=False)
        h_t2 = Variable(torch.zeros(input.size(0), 1).double(), requires_grad=False)
        c_t2 = Variable(torch.zeros(input.size(0), 1).double(), requires_grad=False)

        for i, input_t in enumerate(input.chunk(input.size(1), dim=1)):
            h_t, c_t = self.lstm1(input_t, (h_t, c_t))
            h_t2, c_t2 = self.lstm2(c_t, (h_t2, c_t2))
            outputs += [c_t2]
        for i in range(future):# if we should predict the future
            h_t, c_t = self.lstm1(c_t2, (h_t, c_t))
            h_t2, c_t2 = self.lstm2(c_t, (h_t2, c_t2))
            outputs += [c_t2]
        outputs = torch.stack(outputs, 1).squeeze(2)
        return outputs



if __name__ == '__main__':
    # set ramdom seed to 0
    np.random.seed(0)
    torch.manual_seed(0)
    # load data and make training set
    data = torch.load('traindata.pt')
    input = Variable(torch.from_numpy(data[3:, :-1]), requires_grad=False)
    target = Variable(torch.from_numpy(data[3:, 1:]), requires_grad=False)
    # build the model
    seq = Sequence()
    seq.double()
    criterion = nn.MSELoss()
    # use LBFGS as optimizer since we can load the whole data to train
    optimizer = optim.LBFGS(seq.parameters())
    #begin to train
    for i in range(15):
        print('STEP: ', i)
        def closure():
            optimizer.zero_grad()
            out = seq(input)
            loss = criterion(out, target)
            print('loss:', loss.data.numpy()[0])
            loss.backward()
            return loss
        optimizer.step(closure)
        # begin to predict
        future = 1000
        pred = seq(input[:3], future = future)
        y = pred.data.numpy()
        # draw the result
        plt.figure(figsize=(30,10))
        plt.title('Predict future values for time sequences\n(Dashlines are predicted values)', fontsize=30) 
        plt.xlabel('x', fontsize=20)
        plt.ylabel('y', fontsize=20)
        plt.xticks(fontsize=20)
        plt.yticks(fontsize=20)
        def draw(yi, color):
            plt.plot(np.arange(input.size(1)), yi[:input.size(1)], color, linewidth = 2.0)
            plt.plot(np.arange(input.size(1), input.size(1) + future), yi[input.size(1):], color + ':', linewidth = 2.0)
        draw(y[0], 'r')
        draw(y[1], 'g')
        draw(y[2], 'b')
        plt.savefig('predict%d.pdf'%i)
        plt.close()



STEP:  0


loss: 1.09054476065


loss: 0.826498360191


loss: 0.742258943851


loss: 0.46692670559


loss: 0.331289575869


loss: 0.236026240101


loss: 0.141259279307


loss: 0.0851531255285


loss: 0.0541043081422


loss: 0.040030847713


loss: 0.0348454304112


loss: 0.0298759609175


loss: 0.0270200336839


loss: 0.0249405874131


loss: 0.0236001364739


loss: 0.0191041066229


loss: 0.0133801897465


loss: 0.011366287259


loss: 0.00983567643843


loss: 0.00893716252079


STEP:  1


loss: 0.00857080080665


loss: 0.00803037675458


loss: 0.00710759490784


loss: 0.00924047828183


loss: 0.00576260438147


loss: 0.00500472024838


loss: 0.0213943098216


loss: 0.00430650519541


loss: 0.00388057702213


loss: 0.00434915494094


loss: 0.00298646983474


loss: 0.00266532448006


loss: 0.00259056949056


loss: 0.00245869588709


loss: 0.00244078378996


loss: 0.00239072969226


loss: 0.00233532586374


loss: 0.00224324638565


loss: 0.00193656497951


loss: 0.0203101774647


STEP:  2


loss: 0.00189117719409


loss: 0.00187708194707


loss: 0.00184734283652


loss: 0.00182429561972


loss: 0.00178668111921


loss: 0.00175956313739


loss: 0.00173671048036


loss: 0.00172798871144


loss: 0.00170404503125


loss: 0.00164831568901


loss: 0.00153444724131


loss: 0.00146448434996


loss: 0.00290571262104


loss: 0.00141937563372


loss: 0.00140850693287


loss: 0.00138915262341


loss: 0.00138241696371


loss: 0.00138046501639


loss: 0.00137327773756


loss: 0.00136058322738


STEP:  3


loss: 0.00133259520057


loss: 0.00126648678394


loss: 0.00142157183007


loss: 0.00132635507086


loss: 0.0013678727881


loss: 0.00114130025451


loss: 0.00110063038452


loss: 0.00103065551277


loss: 0.00102210559892


loss: 0.000941988650317


loss: 0.000857952733186


loss: 0.254636914979


loss: 10.8591828103


loss: 69.647614226


loss: 66.3576939207


loss: 49.7292864228


loss: 46.6363800139


loss: 37.2485869865


loss: 32.0848215053


loss: 30.5150727925


STEP:  4


loss: 4990.62313119


loss: 13774.1064247


loss: 275348.422344


loss: 308934.682728


loss: 322473.338302


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


STEP:  5


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


STEP:  6


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


STEP:  7


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


STEP:  8


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


STEP:  9


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.613281


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


STEP:  10


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


STEP:  11


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


STEP:  12


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


STEP:  13


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


STEP:  14


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328


loss: 324745.61328
