## Implement an LSTM Model

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt


In [2]:
# Generate synthetic sequential data
torch.manual_seed(42)
sequence_length = 10
num_samples = 100

# Create a sine wave dataset
X = torch.linspace(0, 4 * 3.14159, steps=num_samples).unsqueeze(1) + torch.randn(num_samples, 1)
y = torch.sin(X)
print(X.shape, y.shape)

torch.Size([100, 1]) torch.Size([100, 1])


In [3]:
in_seq, out_seq = [], []
for i in range(len(y)-sequence_length):
    # print(y[i:i+sequence_length].shape)
    in_seq.append(y[i:i+sequence_length])
    out_seq.append(y[i + sequence_length])


In [11]:
X_seq, y_seq = torch.stack(in_seq, dim=0), torch.stack(out_seq, dim=0)
print(X_seq.shape, y_seq.shape)


torch.Size([90, 10, 1]) torch.Size([90, 1])


In [12]:
class LSTMModel(nn.Module):
    def __init__(self, input_dim, hidden_units):
        super().__init__()
        weights_biases_init = lambda : (nn.Parameter(torch.randn(input_dim, hidden_units)),
                                        nn.Parameter(torch.randn(hidden_units, hidden_units)),
                                        nn.Parameter(torch.zeros(hidden_units)))
        self.input_dim = input_dim
        self.hidden_units = hidden_units
        self.Wxi, self.Whi, self.bi = weights_biases_init()
        self.Wxf, self.Whf, self.bf = weights_biases_init()
        self.Wxo, self.Who, self.bo = weights_biases_init()
        self.Wxc, self.Whc, self.bc = weights_biases_init()
        self.fc = nn.Linear(hidden_units, 1)
        
    def forward(self, inputs, H_C=None):
        if not H_C:
            H = torch.randn(inputs.shape[0], self.hidden_units)
            C = torch.randn(inputs.shape[0], self.hidden_units)
        else:
            H, C = H_C
        outputs = []
        for X in inputs:  
            # print(X.shape, self.Wxi.shape, self.Whi.shape, self.bi.shape)  
            self.It = torch.sigmoid(torch.matmul(X, self.Wxi) + torch.matmul(H, self.Whi) + self.bi)
            self.Ft = torch.sigmoid(torch.matmul(X, self.Wxf) + torch.matmul(H, self.Whf) + self.bf)
            self.Ot = torch.sigmoid(torch.matmul(X, self.Wxo) + torch.matmul(H, self.Who) + self.bo)
            self.Ct = torch.tanh(torch.matmul(X, self.Wxc) + torch.matmul(H, self.Whc) + self.bc)
            C = self.Ft * C + self.It * self.Ct
            H = self.Ot * torch.tanh(self.Ct)
            print(H.shape)
            outputs.append(H)
        pred = self.fc(torch.stack(outputs))
        print(pred.shape)
        return pred, (H, C)
        
    

In [13]:
# Define the LSTM Model
class LSTMModel_inbuilt(nn.Module):
    def __init__(self):
        super(LSTMModel_inbuilt, self).__init__()
        self.lstm = nn.LSTM(input_size=1, hidden_size=50, num_layers=1, batch_first=True)
        self.fc = nn.Linear(50, 1)

    def forward(self, x):
        out, _ = self.lstm(x)
        out = self.fc(out[:, -1, :])  # Use the last output of the LSTM
        return out

In [14]:
criterion = nn.MSELoss()
model = LSTMModel(sequence_length, 8)
model_inbuilt = LSTMModel_inbuilt()
optimizer = optim.Adam(model.parameters(), lr=0.01)
print(X_seq.shape, y_seq.shape)

torch.Size([90, 10, 1]) torch.Size([90, 1])


In [15]:
epochs = 500
for epoch in range(epochs):
    state = None
    pred = model_inbuilt(X_seq)
    loss = criterion(pred, y_seq)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if epoch % 50 == 0:
        print(f"Loss: {loss.item()}")

Loss: 0.45953768491744995
Loss: 0.45953768491744995
Loss: 0.45953768491744995
Loss: 0.45953768491744995
Loss: 0.45953768491744995
Loss: 0.45953768491744995
Loss: 0.45953768491744995
Loss: 0.45953768491744995
Loss: 0.45953768491744995
Loss: 0.45953768491744995


In [None]:
# Convert to numpy for plotting
with torch.no_grad():
    pred_for_vis, _ = model(X_seq) 
print(pred_for_vis.shape)
pred_np = pred_for_vis.squeeze(-1).squeeze(-1).detach().cpu().numpy() 
print(pred_np.shape)
y_np   = y_seq.squeeze(-1).squeeze(-1).detach().cpu().numpy()

plt.figure(figsize=(8, 4))
plt.plot(y_np, label='Ground Truth')
plt.plot(pred_np, label='Predictions')
plt.title("Sine Wave Fit")
plt.legend()
plt.show()

In [None]:

plt.figure()
