In [2]:
!pip install labml --upgrade



Collecting labml
  Downloading labml-0.5.3-py3-none-any.whl.metadata (7.1 kB)
Downloading labml-0.5.3-py3-none-any.whl (94 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m94.6/94.6 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: labml
Successfully installed labml-0.5.3


In [4]:
!pip install labml_helpers


Collecting labml_helpers
  Downloading labml_helpers-0.4.89-py3-none-any.whl.metadata (1.4 kB)
Collecting labml>=0.4.158 (from labml_helpers)
  Downloading labml-0.5.3-py3-none-any.whl.metadata (7.1 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch->labml_helpers)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch->labml_helpers)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch->labml_helpers)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch->labml_helpers)
  Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch->labml_helpers)
  Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.m

In [5]:
from typing import Optional, Tuple

import torch
from torch import nn

from labml_helpers.module import Module

LabML is an open-source deep learning experimentation and monitoring library designed to help researchers and developers track, visualize, and organize their machine learning experiments with minimal overhead. At its core, LabML provides a lightweight framework that integrates seamlessly with PyTorch (and to some extent TensorFlow) and is structured around simplicity, clarity, and rapid iteration. One of its key components is the `labml_helpers` module, which provides abstractions like `Module`, `Tracker`, and `Monitors`, enabling users to define neural network components, monitor training metrics, and manage configurations in a highly readable and modular format. Unlike heavier platforms like MLflow or TensorBoard, LabML emphasizes minimal setup, requiring only a few lines of code to begin tracking training loss, accuracy, gradients, and learning rates. It also features a real-time web UI that can run locally, letting users visualize metrics and model behavior as training progresses. Additionally, LabML's logging system allows automatic capturing of hyperparameters, model structures, and performance metrics, making it ideal for rapid experimentation, reproducibility, and debugging. It’s particularly popular in educational settings and for researchers who want transparency and control without the complexity of heavyweight frameworks.


**self.hidden_lin = nn.Linear(hidden_size, 4 * hidden_size)**


This line creates a fully connected linear layer that takes in the previous hidden state hₜ₋₁ (with size hidden_size) and outputs a vector of size 4 × hidden_size.

Why 4×?
Because in an LSTM cell, we need to compute four separate vectors:

Input gate iₜ – controls what new information to add to memory.

Forget gate fₜ – controls what to remove from memory.

Output gate oₜ – controls what to output as the hidden state.

Candidate vector gₜ – proposed content to add to memory.

So instead of creating four separate Linear layers, this layer combines them into one operation that outputs all four vectors at once. Later, this combined vector is usually split into four parts internally.
This improves efficiency and keeps the code cleaner.


**self.input_lin = nn.Linear(input_size, 4 * hidden_size, bias=False)**

This is another linear layer, similar to the above, but it acts on the current input xₜ.

It takes a vector of size input_size (the current input features),

And maps it to a vector of size 4 × hidden_size (same purpose — computing iₜ, fₜ, oₜ, and gₜ),

This time, with no bias term (bias=False). That’s often done to avoid redundant biases when summing both input_lin and hidden_lin outputs later.

In [6]:
class LSTMCell(Module):
    def __init__(self,input_size:int,hidden_size:int,layer_norm:bool=False):
        super().__init__()
        self.hidden_lin=nn.Linear(hidden_size,4*hidden_size)
        self.input_lin=nn.Linear(input_size,4*hidden_size,bias=False)

        if layer_norm:
            self.layer_norm=nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(4)])
            self.layer_norm_c=nn.LayerNorm(hidden_size)
        else:
            self.layer_norm=nn.ModuleList([nn.Identity() for _ in range(4)])
            self.layer_norm_c=nn.Identity()
            
    def forward(self, x:torch.Tensor, h:torch.tensor, c:torch.tensor):
        ifgo=self.hidden_lin(h) + self.input_lin(x)
        ifgo=ifgo.chunk(4,dim=-1)

        ifgo=[self.layer_norm[i](ifgo[i] for i in range(4))]
        i,f,g,o=ifgo

        c_next=torch.sigmoid(f)*c + torch.sigmoid(i)*torch.tanh(g)
        h_next=torch.sigmoid(o)*torch.tanh(self.layer_norm_c(c_next))

        return h_next,c_next
    
        

In [9]:
class LSTM(Module):
    def __init__(self,input_size: int,hidden_size: int,n_layers: int):
        super().__init__()
        self.n_layers=n_layers
        self.hidden_size=hidden_size

        self.cells = nn.ModuleList([LSTMCell(input_size,hidden_size)]+[LSTMCell(hidden_size,hidden_size) for _  in range(n_layers -1)])


    def forward(self,x: torch.Tensor, state: Optional[Tuple[torch.Tensor,torch.Tensor]]=None):
         n_steps,batch_size=x.shape[:2]

         if state is None:
             h=[x.new_zeros(batch_size,self.hidden_size) for _ in range(self.n_layers)]
             c=[x.new_zeros(batch_size,self.hidden_size) for _ in range(self.n_layers)]
         else:
             (h,c)=state
             h,c =list(torch.unbind(h)),list(torch.unbind(c))

         out=[]
         for t in range(n_steps):
             inp=x[t]
             for layer in range(self.n_layers):
                 h[layer],c[layer]=self.cells[layer](inp,h[layer],c[layer])
                 inp=h[layer]
                 out.append(h[-1])
         
         out=torch.stack(out)
         h=torch.stack(h)
         c=torch.stack(c)

         return out, (h,c)
         
                 
         

