In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters 
# input_size = 784 # 28x28
num_classes = 10
num_epochs = 2
batch_size = 100
learning_rate = 0.001

input_size = 28
sequence_length = 28
hidden_size = 128
num_layers = 2

# MNIST dataset 
train_dataset = torchvision.datasets.MNIST(root='./data', 
                                           train=True, 
                                           transform=transforms.ToTensor(),  
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='./data', 
                                          train=False, 
                                          transform=transforms.ToTensor())

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=batch_size, 
                                          shuffle=False)


# Fully connected neural network with one hidden layer
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNN, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        # -> x needs to be: (batch_size, seq, input_size)
        
        # or:
        #self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
        #self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)
        
    def forward(self, x):
        # Set initial hidden states (and cell states for LSTM)
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) 
        #c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) 
        
        # x: (n, 28, 28), h0: (2, n, 128)
        
        # Forward propagate RNN
        out, _ = self.rnn(x, h0)  
        # or:
        #out, _ = self.lstm(x, (h0,c0))  
        
        # out: tensor of shape (batch_size, seq_length, hidden_size)
        # out: (n, 28, 128)
        
        # Decode the hidden state of the last time step
        out = out[:, -1, :]
        # out: (n, 128)
         
        out = self.fc(out)
        # out: (n, 10)
        return out

model = RNN(input_size, hidden_size, num_layers, num_classes).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  

# Train the model
n_total_steps = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):  
        # origin shape: [N, 1, 28, 28]
        # resized: [N, 28, 28]
        images = images.reshape(-1, sequence_length, input_size).to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print (f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.4f}')

# Test the model
# In test phase, we don't need to compute gradients (for memory efficiency)
with torch.no_grad():
    n_correct = 0
    n_samples = 0
    for images, labels in test_loader:
        images = images.reshape(-1, sequence_length, input_size).to(device)
        labels = labels.to(device)
        outputs = model(images)
        # max returns (value ,index)
        _, predicted = torch.max(outputs.data, 1)
        n_samples += labels.size(0)
        n_correct += (predicted == labels).sum().item()

    acc = 100.0 * n_correct / n_samples
    print(f'Accuracy of the network on the 10000 test images: {acc} %')


Epoch [1/2], Step [100/600], Loss: 1.1801
Epoch [1/2], Step [200/600], Loss: 0.7050
Epoch [1/2], Step [300/600], Loss: 0.9363
Epoch [1/2], Step [400/600], Loss: 0.4060
Epoch [1/2], Step [500/600], Loss: 0.4849
Epoch [1/2], Step [600/600], Loss: 0.2999
Epoch [2/2], Step [100/600], Loss: 0.3964
Epoch [2/2], Step [200/600], Loss: 0.2736
Epoch [2/2], Step [300/600], Loss: 0.2977
Epoch [2/2], Step [400/600], Loss: 0.3016
Epoch [2/2], Step [500/600], Loss: 0.2762
Epoch [2/2], Step [600/600], Loss: 0.1917
Accuracy of the network on the 10000 test images: 93.63 %


In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters 
num_classes = 10
num_epochs = 2
learning_rate = 0.001

input_size = 28
sequence_length = 28
hidden_size = 128
num_layers = 2

# MNIST dataset 
train_dataset = torchvision.datasets.MNIST(root='./data', 
                                           train=True, 
                                           transform=transforms.ToTensor(),  
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='./data', 
                                          train=False, 
                                          transform=transforms.ToTensor())

# Data loader
batch_sizes = [100, 200]
train_loaders = {}
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=batch_sizes[0], 
                                          shuffle=False)

for batch_size in batch_sizes:
    train_loaders[batch_size] = torch.utils.data.DataLoader(dataset=train_dataset, 
                                                             batch_size=batch_size, 
                                                             shuffle=True)

# Model definition for RNN, GRU, and LSTM
class RNNModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes, batch_first=True):
        super(RNNModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=batch_first)
        self.fc = nn.Linear(hidden_size, num_classes)
        
    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) 
        out, _ = self.rnn(x, h0)  
        out = out[:, -1, :]
        out = self.fc(out)
        return out

class GRUModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes, batch_first=True):
        super(GRUModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=batch_first)
        self.fc = nn.Linear(hidden_size, num_classes)
        
    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) 
        out, _ = self.gru(x, h0)  
        out = out[:, -1, :]
        out = self.fc(out)
        return out

class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes, batch_first=True):
        super(LSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=batch_first)
        self.fc = nn.Linear(hidden_size, num_classes)
        
    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) 
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) 
        out, _ = self.lstm(x, (h0, c0))  
        out = out[:, -1, :]
        out = self.fc(out)
        return out

# Train and evaluate models
for batch_size in batch_sizes:
    for batch_first in [True, False]:
        print(f'Batch Size: {batch_size}, Batch First: {batch_first}')
        
        # Define models
        rnn_model = RNNModel(input_size, hidden_size, num_layers, num_classes, batch_first=batch_first).to(device)
        gru_model = GRUModel(input_size, hidden_size, num_layers, num_classes, batch_first=batch_first).to(device)
        lstm_model = LSTMModel(input_size, hidden_size, num_layers, num_classes, batch_first=batch_first).to(device)
        
        # Loss and optimizer
        criterion = nn.CrossEntropyLoss()
        rnn_optimizer = torch.optim.Adam(rnn_model.parameters(), lr=learning_rate)
        gru_optimizer = torch.optim.Adam(gru_model.parameters(), lr=learning_rate)
        lstm_optimizer = torch.optim.Adam(lstm_model.parameters(), lr=learning_rate)
        
        # Training loop for RNN model
        for epoch in range(num_epochs):
            for images, labels in train_loaders[batch_size]:
                images = images.reshape(-1, sequence_length, input_size).to(device)
                labels = labels.to(device)
                rnn_outputs = rnn_model(images)
                rnn_loss = criterion(rnn_outputs, labels)
                rnn_optimizer.zero_grad()
                rnn_loss.backward()
                rnn_optimizer.step()
        
        # Training loop for GRU model
        for epoch in range(num_epochs):
            for images, labels in train_loaders[batch_size]:
                images = images.reshape(-1, sequence_length, input_size).to(device)
                labels = labels.to(device)
                gru_outputs = gru_model(images)
                gru_loss = criterion(gru_outputs, labels)
                gru_optimizer.zero_grad()
                gru_loss.backward()
                gru_optimizer.step()
        
        # Training loop for LSTM model
        for epoch in range(num_epochs):
            for images, labels in train_loaders[batch_size]:
                images = images.reshape(-1, sequence_length, input_size).to(device)
                labels = labels.to(device)
                lstm_outputs = lstm_model(images)
                lstm_loss = criterion(lstm_outputs, labels)
                lstm_optimizer.zero_grad()
                lstm_loss.backward()
                lstm_optimizer.step()
        
        # Evaluation
        with torch.no_grad():
            rnn_correct = 0
            gru_correct = 0
            lstm_correct = 0
            total = 0
            for images, labels in test_loader:
                images = images.reshape(-1, sequence_length, input_size).to(device)
                labels = labels.to(device)
                rnn_outputs = rnn_model(images)
                _, rnn_predicted = torch.max(rnn_outputs.data, 1)
                rnn_correct += (rnn_predicted == labels).sum().item()
                
                gru_outputs = gru_model(images)
                _, gru_predicted = torch.max(gru_outputs.data, 1)
                gru_correct += (gru_predicted == labels).sum().item()
                
                lstm_outputs = lstm_model(images)
                _, lstm_predicted = torch.max(lstm_outputs.data, 1)
                lstm_correct += (lstm_predicted == labels).sum().item()
                
                total += labels.size(0)

            rnn_accuracy = 100 * rnn_correct / total
            gru_accuracy = 100 * gru_correct / total
            lstm_accuracy = 100 * lstm_correct / total

            print(f'RNN Accuracy: {rnn_accuracy:.2f}%, GRU Accuracy: {gru_accuracy:.2f}%, LSTM Accuracy: {lstm_accuracy:.2f}%')


Batch Size: 100, Batch First: True
RNN Accuracy: 92.10%, GRU Accuracy: 97.48%, LSTM Accuracy: 96.52%
Batch Size: 100, Batch First: False


RuntimeError: Expected hidden size (2, 28, 128), got [2, 100, 128]

In [2]:
import torch
import time

def time_measure(batch_first: bool):
    layer = torch.nn.RNN(10, 20, batch_first=batch_first)
    if batch_first:
        inputs = torch.randn(10000000, 7, 10)
    else:
        inputs = torch.randn(7, 10000000, 10)

    start = time.perf_counter()
    for chunk in torch.chunk(inputs, 100000 // 64, dim=0 if batch_first else 1):
        _, last = layer(chunk)
    return time.perf_counter() - start

print(f"Time taken for batch_first=False: {time_measure(False)}")
print(f"Time taken for batch_first=True: {time_measure(True)}")

Time taken for batch_first=False: 15.587970499997027
Time taken for batch_first=True: 10.071646099997452
