In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

#### We will use the MNIST dataset. 
An image in the MNIST dataset is a 28x28 image. So, 28 rows and 28 columns. WE'll consider each of the 28 rows as a single time step in a sequence. 
So, sequence length will be 28(for each row) and input length will be 28 (for each column).

In [2]:
# Hyperparameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_size = 28
sequence_length = 28
hidden_size = 128  # Number of features in the hidden state
num_layers = 2  # Number of stacked recurrent layers
num_classes = 10
batch_size = 100
num_epochs = 5
learning_rate = 0.001

print(f"Running on device: {device}")

Running on device: cuda


In [3]:
train_set = torchvision.datasets.MNIST(root = "./data3", train = True, download = True, transform = transforms.ToTensor())

test_set = torchvision.datasets.MNIST(root = "./data3", train = False, download = False, transform = transforms.ToTensor())

In [4]:
train_loader = DataLoader(train_set, batch_size, shuffle = True)

test_loader = DataLoader(test_set, batch_size, shuffle = False)

In [5]:
class RecurrentNet(nn.Module):
    def __init__(self, model_type, input_size, hidden_size, num_layers, num_classes):
        """
        A flexible recurrent neural network model.
        
        Args:
            model_type (str): The type of recurrent layer to use ('RNN', 'LSTM', 'GRU')
            input_size (int): The number of expected features in the input x..
            hidden_size (int): The number of features in the hidden state h.
            num_layers (int): Number of recurrent layers.
            num_classes (int): Number of output classes.
        """
        super(RecurrentNet, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.model_type = model_type
        
        # Define the recurrent layer based on the model_type
        if model_type == "LSTM":
            self.recurrent_layer = nn.LSTM(input_size = input_size, hidden_size = hidden_size, num_layers = num_layers, batch_first = True)
            
        elif model_type == "GRU":
            self.recurrent_layer = nn.GRU(input_size = input_size, hidden_size = hidden_size, num_layers = num_layers, batch_first = True)
            
        elif model_type == "RNN":
            self.recurrent_layer = nn.RNN(input_size = input_size, hidden_size = hidden_size, num_layers = num_layers, batch_first = True)
            
        else:
            raise ValueError("Invalid model_type. Choose from 'RNN', 'LSTM', 'GRU'.")
        
        # A fully connected layer to map the final hidden state to the class scores
        self.fc = nn.Linear(in_features = hidden_size, out_features = num_classes)
        
    def forward(self, x):
        # Reshape Input
        # The input "x" has shape [batch_size, 1, 28, 28]
        # We need to reshape it to [batch_size, seq_length, input_size] which is [100, 28, 28] for our case
        x = x.squeeze(1)  # Shape: [batch_size, 28, 28]
        
        # Initialize hidden state
        # Initialize hidden state with zeros
        # The shape is (num_layers, batch_size, hidden_size)
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        
        # For LSTMs, we also need to initialize the cell state
        if self.model_type == "LSTM":
            c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
            hidden_state = (h0, c0)
        else:
            hidden_state = h0
            
        # Forward Propagate through the recurrent layer
        # Out Shape: (batch_size, seq_length, hidden_size) --> Contains the output of the last layer for each time step
        # hidden_state is the final hidden state
        out, hidden_state = self.recurrent_layer(x, hidden_state)
        
        # Decode the hidden state of the last time step
        # We only care about the output of the very last time step to make our prediction.
        # 'out[:, -1, :]' selects the hidden state of the last time step for all batches.
        out = self.fc(out[:, -1, :])
        return out

In [6]:
def train_and_evaluate(model, model_name):
    print(f"--- Training {model_name} ---")
    
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Training loop
    n_total_steps = len(train_loader)
    for epoch in range(num_epochs):
        for i, (images, labels) in enumerate(train_loader):
            # Move tensors to the configured device
            images = images.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (i+1) % 200 == 0:
                print (f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.4f}')

    # Test the model
    with torch.no_grad():
        n_correct = 0
        n_samples = 0
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            
            # max returns (value, index)
            _, predicted = torch.max(outputs.data, 1)
            n_samples += labels.size(0)
            n_correct += (predicted == labels).sum().item()

    acc = 100.0 * n_correct / n_samples
    print(f'Accuracy of the {model_name} on the 10000 test images: {acc:.2f} %')
    print("-" * 30)

In [7]:
# --- Simple RNN ---
rnn_model = RecurrentNet('RNN', input_size, hidden_size, num_layers, num_classes).to(device)
train_and_evaluate(rnn_model, "Simple RNN")

# --- GRU ---
gru_model = RecurrentNet('GRU', input_size, hidden_size, num_layers, num_classes).to(device)
train_and_evaluate(gru_model, "GRU")

# --- LSTM ---
lstm_model = RecurrentNet('LSTM', input_size, hidden_size, num_layers, num_classes).to(device)
train_and_evaluate(lstm_model, "LSTM")

--- Training Simple RNN ---
Epoch [1/5], Step [200/600], Loss: 0.5506
Epoch [1/5], Step [400/600], Loss: 0.3888
Epoch [1/5], Step [600/600], Loss: 0.3115
Epoch [2/5], Step [200/600], Loss: 0.2869
Epoch [2/5], Step [400/600], Loss: 0.1837
Epoch [2/5], Step [600/600], Loss: 0.2794
Epoch [3/5], Step [200/600], Loss: 0.1549
Epoch [3/5], Step [400/600], Loss: 0.2392
Epoch [3/5], Step [600/600], Loss: 0.1022
Epoch [4/5], Step [200/600], Loss: 0.0831
Epoch [4/5], Step [400/600], Loss: 0.2049
Epoch [4/5], Step [600/600], Loss: 0.1817
Epoch [5/5], Step [200/600], Loss: 0.2457
Epoch [5/5], Step [400/600], Loss: 0.0835
Epoch [5/5], Step [600/600], Loss: 0.0229
Accuracy of the Simple RNN on the 10000 test images: 96.31 %
------------------------------
--- Training GRU ---
Epoch [1/5], Step [200/600], Loss: 0.4254
Epoch [1/5], Step [400/600], Loss: 0.2110
Epoch [1/5], Step [600/600], Loss: 0.0559
Epoch [2/5], Step [200/600], Loss: 0.0862
Epoch [2/5], Step [400/600], Loss: 0.1520
Epoch [2/5], Step [