# LoRA : Low-Rank Adaptation MNIST - implementation with PyTorch

Importing the librairies

In [2]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn

import matplotlib.pyplot as plt
from tqdm import tqdm

Make the model deterministic (using seeds)

In [3]:
_ = torch.manual_seed(0)

# define the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Training a network to classify MNIST digits and then fine-tune the network on a particular digit that underperform.

In [4]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((.1307,), (0.3081,))])

# Load MNIST (train)
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Load MNIST (test)
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# Dataloader for the training
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)



Create the NN to classify digits, making it overly complicated for LoRA purpose.

In [5]:
# Not really efficient
class NNet(nn.Module):
    
    def __init__(self, hidden_size1=1000, hidden_size2=2000):
        super().__init__()
        self.linear1 = nn.Linear(28*28, hidden_size1)
        self.linear2 = nn.Linear(hidden_size1, hidden_size2)
        self.linear3 = nn.Linear(hidden_size2, 10)
        self.relu = nn.ReLU()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x
    
net = NNet().to(device)

Train the network only for 1 epoch to simulate a complete general pre-training on the data

In [6]:
def train(train_loader, net, epochs=5, total_iteration_limit=None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

    total_iterations = 0

    for epoch in range(epochs):
        net.train()

        loss_sum = 0
        num_iterations = 0

        data_iterator = tqdm(train_loader, desc=f"Epoch {epoch+1}")
        if total_iteration_limit is not None:
            data_iterator.total = total_iteration_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = net(x.view(-1, 28*28))
            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()

            if total_iteration_limit is not None and total_iterations >= total_iteration_limit:
                return

train(train_loader, net, epochs=1)

Epoch 1: 100%|██████████| 6000/6000 [00:11<00:00, 520.74it/s, loss=0.238]


Keep the original weights

In [7]:
original_weights = {}
for name, param in net.named_parameters():
    original_weights[name] = param.clone().detach()

Testing the perrformance of our model, previously trained.

In [8]:
def test():
    correct = 0
    total = 0

    wrong_counts = [0 for i in range(10)]

    with torch.no_grad():
        for data in tqdm(test_loader, desc="Testing"):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = net(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct += 1
                else:
                    wrong_counts[y[idx]] += 1
                total += 1
        print(f'Accuracy: {round(correct/total, 3)}')
        for i in range(len(wrong_counts)):
            print(f'Wrong counts for the digit {i}: {wrong_counts[i]}')

test()

Testing: 100%|██████████| 1000/1000 [00:01<00:00, 782.61it/s]

Accuracy: 0.962
Wrong counts for the digit 0: 22
Wrong counts for the digit 1: 27
Wrong counts for the digit 2: 18
Wrong counts for the digit 3: 45
Wrong counts for the digit 4: 34
Wrong counts for the digit 5: 22
Wrong counts for the digit 6: 38
Wrong counts for the digit 7: 60
Wrong counts for the digit 8: 22
Wrong counts for the digit 9: 88





Before implementing LoRA, let's visualize how many parameters are in the original network

In [9]:
# Print the size of the weights matrices of the network
# Save the count of the total number of parameters
total_parameters_original = 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_parameters_original += layer.weight.nelement() + layer.bias.nelement()
    print(f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape}')
print(f'Total number of parameters: {total_parameters_original:,}')

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10])
Total number of parameters: 2,807,010


## Define LoRA parameterization

In [15]:
class LoRA(nn.Module):
    
    def __init__(self, features_in, features_out, rank=1, alpha=1, device=device):
        super().__init__()

        # Section 4.1 of the paper:
        # We use a random Gaussian init for A and zero for Bm so deltaW = BA is zero at the beginning of training
        
        self.lora_A = nn.Parameter(torch.zeros((rank, features_out)).to(device))
        self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
        nn.init.normal_(self.lora_A, mean=0, std=1)

        # Section 4.1 of the paper:
        #   We then scale deltaW by alpha/r, where alpha is a constant in r.
        #   when optimizing with Adam, tuning alpha is roughly the same as tuning the lr if we scale the init approprietaly
        #   As a result, we simply set alpha to the first r we try and do not tune it.
        #   This scaling helps to redyuce the need toreturn Hyperparameters when we vary r.

        self.scale = alpha/rank
        self.enabled = True

    def forward(self, original_weights):
        if self.enabled:
            # Return W + (B*A)* scale
            return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
        else:
            return original_weights

Add the LoRA to our network (previously build : NNet)

In [16]:
import torch.nn.utils.parametrize as parametrize

def linear_layer_parametrization(layer, device, rank=1, lora_alpha=1):
    # Only add to the weight matrix, ignore the Bias

    # Section 4.2 of the paper:
    #   We limit our study to only adapting the attention weights for downstream tasks and freeze the MLP modules
    #   so they are rained on downstream tasks.

    features_in, features_out = layer.weight.shape
    return LoRA(
        features_in, features_out, rank=rank, alpha=lora_alpha, device=device
    )

parametrize.register_parametrization(
    net.linear1, "weight", linear_layer_parametrization(net.linear1, device)
)
parametrize.register_parametrization(
    net.linear2, "weight", linear_layer_parametrization(net.linear2, device)
)
parametrize.register_parametrization(
    net.linear3, "weight", linear_layer_parametrization(net.linear3, device)
)

def enable_disable_lora(enabled=True):
    for layer in [net.linear1, net.linear2, net.linear3]:
        layer.parametrizations["weight"][0].enabled = enabled

Display the number of parameters added by LoRA