# LSTM implementation testing

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
class LSTMcell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(LSTMcell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        # weights for forget gate
        self.W_f = nn.Parameter(torch.randn(self.input_size+self.hidden_size, self.hidden_size))
        self.b_f = nn.Parameter(torch.zeros(self.hidden_size))

        # weigths for input gate
        self.W_i = nn.Parameter(torch.randn(self.input_size+self.hidden_size, self.hidden_size))
        self.b_i = nn.Parameter(torch.zeros(self.hidden_size))

        # weights for candidate gate
        self.W_c = nn.Parameter(torch.randn(self.input_size+self.hidden_size, self.hidden_size))
        self.b_c = nn.Parameter(torch.zeros(self.hidden_size))

        # weights for output gate
        self.W_o = nn.Parameter(torch.randn(self.input_size+self.hidden_size, self.hidden_size))
        self.b_o = nn.Parameter(torch.zeros(self.hidden_size))

    def forward(self, x, h_prev, c_prev):
        # x -> (batch_size, embd_size)
        # h_prev -> (batch_size, vector_size)
        # c_prev -> (batch_size, vector_size)

        x_h_prev = torch.cat((x, h_prev), dim=1)
        # --- forget gate ------------------------
        # compute ft
        ft = torch.sigmoid(
            x_h_prev @ self.W_f + self.b_f
        )

        c_state = c_prev * ft
        # ------ input gate ----------------------
        it = torch.sigmoid(
            x_h_prev @ self.W_i + self.b_i
        )
    
        ct_cap = torch.tanh(
            x_h_prev @ self.W_c + self.b_c
        )
        c_state = c_state + it * ct_cap 
        #--------output_gate --------------------
        ot = torch.sigmoid(
            x_h_prev @ self.W_o + self.b_o
        )
        h_state = torch.tanh(c_state) * ot

        return c_state, h_state

## Simple LSTM

In [4]:
class SimpleLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleLSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size

        self.lstm_cell = LSTMcell(self.input_size, self.hidden_size)
        self.fcl = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, input_seq):
        # input_seq -> (batch_size, seq_length, input_size)
        self.batch_size = input_seq.size(0)
        self.seq_length = input_seq.size(1)

        h_prev = torch.zeros((self.batch_size, self.hidden_size)).to(input_seq.device)
        c_prev = torch.zeros((self.batch_size, self.hidden_size)).to(input_seq.device)

        outputs = []
        for t_step in range(self.seq_length):
            x = input_seq[:, t_step, :]
            h_prev, c_prev = self.lstm_cell(x, h_prev, c_prev)
            y_pred = self.fcl(h_prev)            
            outputs.append(y_pred.unsqueeze(1))

        return torch.cat(outputs, axis=1)

## Dummy task 

In [58]:
def next_number_prediction(model_class):
    # ------ hyper parameter --------------
    seq_len = 5
    batch_size = 16
    input_size = 1
    hidden_size = 32
    epochs = 1000

    # ------- data preparation ------------
    # Pick random starting numbers for each batch (e.g. 0–99)
    starts = torch.randint(0, 10, (2, 1, 1), dtype=torch.float)
    # Create offset sequence [0, 1, 2, ..., seq_len-1]
    offsets = torch.arange(5, dtype=torch.float).view(1, 5, 1)
    # Add start + offsets to form sequences
    x = starts + offsets
    y = x + 1

    # ------- model config -------------
    model = model_class(input_size, hidden_size, input_size)
    loss_fn = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    
    # ------- model training -------------
    for epoch in range(epochs):
        # forward pass
        seq_out = model(x)
        
        # calculate loss
        loss = loss_fn(seq_out, y)
        
        if epoch % 100 == 0:
            print(f"epoch: {epoch} - loss: {loss.item()}")
        
        # back propagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # model testing
    print("====predictions =======")
    test_x = torch.tensor([[[1.], [2], [3], [4], [5]]])
    y_pred = model(test_x)
    for i, j in zip(test_x.flatten(), y_pred.flatten()):
        print(f" {round(i.item())} -> {round(j.item())}")

In [59]:
trained_model = next_number_prediction(SimpleLSTM)

epoch: 0 - loss: 46.84516525268555
epoch: 100 - loss: 0.03893999755382538
epoch: 200 - loss: 1.5845322423047037e-06
epoch: 300 - loss: 7.054268193890101e-12
epoch: 400 - loss: 6.707523533995563e-13
epoch: 500 - loss: 5.939568975543708e-11
epoch: 600 - loss: 7.736520046819351e-07
epoch: 700 - loss: 8.70841159600344e-12
epoch: 800 - loss: 2.283400205027597e-11
epoch: 900 - loss: 3.304146218852111e-07
 1 -> 2
 2 -> 3
 3 -> 4
 4 -> 5
 5 -> 6
