## LoRA implementation with PyTorch

In [1]:
import torch
import torchvision.datasets as datasets 
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm

In [2]:
# Make torch deterministic
_ = torch.manual_seed(0)

### We will be training a network to classify MNIST digits and then fine-tune the network on a particular digit on which it doesn't perform well.

In [3]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Load the MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', 
                                train=True, 
                                download=True, 
                                transform=transform)

# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, 
                                           batch_size=10, 
                                           shuffle=True)

# Load the MNIST test set
mnist_testset = datasets.MNIST(root='./data', 
                               train=False, 
                               download=True, 
                               transform=transform)

test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

# Define the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

### Create the Neural Network to classify the digits, making it deep (more parameters) to better show the power of LoRA

In [4]:
# Create an deep neural network to classify MNIST digits

class NNet(nn.Module):
    def __init__(self, hidden_size_1=1000, hidden_size_2=2000):
        super(NNet,self).__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1) 
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2) 
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x

net = NNet().to(device)

In [5]:
net

NNet(
  (linear1): Linear(in_features=784, out_features=1000, bias=True)
  (linear2): Linear(in_features=1000, out_features=2000, bias=True)
  (linear3): Linear(in_features=2000, out_features=10, bias=True)
  (relu): ReLU()
)

A scaled down representation of the neural network

<div style="width: 1200px;">Neural Network scaled down</div>
<center><img src="./assets/nn.svg" width="1200"></center>

### Set the optimizer and loss for our network training.

In [6]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)

#### Training function

Training for 1 epoch to get some pretraining on the data, so that we can fine-tune later.

In [7]:
def train(train_loader, net, epochs=1, num_iters=None):
    for epoch in range(epochs):
        net.train()
        total_loss = 0
        iterations = 0
        train_data = tqdm(train_loader, desc='Epoch {epoch}')
        for data in train_data:
            optimizer.zero_grad()
            iterations += 1
            X, y = data
            X, y = X.to(device), y.to(device)
            output = net(X.view(-1, 28*28))
            loss = loss_fn(output, y)
            total_loss += loss.item()
            running_loss = total_loss / iterations
            train_data.set_postfix(loss=running_loss)
            loss.backward()
            optimizer.step()
            
                    
train(train_loader, net, epochs=1)

Epoch {epoch}: 100%|██████████████████████████████████████████████████| 6000/6000 [00:32<00:00, 182.93it/s, loss=0.238]


### Keep a copy of the original weights (cloning them) so later we can prove that a fine-tuning with LoRA doesn't alter the original weights


In [8]:
original_weights = {}
for name, param in net.named_parameters():
    original_weights[name] = param.clone().detach()

In [10]:
original_weights.keys()

dict_keys(['linear1.weight', 'linear1.bias', 'linear2.weight', 'linear2.bias', 'linear3.weight', 'linear3.bias'])

#### Get the number of parameters in the neural net

In [19]:
def get_param_count(neural_net):
    param_count = 0
    for name, param in neural_net.named_parameters():
        print(f'Layer: {name} Shape: {param.shape}')
        param_count += param.nelement()
    return param_count
param_count = get_param_count(neural_net=net)
print(f'\nTotal number of trainable params in Original neural net: {param_count}')

Layer: linear1.weight Shape: torch.Size([1000, 784])
Layer: linear1.bias Shape: torch.Size([1000])
Layer: linear2.weight Shape: torch.Size([2000, 1000])
Layer: linear2.bias Shape: torch.Size([2000])
Layer: linear3.weight Shape: torch.Size([10, 2000])
Layer: linear3.bias Shape: torch.Size([10])

Total number of trainable params in Original neural net: 2807010


In [20]:
# Testing on the testset to validate performance
def test(test_loader):

    with torch.no_grad():
        
        
    

SyntaxError: incomplete input (496762470.py, line 6)