In [None]:
import torch
from torch import nn

class My_LSTM_cell(torch.nn.Module):
    """
    A simple LSTM cell network for educational AI-summer purposes
    """
    def __init__(self, input_length=10, hidden_length=20):
        super(My_LSTM_cell, self).__init__()
        self.input_length = input_length
        self.hidden_length = hidden_length

        # forget gate components
        # 1. DEFINE FORGET GATE COMPONENTS
        self.linear_forgot_w1 = nn.Linear(self.input_length, self.hidden_length, bias=True)
        self.linear_forgot_r1 = nn.Linear(self.hidden_length, self.hidden_length, bias=False)
        self.sigmoid_forgot = nn.Sigmoid()

        # input gate components
        self.linear_input_w2 = nn.Linear(self.input_length, self.hidden_length, bias=True)
        self.linear_input_r2 = nn.Linear(self.hidden_length, self.hidden_length, bias=False)
        self.sigmoid_input = nn.Sigmoid()

        # cell memory components
        # 2. DEFINE CELL MEMORY COMPONENTS
        self.linear_memory_w3 = nn.Linear(self.input_length, self.hidden_length, bias=True)
        self.linear_memory_r3 = nn.Linear(self.hidden_length, self.hidden_length, bias=False)
        self.activation_memory = nn.Tanh()

        # out gate components
        # 3. DEFINE OUT GATE COMPONENTS
        self.linear_out_w4 = nn.Linear(self.input_length, self.hidden_length, bias=True)
        self.linear_out_r4 = nn.Linear(self.hidden_length, self.hidden_length, bias=False)
        self.sigmoid_out = nn.Sigmoid()

        # final output
        # 4. DEFINE OUTPUT
        self.activation_final = nn.Tanh() 

    def forget(self, x, h):
        # 5. FORGET GATE
        x_f = self.linear_forgot_w1(x)
        h_f = self.linear_forgot_r1(h)
        return self.sigmoid_forgot(x_f + h_f)
      

    def input_gate(self, x, h):

        # input gate
        x_i = self.linear_input_w2(x)
        h_i = self.linear_input_r2(h)
        i = self.sigmoid_input(x_i + h_i)
        return i

    def cell_memory_gate(self, i, f, x, h, c_prev):
        # 6. CELL MEMORY GATE
        x_c = self.linear_memory_w3(x)
        h_c = self.linear_memory_r3(h)
        
        return f*c_prev + i * self.activation_memory(x_c + h_c)
       

    def out_gate(self, x, h):
        # 7. OUT GATE
        x_o = self.linear_out_w4(x)
        h_o = self.linear_out_r4(h)
        
        return self.sigmoid_out(x_o + h_o)
       
    def forward(self, x, tuple_in ):
        (h, c_prev) = tuple_in
        # Equation 1. input gate
        i = self.input_gate(x, h)

        # Equation 2. forget gate
        f = self.forget(x, h)

        # Equation 3. updating the cell memory
        c_next = self.cell_memory_gate(i, f, x, h,c_prev)

        # Equation 4. calculate the main output gate
        o = self.out_gate(x, h)

        # Equation 5. produce next hidden output
        h_next = o * self.activation_final(c_next)

        return h_next, c_next

# Train LSTM (Sine Wave)

In [1]:
import random
import numpy as np
import torch
from torch import nn
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

import torch.optim as optim

#set seed to be able to replicate the resutls
seed = 172
random.seed(seed)
torch.manual_seed(seed)

def generate_sin_wave_data():
    T = 20
    L = 1000
    N = 200

    x = np.empty((N, L), 'int64')
    x[:] = np.array(range(L)) + np.random.randint(-4 * T, 4 * T, N).reshape(N, 1)
    data = np.sin(x / 1.0 / T).astype('float64')

    return data


The model consists of two LSTM cells as we already mentioned. The first cell receives an input of length 1 and has an output of length 51 while the second one receives an input of length 51 and has an output of length 1.

Have a closer look in the `forward` method. Did you noticed that we can generate future predictions? The first for-loop runs on all data points in the input data. The second for loop recieves thet last data point and tries to generate new ones for the next time step.

For each training epoch, we train the model and then we generate 1000 new data points.

After each epoch, we plot the predicted data points to visualize our results.

In [2]:
class Sequence(nn.Module):
    def __init__(self):
        super(Sequence, self).__init__()

        self.rnn1 = nn.LSTMCell(1, 51)
        self.rnn2 = nn.LSTMCell(51, 51)

        self.linear = nn.Linear(51, 1)

    def forward(self, input, future=0):
        outputs = []
        h_t = torch.zeros(input.size(0), 51, dtype=torch.double)
        c_t = torch.zeros(input.size(0), 51, dtype=torch.double)
        h_t2 = torch.zeros(input.size(0), 51, dtype=torch.double)
        c_t2 = torch.zeros(input.size(0), 51, dtype=torch.double)

        for i, input_t in enumerate(input.chunk(input.size(1), dim=1)):

            h_t, c_t = self.rnn1(input_t, (h_t, c_t))
            h_t2, c_t2 = self.rnn2(h_t, (h_t2, c_t2))


            output = self.linear(h_t2)
            outputs += [output]

        # if we should predict the future
        for i in range(future):

            h_t, c_t = self.rnn1(input_t, (h_t, c_t))
            h_t2, c_t2 = self.rnn2(h_t, (h_t2, c_t2))

            output = self.linear(h_t2)
            outputs += [output]

        outputs = torch.stack(outputs, 1).squeeze(2)
        return outputs


def train():
    # load data and make training set
    data = generate_sin_wave_data()
    input = torch.from_numpy(data[3:, :-1])
    target = torch.from_numpy(data[3:, 1:])
    test_input = torch.from_numpy(data[:3, :-1])
    test_target = torch.from_numpy(data[:3, 1:])

    seq = Sequence()

    seq.double()
    criterion = nn.MSELoss()
    # use LBFGS as optimizer since we can load the whole data to train
    optimizer = optim.LBFGS(seq.parameters(), lr=0.8)
    
    # begin to train
    for i in range(1):
        print('STEP: ', i)

        def closure():
            optimizer.zero_grad()
            out = seq(input)
            loss = criterion(out, target)
            print('loss:', loss.item())
            loss.backward()
            return loss

        optimizer.step(closure)
        
        # begin to predict, no need to track gradient here
        with torch.no_grad():
            future = 1000
            pred = seq(test_input, future=future)
            loss = criterion(pred[:, :-future], test_target)
            print('test loss:', loss.item())
            y = pred.detach().numpy()
            
        # draw the result
        plt.figure(figsize=(30, 10))
        plt.title('Predict future values for time sequences\n(Dashlines are predicted values)', fontsize=30)
        plt.xlabel('x', fontsize=20)
        plt.ylabel('y', fontsize=20)
        plt.xticks(fontsize=20)
        plt.yticks(fontsize=20)

        def draw(yi, color):
            plt.plot(np.arange(input.size(1)), yi[:input.size(1)], color, linewidth=2.0)
            plt.plot(np.arange(input.size(1), input.size(1) + future), yi[input.size(1):], color + ':', linewidth=2.0)

        draw(y[0], 'r')
        draw(y[1], 'g')
        draw(y[2], 'b')
        plt.show()


if __name__ == '__main__':
    generate_sin_wave_data()
    train()

STEP:  0
loss: 0.5014640293663556
loss: 0.49865314833625013
loss: 0.19580064637837746
loss: 0.3981928283278413
loss: 0.03893307468453568
loss: 0.031036286462091025
loss: 0.026284214629450284
loss: 0.025777950984984557
loss: 0.025162068537741427
loss: 0.019594104597675854
loss: 0.013388312694249952
loss: 0.008929125547073905
loss: 0.006269382987125089
loss: 0.004663699764098029
loss: 0.002585471178688211
loss: 0.0010907439706587241
loss: 0.0006870894258155633
loss: 0.0005395663233956395
loss: 0.00046112457580549816
loss: 0.000448608219614525
test loss: 0.0005382116290604778


: 