# Lora Introduction

references:
- https://youtu.be/DhRoTONcyZE?si=V9fBdboA9VJXw9Fx
- https://youtu.be/PXWYUTMt-AU?si=Nj8mpWCR20TQmnQ2


<img src="attachments/decomposed.png" width="1000">

Generate a rank-deficient matrix W

In [11]:
import torch
import numpy as np

_ = torch.manual_seed(42)

d, k = 10, 10 # d=input dimension, k=output dimension

W_rank = 2 # W_rank < d & W_rank < k
B, A = torch.rand(d, W_rank), torch.randn(W_rank, k)
W = B @ A
print("W shape: ", W.shape)
print("B & A shapes: ", B.shape, A.shape)

print("Total parameters of W: ", W.nelement())
print("Total parameters of B & A: ", B.nelement() + A.nelement())

W shape:  torch.Size([10, 10])
B & A shapes:  torch.Size([10, 2]) torch.Size([2, 10])
Total parameters of W:  100
Total parameters of B & A:  40


### B & A are not unique. And is possible to decompose them without the original B & A

In [3]:
W_rank = np.linalg.matrix_rank(W)
print(W_rank)

2


Calculate SVD decomposition of the W matrix.

In [14]:
def svd(W):
    # Perform SVD on W (W = UxSxV^T)
    U, S, V = torch.svd(W)

    # For rank-r factorization, keep only the first r singular values (and corresponding columns of U & V)
    U_r = U[:, :W_rank]
    S_r = torch.diag(S[:W_rank])
    V_r = V[:, :W_rank].t()

    # Compute C = U_r * S_r & V_r
    B_ = U_r @ S_r
    A_ = V_r
    return B_, A_

B_, A_ = svd(W)
print(f"Shape of B': {B_.shape}")
print(f"Shape of A': {A_.shape}")

# they are not the same, but (B @ A) == (B' @ A')
print("Original B: \n", B)
print("Decomposed B': \n", B_)

Shape of B': torch.Size([10, 2])
Shape of A': torch.Size([2, 10])
Original B: 
 tensor([[0.8823, 0.9150],
        [0.3829, 0.9593],
        [0.3904, 0.6009],
        [0.2566, 0.7936],
        [0.9408, 0.1332],
        [0.9346, 0.5936],
        [0.8694, 0.5677],
        [0.7411, 0.4294],
        [0.8854, 0.5739],
        [0.2666, 0.6274]])
Decomposed B': 
 tensor([[-4.6980, -0.2061],
        [-3.6697, -1.4674],
        [-2.6475, -0.5715],
        [-2.8966, -1.3529],
        [-2.5655,  1.8451],
        [-3.8867,  0.7025],
        [-3.6607,  0.6154],
        [-2.9623,  0.6582],
        [-3.7158,  0.6372],
        [-2.4377, -0.9225]])


In [5]:
# Generate random bias and input
bias = torch.randn(d)
x = torch.randn(d)

# Compute y = Wx + b
y = W @ x + bias

# Compute y' = CRx + b
y_prime = (B @ A) @ x + bias

print("Original y using W: \n", y)
print("")
print("y, computed using BA:\n", y_prime)

print(f"Both tensors have the same content: {torch.allclose(W, B @ A, atol=1e-4)}")

Original y using W: 
 tensor([6.1220, 4.5818, 3.2029, 3.8161, 4.6506, 6.4817, 4.2482, 3.4647, 5.5003,
        2.5157])

y, computed using BA:
 tensor([6.1220, 4.5818, 3.2029, 3.8161, 4.6506, 6.4817, 4.2482, 3.4647, 5.5003,
        2.5157])
Both tensors have the same content: True


In [7]:
print("Total parameters of W: ", W.nelement())
print("Total parameters of B & A: ", B.nelement() + A.nelement())

Total parameters of W:  100
Total parameters of B & A:  40


# Use Case

### Training

1. Finetuning<br>
<img src="attachments/lora1.png" width=300><br>

a. freeze the original weights (that are of shape [d, k])
b. initialise new loRA weights of shape [dxr] & [r, k], where r < d & r < k
c. train the new loRA weights for the new downstream task

2. Better checkpointing<br>
<img src="attachments/lora2.png" width=300><br>

a. Instead of saving a entirely new copy of the whole model, we will only save the lora module for checkpoints.
b. We can also train the model one multiple tasks in parallel, using seperate lora modules.

### Inference

1. Model switching<br>
<img src="attachments/lora3.png" width=800><br>

2. Model Specialisation<br>
<img src="attachments/lora4.png" width=300><br>

We can train different lora modules and switch between them using the lora modules.

# Lora Implementation

In [1]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm

_ = torch.manual_seed(0)

Load Mnist Dataset

In [2]:
from core.utils.device import get_device

device = get_device()

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Load the MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Load the MNIST test set
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

100%|██████████| 9.91M/9.91M [00:02<00:00, 3.46MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 117kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.04MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.06MB/s]


Create a classifier model (a unnecessarily big net for demonstration purposes).

In [3]:
class MnistClassifier(nn.Module):
    def __init__(self, hidden_size_1=1000, hidden_size_2=2000):
        super(MnistClassifier,self).__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x

net = MnistClassifier().to(device)

Train the network only for 1 epoch to simulate a complete general pre-training on the data

In [4]:
def train(train_loader, net, epochs=5, total_iterations_limit=None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

    total_iterations = 0

    for epoch in range(epochs):
        net.train()

        loss_sum = 0
        num_iterations = 0

        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = net(x.view(-1, 28*28))
            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return

train(train_loader, net, epochs=1)

Epoch 1: 100%|██████████| 6000/6000 [00:51<00:00, 115.57it/s, loss=0.234]


Keep a copy of the original weights (cloning them) so later we can prove that a fine-tuning with LoRA doesn't alter the original weights

In [5]:
original_weights = {}
for name, param in net.named_parameters():
    original_weights[name] = param.clone().detach()

The the performance of the pretrained network. As we can see, the network performs poorly on the digit 4. Let's fine-tune it on the digit 4

In [6]:
def test():
    correct = 0
    total = 0

    wrong_counts = [0 for i in range(10)]

    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = net(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct +=1
                else:
                    wrong_counts[y[idx]] +=1
                total +=1
    print(f'Accuracy: {round(correct/total, 3)}')
    for i in range(len(wrong_counts)):
        print(f'wrong counts for the digit {i}: {wrong_counts[i]}')

test()

Testing: 100%|██████████| 1000/1000 [00:05<00:00, 167.89it/s]

Accuracy: 0.956
wrong counts for the digit 0: 44
wrong counts for the digit 1: 51
wrong counts for the digit 2: 15
wrong counts for the digit 3: 46
wrong counts for the digit 4: 76
wrong counts for the digit 5: 23
wrong counts for the digit 6: 28
wrong counts for the digit 7: 56
wrong counts for the digit 8: 52
wrong counts for the digit 9: 45





Let's visualize how many parameters are in the original network, before introducing the LoRA matrices.

In [7]:
total_parameters_original = 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_parameters_original += layer.weight.nelement() + layer.bias.nelement()
    print(f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape}')
print(f'Total number of parameters: {total_parameters_original:,}')

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10])
Total number of parameters: 2,807,010


Define the LoRA parameterization as described in the paper. The full detail on how PyTorch parameterizations work is here: https://pytorch.org/tutorials/intermediate/parametrizations.html

In [8]:
class LoRAParametrization(nn.Module):
    def __init__(self, features_in, features_out, rank=1, alpha=1, device='cpu'):
        super().__init__()
        # Section 4.1 of the paper:
        #   We use a random Gaussian initialization for A and zero for B, so ∆W = BA is zero at the beginning of training
        self.lora_A = nn.Parameter(torch.zeros((rank,features_out)).to(device))
        self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
        nn.init.normal_(self.lora_A, mean=0, std=1)

        # Section 4.1 of the paper:
        #   We then scale ∆Wx by α/r , where α is a constant in r.
        #   When optimizing with Adam, tuning α is roughly the same as tuning the learning rate if we scale the initialization appropriately.
        #   As a result, we simply set α to the first r we try and do not tune it.
        #   This scaling helps to reduce the need to retune hyperparameters when we vary r.
        self.scale = alpha / rank
        self.enabled = True

    def forward(self, original_weights):
        if self.enabled:
            # Return W + (B*A)*scale
            return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
        else:
            return original_weights

Add the parameterization to our network.

In [9]:
import torch.nn.utils.parametrize as parametrize

def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):
    # Only add the parameterization to the weight matrix, ignore the Bias

    # From section 4.2 of the paper:
    #   We limit our study to only adapting the attention weights for downstream tasks and freeze the MLP modules (so they are not trained in downstream tasks) both for simplicity and parameter-efficiency.
    #   [...]
    #   We leave the empirical investigation of [...], and biases to a future work.

    features_in, features_out = layer.weight.shape
    return LoRAParametrization(
        features_in, features_out, rank=rank, alpha=lora_alpha, device=device
    )

parametrize.register_parametrization(
    net.linear1, "weight", linear_layer_parameterization(net.linear1, device)
)
parametrize.register_parametrization(
    net.linear2, "weight", linear_layer_parameterization(net.linear2, device)
)
parametrize.register_parametrization(
    net.linear3, "weight", linear_layer_parameterization(net.linear3, device)
)


def enable_disable_lora(enabled=True):
    for layer in [net.linear1, net.linear2, net.linear3]:
        layer.parametrizations["weight"][0].enabled = enabled

Display the number of parameters added by LoRA.

In [10]:
total_parameters_lora = 0
total_parameters_non_lora = 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_parameters_lora += layer.parametrizations["weight"][0].lora_A.nelement() + layer.parametrizations["weight"][0].lora_B.nelement()
    total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
    print(
        f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + Lora_A: {layer.parametrizations["weight"][0].lora_A.shape} + Lora_B: {layer.parametrizations["weight"][0].lora_B.shape}'
    )
# The non-LoRA parameters count must match the original network
assert total_parameters_non_lora == total_parameters_original
print(f'Total number of parameters (original): {total_parameters_non_lora:,}')
print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}')
print(f'Parameters introduced by LoRA: {total_parameters_lora:,}')
parameters_incremment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f'Parameters incremment: {parameters_incremment:.3f}%')

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000]) + Lora_A: torch.Size([1, 784]) + Lora_B: torch.Size([1000, 1])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000]) + Lora_A: torch.Size([1, 1000]) + Lora_B: torch.Size([2000, 1])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10]) + Lora_A: torch.Size([1, 2000]) + Lora_B: torch.Size([10, 1])
Total number of parameters (original): 2,807,010
Total number of parameters (original + LoRA): 2,813,804
Parameters introduced by LoRA: 6,794
Parameters incremment: 0.242%


Freeze all the parameters of the original network and only fine tuning the ones introduced by LoRA. Then fine-tune the model on the digit 4 and only for 100 batches.

In [11]:
# Freeze the non-Lora parameters
for name, param in net.named_parameters():
    if 'lora' not in name:
        print(f'Freezing non-LoRA parameter {name}')
        param.requires_grad = False

# Load the MNIST dataset again, by keeping only the digit 4
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
exclude_indices = mnist_trainset.targets == 4
mnist_trainset.data = mnist_trainset.data[exclude_indices]
mnist_trainset.targets = mnist_trainset.targets[exclude_indices]
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Train the network with LoRA only on the digit 4 and only for 100 batches (hoping that it would improve the performance on the digit 4)
train(train_loader, net, epochs=1, total_iterations_limit=100)

Freezing non-LoRA parameter linear1.bias
Freezing non-LoRA parameter linear1.parametrizations.weight.original
Freezing non-LoRA parameter linear2.bias
Freezing non-LoRA parameter linear2.parametrizations.weight.original
Freezing non-LoRA parameter linear3.bias
Freezing non-LoRA parameter linear3.parametrizations.weight.original


Epoch 1:  99%|█████████▉| 99/100 [00:01<00:00, 50.31it/s, loss=0.168]


Verify that the fine-tuning didn't alter the original weights, but only the ones introduced by LoRA.

In [12]:
# Check that the frozen parameters are still unchanged by the finetuning
assert torch.all(net.linear1.parametrizations.weight.original == original_weights['linear1.weight'])
assert torch.all(net.linear2.parametrizations.weight.original == original_weights['linear2.weight'])
assert torch.all(net.linear3.parametrizations.weight.original == original_weights['linear3.weight'])

enable_disable_lora(enabled=True)
# The new linear1.weight is obtained by the "forward" function of our LoRA parametrization
# The original weights have been moved to net.linear1.parametrizations.weight.original
# More info here: https://pytorch.org/tutorials/intermediate/parametrizations.html#inspecting-a-parametrized-module
assert torch.equal(net.linear1.weight, net.linear1.parametrizations.weight.original + (net.linear1.parametrizations.weight[0].lora_B @ net.linear1.parametrizations.weight[0].lora_A) * net.linear1.parametrizations.weight[0].scale)

enable_disable_lora(enabled=False)
# If we disable LoRA, the linear1.weight is the original one
assert torch.equal(net.linear1.weight, original_weights['linear1.weight'])

Test the network with LoRA enabled (the digit 4 should be classified better)

In [13]:
# Test with LoRA enabled (Note, the training made digit 9 worse)
enable_disable_lora(enabled=True)
test()

Testing: 100%|██████████| 1000/1000 [00:06<00:00, 162.00it/s]

Accuracy: 0.956
wrong counts for the digit 0: 46
wrong counts for the digit 1: 38
wrong counts for the digit 2: 15
wrong counts for the digit 3: 52
wrong counts for the digit 4: 14
wrong counts for the digit 5: 26
wrong counts for the digit 6: 21
wrong counts for the digit 7: 56
wrong counts for the digit 8: 47
wrong counts for the digit 9: 129





Test the network with LoRA disabled (the accuracy and errors counts must be the same as the original network)

In [14]:
# Test with LoRA disabled
enable_disable_lora(enabled=False)
test()

Testing: 100%|██████████| 1000/1000 [00:05<00:00, 188.03it/s]

Accuracy: 0.956
wrong counts for the digit 0: 44
wrong counts for the digit 1: 51
wrong counts for the digit 2: 15
wrong counts for the digit 3: 46
wrong counts for the digit 4: 76
wrong counts for the digit 5: 23
wrong counts for the digit 6: 28
wrong counts for the digit 7: 56
wrong counts for the digit 8: 52
wrong counts for the digit 9: 45



