# LoRA implementation with PyTorch

In [1]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch.nn.utils.parametrize as parametrize

Make the model deterministic

In [2]:
# Make torch deterministic
_ = torch.manual_seed(0)

We will be training a network to classify MNIST digits and then fine-tune the network on a particular digit on which it doesnt perform well

In [3]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Load the MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Create a dataloader for the training set
train_loader = DataLoader(mnist_trainset, batch_size=10, shuffle=True)

#Load the MNIST test dataset
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# Create a dataloader for the test set
test_loader = DataLoader(mnist_testset, batch_size=10, shuffle=True)

#Define the model
# device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
device = torch.device('cpu')

Create the Neural Network to classify the digits, making it overly complicated to better show LoRA

In [4]:
# Create an overly expensive neural network to classify MNIST digits
# Daddy got money, so I don't care about efficiency
class RichBoyNet(nn.Module):
    def __init__(self, hidden_size_1=1000, hidden_size_2=2000):
        super(RichBoyNet, self).__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x

In [5]:
net = RichBoyNet().to(device)

Train the network only for 1 epoch to simulate a complete general pre-training on the data

In [6]:
def train(train_loader, net, epochs=5, total_iterations_limit=None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

    total_iterations = 0

    for epoch in range(epochs):
        net.train()

        loss_sum = 0
        num_iterations = 0

        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch + 1}')
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x,y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = net(x.view(-1, 28*28))
            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return

In [7]:
train(train_loader, net, epochs=1)

Epoch 1: 100%|██████████| 6000/6000 [00:37<00:00, 160.74it/s, loss=0.239]


Keep a copy of the original weights (cloning them) so later we can prove that a fine-tuning with LoRA doesn't alter the original weights

In [8]:
original_weights={}
for name, param in net.named_parameters():
    original_weights[name] = param.clone().detach()

The performance of the pretrained networks. As we can see, the network performs poorly on the digit 9, let's fine tune it

In [9]:
def test():
    correct = 0
    total = 0

    wrong_counts = [0 for i in range(10)]

    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x,y = data
            x = x.to(device)
            y = y.to(device)
            output = net(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct += 1
                else:
                    wrong_counts[y[idx]] += 1
                total += 1
    
    print(f'Accuracy: {round(correct/total, 3)}')
    for i in range(len(wrong_counts)):
        print(f'Wrong counts for the digit {i}: {wrong_counts[i]}')

In [10]:
test()

Testing: 100%|██████████| 1000/1000 [00:01<00:00, 778.30it/s]

Accuracy: 0.952
Wrong counts for the digit 0: 28
Wrong counts for the digit 1: 12
Wrong counts for the digit 2: 49
Wrong counts for the digit 3: 68
Wrong counts for the digit 4: 33
Wrong counts for the digit 5: 18
Wrong counts for the digit 6: 103
Wrong counts for the digit 7: 44
Wrong counts for the digit 8: 24
Wrong counts for the digit 9: 103





Number of paramters in the original network, before introducing the LoRA matrices

In [11]:
total_parameters_original = 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_parameters_original += layer.weight.nelement() + layer.bias.nelement()
    print(f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape}')
print(f'Total number of parameters: {total_parameters_original:,}')

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10])
Total number of parameters: 2,807,010


Define the LoRA parameterization as described in the paper.

- We run LoRA only on the weights matrix and not on the bias matrix

In [12]:
class LoRAParametrization(nn.Module):
    def __init__(self, features_in, features_out, rank=1, alpha=1, device='cpu'):
        super(LoRAParametrization, self).__init__()
        # We use a random Gaussian initialization for A and zero for B, so △W = BA is zero at the beginning of training
        self.lora_A = nn.Parameter(torch.zeros((rank,features_out)).to(device))
        self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
        nn.init.normal_(self.lora_A, mean=0, std=1)# Gaussian initialization for A

        # We then scale △Wx by a/r, where a is a constant in r
        # When optimizing with Adam, tuning a is roughly the same as tuning the learning rate if we scale the initializarion appropriately
        # As a result, we simply set a to the the first r we try and do not tune it.
        # This scaling helps to reduce the need to treturn hyperparameters when we vary r.
        self.scale = alpha / rank
        self.enabled = True

    def forward(self, original_weights):
        if self.enabled:
            # Return W + (B*A)*scale
            return  original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
        else:
            return original_weights

Add the parametrization to the network

In [13]:
def linear_layer_parametrization(layer, device, rank=1, lora_alpha=1):
    # Only add the parametrization to the weight, not the bias

    # We limit out study to only adapting the attention weights for downstream tasks nad freeze the MLP modules (so they are not trained in the downstream tasks) both for

    features_in, features_out = layer.weight.shape
    return LoRAParametrization(
        features_in,
        features_out,
        rank=rank,
        alpha=lora_alpha,
        device=device
    )

# When the model is created, we add the LoRA parametrization to the weights of the linear layers
parametrize.register_parametrization(
    net.linear1,
    'weight',
    linear_layer_parametrization(net.linear1, device)
)

parametrize.register_parametrization(
    net.linear2,
    'weight',
    linear_layer_parametrization(net.linear2, device)
)

parametrize.register_parametrization(
    net.linear3,
    'weight',
    linear_layer_parametrization(net.linear3, device)
)

def enable_disable_lora(enabled=True):
    for layer in [net.linear1, net.linear2, net.linear3]:
        layer.parametrizations["weight"][0].enabled = enabled

Parameters introduced by LoRA

In [14]:
total_parameters_lora = 0
total_parameters_non_lora = 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_parameters_lora += layer.parametrizations["weight"][0].lora_A.nelement() + layer.parametrizations["weight"][0].lora_B.nelement()
    total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
    print(
        f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + Lora_A: {layer.parametrizations["weight"][0].lora_A.shape} + Lora_B: {layer.parametrizations["weight"][0].lora_B.shape}'
    )
# The non-LoRA parameters count must match the original network
assert total_parameters_non_lora == total_parameters_original
print(f'Total number of parameters (original): {total_parameters_non_lora:,}')
print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}')
print(f'Parameters introduced by LoRA: {total_parameters_lora:,}')
parameters_incremment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f'Parameters incremment: {parameters_incremment:.3f}%')

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000]) + Lora_A: torch.Size([1, 784]) + Lora_B: torch.Size([1000, 1])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000]) + Lora_A: torch.Size([1, 1000]) + Lora_B: torch.Size([2000, 1])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10]) + Lora_A: torch.Size([1, 2000]) + Lora_B: torch.Size([10, 1])
Total number of parameters (original): 2,807,010
Total number of parameters (original + LoRA): 2,813,804
Parameters introduced by LoRA: 6,794
Parameters incremment: 0.242%


Freeze all the parameters of the original network and only finr-tune the ones introduced by LoRA. THen fine-tune the model on the digit 9 and only for 100 batches

In [15]:
# Freeze the non-LoRA parameters
for name, param in net.named_parameters():
    if 'lora' not in name:
        print(f'Freezing non-LoRA paramter {name}')
        param.requires_grad = False

# Loading the MNIST dataset again, by keeping only the digit 9
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
exclude_indices = mnist_trainset.targets == 9
mnist_trainset.data = mnist_trainset.data[exclude_indices]
mnist_trainset.targets = mnist_trainset.targets[exclude_indices]
#Create a dataloader for the training set
train_loader = DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Train the network with LoRA only on the digit 9 and only for 100 batches
train(train_loader, net, epochs=1, total_iterations_limit=100)

Freezing non-LoRA paramter linear1.bias
Freezing non-LoRA paramter linear1.parametrizations.weight.original
Freezing non-LoRA paramter linear2.bias
Freezing non-LoRA paramter linear2.parametrizations.weight.original
Freezing non-LoRA paramter linear3.bias
Freezing non-LoRA paramter linear3.parametrizations.weight.original


Epoch 1:  99%|█████████▉| 99/100 [00:00<00:00, 130.87it/s, loss=0.135]


Verify that the fine-tuning didn't alter the original weights, but only the ones introduced by LoRA

In [16]:
original_weights['linear1.weight']

tensor([[ 0.0276,  0.0470, -0.0015,  ...,  0.0498,  0.0316,  0.0299],
        [ 0.0206,  0.0254,  0.0299,  ...,  0.0201,  0.0344,  0.0104],
        [ 0.0055,  0.0405, -0.0076,  ...,  0.0053,  0.0268,  0.0337],
        ...,
        [-0.0077,  0.0559,  0.0542,  ...,  0.0231,  0.0493, -0.0033],
        [ 0.0593,  0.0226,  0.0083,  ...,  0.0426,  0.0362,  0.0370],
        [ 0.0337,  0.0100,  0.0611,  ...,  0.0705,  0.0440,  0.0544]])

In [17]:
net.linear1.parametrizations.weight.original

Parameter containing:
tensor([[ 0.0276,  0.0470, -0.0015,  ...,  0.0498,  0.0316,  0.0299],
        [ 0.0206,  0.0254,  0.0299,  ...,  0.0201,  0.0344,  0.0104],
        [ 0.0055,  0.0405, -0.0076,  ...,  0.0053,  0.0268,  0.0337],
        ...,
        [-0.0077,  0.0559,  0.0542,  ...,  0.0231,  0.0493, -0.0033],
        [ 0.0593,  0.0226,  0.0083,  ...,  0.0426,  0.0362,  0.0370],
        [ 0.0337,  0.0100,  0.0611,  ...,  0.0705,  0.0440,  0.0544]])

In [18]:
# Check that the frozen parameters are still unchanged by the finetuning
assert torch.all(net.linear1.parametrizations.weight.original == original_weights['linear1.weight'])
assert torch.all(net.linear2.parametrizations.weight.original == original_weights['linear2.weight'])
assert torch.all(net.linear3.parametrizations.weight.original == original_weights['linear3.weight'])

enable_disable_lora(enabled=True)
# The new linear1.weight is obtained by the "forward" function of our LoRA parametrization
# The original weights have been moved to net.linear1.parametrizations.weight.original
# More info here: https://pytorch.org/tutorials/intermediate/parametrizations.html#inspecting-a-parametrized-module
assert torch.equal(net.linear1.weight, net.linear1.parametrizations.weight.original + (net.linear1.parametrizations.weight[0].lora_B @ net.linear1.parametrizations.weight[0].lora_A) * net.linear1.parametrizations.weight[0].scale)

enable_disable_lora(enabled=False)
# If we disable LoRA, the linear1.weight is the original one
assert torch.equal(net.linear1.weight, original_weights['linear1.weight'])

Test the netowrk with LoRA enabled

In [19]:
# Test with LoRA enabled
enable_disable_lora(enabled=True)
test()

Testing: 100%|██████████| 1000/1000 [00:04<00:00, 237.27it/s]

Accuracy: 0.94
Wrong counts for the digit 0: 29
Wrong counts for the digit 1: 17
Wrong counts for the digit 2: 81
Wrong counts for the digit 3: 100
Wrong counts for the digit 4: 86
Wrong counts for the digit 5: 28
Wrong counts for the digit 6: 127
Wrong counts for the digit 7: 75
Wrong counts for the digit 8: 39
Wrong counts for the digit 9: 18





Test the network with LoRA disabled

In [20]:
# Test with LoRA enabled
enable_disable_lora(enabled=False)
test()

Testing: 100%|██████████| 1000/1000 [00:01<00:00, 738.60it/s]

Accuracy: 0.952
Wrong counts for the digit 0: 28
Wrong counts for the digit 1: 12
Wrong counts for the digit 2: 49
Wrong counts for the digit 3: 68
Wrong counts for the digit 4: 33
Wrong counts for the digit 5: 18
Wrong counts for the digit 6: 103
Wrong counts for the digit 7: 44
Wrong counts for the digit 8: 24
Wrong counts for the digit 9: 103



