## LoRA implementation with PyTorch

In [3]:
import torch
import torchvision.datasets as datasets 
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm

In [4]:
# Make torch deterministic
_ = torch.manual_seed(0)

### We will be training a network to classify MNIST digits and then fine-tune the network on a particular digit on which it doesn't perform well.

In [5]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Load the MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', 
                                train=True, 
                                download=True, 
                                transform=transform)

# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, 
                                           batch_size=10, 
                                           shuffle=True)

# Load the MNIST test set
mnist_testset = datasets.MNIST(root='./data', 
                               train=False, 
                               download=True, 
                               transform=transform)

test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

# Define the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

### Create the Neural Network to classify the digits, making it deep (more parameters) to better show the power of LoRA

In [6]:
# Create an deep neural network to classify MNIST digits

class NNet(nn.Module):
    def __init__(self, hidden_size_1=1000, hidden_size_2=2000):
        super(NNet,self).__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1) 
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2) 
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x

net = NNet().to(device)

In [7]:
net

NNet(
  (linear1): Linear(in_features=784, out_features=1000, bias=True)
  (linear2): Linear(in_features=1000, out_features=2000, bias=True)
  (linear3): Linear(in_features=2000, out_features=10, bias=True)
  (relu): ReLU()
)

A scaled down representation of the neural network

<div style="width: 1200px;">Neural Network scaled down</div>
<center><img src="./assets/nn.svg" width="1200"></center>

### Set the optimizer and loss for our network training.

In [8]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)

#### Training function

Training for 1 epoch to get some pretraining on the data, so that we can fine-tune later.

In [9]:
def train(train_loader, net, epochs=1, num_iters=None):
    for epoch in range(epochs):
        net.train()
        total_loss = 0
        iterations = 0
        train_data = tqdm(train_loader, desc=f'Epoch {epoch}')
        if num_iters is not None:
            train_data.total = num_iters
        for data in train_data:
            optimizer.zero_grad()
            iterations += 1
            X, y = data
            X, y = X.to(device), y.to(device)
            output = net(X.view(-1, 28*28))
            loss = loss_fn(output, y)
            total_loss += loss.item()
            running_loss = total_loss / iterations
            train_data.set_postfix(loss=running_loss)
            loss.backward()
            optimizer.step()
            if num_iters is not None and iterations >= num_iters:
                return
                    
train(train_loader, net, epochs=1)

Epoch 0: 100%|█████████████████████████████████████████████████████████| 6000/6000 [01:22<00:00, 72.56it/s, loss=0.238]


### Keep a copy of the original weights (cloning them) so later we can prove that a fine-tuning with LoRA doesn't alter the original weights


In [10]:
original_weights = {}
for name, param in net.named_parameters():
    original_weights[name] = param.clone().detach()

In [11]:
original_weights.keys()

dict_keys(['linear1.weight', 'linear1.bias', 'linear2.weight', 'linear2.bias', 'linear3.weight', 'linear3.bias'])

#### Get the number of parameters in the neural net

In [12]:
def get_param_count(neural_net):
    param_count = 0
    for name, param in neural_net.named_parameters():
        print(f'Layer: {name} Shape: {param.shape}')
        param_count += param.nelement()
    return param_count
param_count = get_param_count(neural_net=net)
print(f'\nTotal number of trainable params in original neural net: {param_count}')

Layer: linear1.weight Shape: torch.Size([1000, 784])
Layer: linear1.bias Shape: torch.Size([1000])
Layer: linear2.weight Shape: torch.Size([2000, 1000])
Layer: linear2.bias Shape: torch.Size([2000])
Layer: linear3.weight Shape: torch.Size([10, 2000])
Layer: linear3.bias Shape: torch.Size([10])

Total number of trainable params in original neural net: 2807010


In [13]:
# Testing on the testset to validate performance
def test(test_loader, net):
    num_correct, total = 0, 0
    wrong_counts = [0 for _ in range(10)]
    with torch.no_grad():
        for sample in tqdm(test_loader):
            X, y = sample
            X, y = X.to(device), y.to(device)
            output = net(X)
            for idx, out in enumerate(output):
                total += 1
                if torch.argmax(out) == y[idx]:
                    num_correct += 1
                else:
                    wrong_counts[y[idx]] += 1
    
    print(f'Accuracy: {round(num_correct/total, 3)}')
    for i in range(len(wrong_counts)):
        print(f'wrong counts for the digit {i}: {wrong_counts[i]}')
       
            
test(test_loader, net)

100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:10<00:00, 93.36it/s]

Accuracy: 0.954
wrong counts for the digit 0: 12
wrong counts for the digit 1: 19
wrong counts for the digit 2: 36
wrong counts for the digit 3: 92
wrong counts for the digit 4: 29
wrong counts for the digit 5: 41
wrong counts for the digit 6: 38
wrong counts for the digit 7: 47
wrong counts for the digit 8: 35
wrong counts for the digit 9: 115





### Now Let's apply LoRA parametrization to fine-tune 
### this network for digit 9 as we saw poor performance on the digit 9


LoRA paper: [Link](https://arxiv.org/abs/2106.09685)

Paramtertrization in PyTorch: [Turtorial Link](https://pytorch.org/tutorials/intermediate/parametrizations.html)

In [14]:
class LoRAParametrization(nn.Module):
    def __init__(self, features_in, features_out, rank=1, alpha=1, device='cpu'):
        super().__init__()
        '''
        Section 4.1 of the paper: 
          We use a random Gaussian initialization for A and zero for B, so ∆W = BA is zero at the beginning of training
        '''
        self.lora_A = nn.Parameter(torch.zeros((rank,features_out)).to(device))
        self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
        nn.init.normal_(self.lora_A, mean=0, std=1)

        '''
        Section 4.1 of the paper: 
          We then scale ∆Wx by α/r , where α is a constant in r. 
          When optimizing with Adam, tuning α is roughly the same as tuning the learning rate if we scale the initialization appropriately. 
          As a result, we simply set α to the first r we try and do not tune it. 
          This scaling helps to reduce the need to retune hyperparameters when we vary r.
        '''
        self.scale = alpha / rank
        self.enabled = True

    def forward(self, original_weights):
        if self.enabled:
            # Return W + (B*A)*scale
            return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
        else:
            return original_weights

### Add the LoRA parameterization to our network

In [15]:
import torch.nn.utils.parametrize as parametrize

def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):
    # Only add the parameterization to the weight matrix, ignore the Bias

    # From section 4.2 of the paper:
    #   We limit our study to only adapting the attention weights for downstream tasks and freeze the MLP modules (so they are not trained in downstream tasks) both for simplicity and parameter-efficiency.
    #   [...]
    #   We leave the empirical investigation of [...], and biases to a future work.
    
    features_in, features_out = layer.weight.shape
    return LoRAParametrization(
        features_in, features_out, rank=rank, alpha=lora_alpha, device=device
    )

parametrize.register_parametrization(
    net.linear1, "weight", linear_layer_parameterization(net.linear1, device)
)
parametrize.register_parametrization(
    net.linear2, "weight", linear_layer_parameterization(net.linear2, device)
)
parametrize.register_parametrization(
    net.linear3, "weight", linear_layer_parameterization(net.linear3, device)
)


def enable_disable_lora(enabled=True):
    for layer in [net.linear1, net.linear2, net.linear3]:
        layer.parametrizations["weight"][0].enabled = enabled

In [16]:
total_parameters_lora = 0
total_parameters_non_lora = 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_parameters_lora += layer.parametrizations["weight"][0].lora_A.nelement() + layer.parametrizations["weight"][0].lora_B.nelement()
    total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
    print(
        f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + Lora_A: {layer.parametrizations["weight"][0].lora_A.shape} + Lora_B: {layer.parametrizations["weight"][0].lora_B.shape}'
    )

# The non-LoRA parameters count must match the original network
assert total_parameters_non_lora == param_count
print(f'Total number of parameters (original): {total_parameters_non_lora:,}')
print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}')
print(f'Parameters introduced by LoRA: {total_parameters_lora:,}')
parameters_incremment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f'Parameters incremment: {parameters_incremment:.3f}%')

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000]) + Lora_A: torch.Size([1, 784]) + Lora_B: torch.Size([1000, 1])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000]) + Lora_A: torch.Size([1, 1000]) + Lora_B: torch.Size([2000, 1])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10]) + Lora_A: torch.Size([1, 2000]) + Lora_B: torch.Size([10, 1])
Total number of parameters (original): 2,807,010
Total number of parameters (original + LoRA): 2,813,804
Parameters introduced by LoRA: 6,794
Parameters incremment: 0.242%


In [18]:
# Freeze the non-Lora parameters
for name, param in net.named_parameters():
    if 'lora' not in name:
        print(f'Freezing non-LoRA parameter {name}')
        param.requires_grad = False

# Load the MNIST dataset again, by keeping only the digit 9
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
exclude_indices = mnist_trainset.targets == 9
mnist_trainset.data = mnist_trainset.data[exclude_indices]
mnist_trainset.targets = mnist_trainset.targets[exclude_indices]
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Train the network with LoRA only on the digit 9 and only for 100 batches (hoping that it would improve the performance on the digit 9)
train(train_loader, net, epochs=1, num_iters=100)

Freezing non-LoRA parameter linear1.bias
Freezing non-LoRA parameter linear1.parametrizations.weight.original
Freezing non-LoRA parameter linear2.bias
Freezing non-LoRA parameter linear2.parametrizations.weight.original
Freezing non-LoRA parameter linear3.bias
Freezing non-LoRA parameter linear3.parametrizations.weight.original


Epoch 0:  99%|███████████████████████████████████████████████████████████▍| 99/100 [00:02<00:00, 45.64it/s, loss=0.411]


In [19]:
# Check that the frozen parameters are still unchanged by the finetuning
assert torch.all(net.linear1.parametrizations.weight.original == original_weights['linear1.weight'])
assert torch.all(net.linear2.parametrizations.weight.original == original_weights['linear2.weight'])
assert torch.all(net.linear3.parametrizations.weight.original == original_weights['linear3.weight'])

enable_disable_lora(enabled=True)
# The new linear1.weight is obtained by the "forward" function of our LoRA parametrization
# The original weights have been moved to net.linear1.parametrizations.weight.original
# More info here: https://pytorch.org/tutorials/intermediate/parametrizations.html#inspecting-a-parametrized-module
assert torch.equal(net.linear1.weight, net.linear1.parametrizations.weight.original + (net.linear1.parametrizations.weight[0].lora_B @ net.linear1.parametrizations.weight[0].lora_A) * net.linear1.parametrizations.weight[0].scale)

enable_disable_lora(enabled=False)
# If we disable LoRA, the linear1.weight is the original one
assert torch.equal(net.linear1.weight, original_weights['linear1.weight'])

In [21]:
# Test with LoRA enabled
enable_disable_lora(enabled=True)
test(test_loader, net)

100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:15<00:00, 64.78it/s]

Accuracy: 0.954
wrong counts for the digit 0: 12
wrong counts for the digit 1: 19
wrong counts for the digit 2: 36
wrong counts for the digit 3: 92
wrong counts for the digit 4: 29
wrong counts for the digit 5: 41
wrong counts for the digit 6: 38
wrong counts for the digit 7: 47
wrong counts for the digit 8: 35
wrong counts for the digit 9: 115





In [None]:
# Test with LoRA disabled
enable_disable_lora(enabled=False)
test(test_loader, net)

 25%|███████████████████▊                                                           | 251/1000 [00:02<00:08, 88.47it/s]