## Implement an LSTM Model

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt


In [32]:
# Generate synthetic sequential data
torch.manual_seed(42)
sequence_length = 10
num_samples = 100

# Create a sine wave dataset
X = torch.linspace(0, 4 * 3.14159, steps=num_samples).unsqueeze(1) + torch.randn(num_samples, 1)
y = torch.sin(X)
print(X.shape, y.shape)

torch.Size([100, 1]) torch.Size([100, 1])


In [37]:
in_seq, out_seq = [], []
for i in range(len(y)-sequence_length):
    # print(y[i:i+sequence_length].shape)
    in_seq.append(y[i:i+sequence_length])
    out_seq.append(y[i + sequence_length])


In [40]:
X_seq, y_seq = torch.stack(in_seq, dim=0), torch.stack(out_seq, dim=0)
print(X_seq.shape, y_seq.shape)


torch.Size([90, 10, 1]) torch.Size([90, 1])


In [None]:
class LSTMModel(nn.Module):
    def __init__(self, input_dim, hidden_units):
        super().__init__()
        weights_biases_init = lambda : (nn.Parameter(torch.randn(input_dim, hidden_units)),
                                        nn.Parameter(torch.randn(hidden_units, hidden_units)),
                                        nn.Parameter(torch.zeros(hidden_units)))
        self.input_dim = input_dim
        self.hidden_units = hidden_units
        self.Wxi, self.Whi, self.bi = weights_biases_init()
        self.Wxf, self.Whf, self.bf = weights_biases_init()
        self.Wxo, self.Who, self.bo = weights_biases_init()
        self.Wxc, self.Whc, self.bc = weights_biases_init()
        
    def forward(self, X, H_C=None):
        if not H_C:
            H = torch.randn(X.shape[0], self.hidden_units)
            C = torch.randn(X.shape[0], self.input_dim)
            self.It = torch.sigmoid(torch.matmul(X, self.Wxi) + torch.matmul(H, self.Whi) + self.bi)
            self.Ft = torch.sigmoid(torch.matmul(X, self.Wxf) + torch.matmul(H, self.Whf) + self.bf)
            self.Ot = torch.sigmoid(torch.matmul(X, self.Wxo) + torch.matmul(H, self.Who) + self.bo)
            self.Ct = torch.tanh(torch.matmul(X, self.Wxc) + torch.matmul(H, self.Whc) + self.bc)
            C = torch.matmul(self.Ft, C) + torch.matmul(self.It, self.Ct)
            H = torch.matmul(self.Ot, torch.tanh(self.Ct))
        