# LoRA

### Notebook to help understand how LoRA works.

In [1]:
import torch
import numpy as np
torch.manual_seed(1)

<torch._C.Generator at 0x10cc8b390>

In [2]:
d, k = 10, 10

rank = 2
W = torch.randn(d,rank) @ torch.randn(rank,k)
print(W)

tensor([[-0.4902,  1.2622, -1.0716,  0.5020, -0.4107,  1.2846,  2.4297,  0.8886,
          0.6491, -2.9506],
        [-0.5303,  2.8276, -1.8984, -0.1688, -0.3176,  1.3831,  0.4896, -1.5201,
          2.5795, -3.3387],
        [ 0.2331, -0.3944,  0.4055, -0.3389,  0.2131, -0.6118, -1.4564, -0.7718,
         -0.0443,  1.3823],
        [ 0.0687, -1.4636,  0.8006,  0.5560, -0.0539, -0.1743,  1.5414,  2.0590,
         -1.7430,  0.5429],
        [ 0.3786, -2.0593,  1.3758,  0.1403,  0.2232, -0.9872, -0.2900,  1.1542,
         -1.8937,  2.3876],
        [ 0.3281,  0.2501,  0.1638, -0.8692,  0.3698, -0.8648, -3.2282, -2.4531,
          0.9715,  1.8651],
        [-0.1414,  1.0995, -0.6809, -0.2133, -0.0547,  0.3672, -0.3750, -0.9918,
          1.1315, -0.9249],
        [-0.1916,  0.1972, -0.2691,  0.3404, -0.1862,  0.5034,  1.3828,  0.8498,
         -0.1265, -1.1235],
        [-0.0755, -0.6325,  0.2530,  0.4799, -0.1349,  0.2015,  1.5838,  1.5402,
         -0.9618, -0.3713],
        [ 0.4815, -

Recall that SVD is defined as the following: $$ W = U \Sigma V^{T}$$ where $W$ is an $m x n$ matrix, $U$ is an $m x m$ matrix, $\Sigma$ is an m x n matrix and $V^{T}$ is an $n x n$ matrix.

In [3]:
U, S, V = torch.svd(W)

U_r = U[:, :rank]
S_r = torch.diag(S[:rank])
V_r = V[:, :rank].t()

In the LoRA paper, recall that we created matrices A (r x k) and B (d x r). We can create these matrices using the rank factorized matrices from the SVD we performed.

In [4]:
A = V_r
B = U_r @ S_r

print(f"A shape: {A.shape}")
print(f"B shape: {B.shape}")


A shape: torch.Size([2, 10])
B shape: torch.Size([10, 2])


We can test how accurate BA is by comparing how similar the results are when providing input to W and BA.

In [19]:
x = torch.randn(d)

# y = Wx
y = W @ x
# y' = (B*A)x
y_prime = (B @ A) @ x

print("Original y using W:\n", y)
print("")
print("y' computed using BA:\n", y_prime)

print("Total parameters of W: ", W.nelement())
print("Total parameters of B and A: ", B.nelement() + A.nelement())

Original y using W:
 tensor([-6.3192, -4.3742,  3.3514, -1.2804,  3.0542,  6.0737, -0.5844, -2.9684,
        -2.3654,  3.1847])

y' computed using BA:
 tensor([-6.3192, -4.3742,  3.3514, -1.2804,  3.0542,  6.0737, -0.5844, -2.9684,
        -2.3654,  3.1847])
Total parameters of W:  100
Total parameters of B and A:  40


This is the basis of why LoRA is such a powerful technique. We were able to replicate the original matrix's behavior with far fewer parameters. This is why LoRA is significantly more memory efficient than other techniques that were created for the purpose of fine-tuning models.

# LoRA in application

Let's apply LoRA to a simple neural network

In [2]:
import torch.nn.functional as F
import torch.optim as optim
import torch.nn.utils.parametrize as parametrize
from torch import nn
from torchvision import datasets, transforms

In [3]:
# Download MNIST
transform = transforms.Compose([transforms.ToTensor()])

# Download and load the training data
trainset = datasets.MNIST(root='../data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

# Download and load the test data
testset = datasets.MNIST(root='../data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

images, labels = next(iter(trainloader))

print(f"Image batch shape: {images.shape}")
print(f"Label batch shape: {labels.shape}")

Image batch shape: torch.Size([64, 1, 28, 28])
Label batch shape: torch.Size([64])


Simple CNN with three convolution layers.

In [4]:
class Network(nn.Module):
    def __init__(self, input_channels):
        super(Network, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(128 * 4 * 4, 10)
        
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = self.flatten(x)
        x = self.fc(x)
        
        return x

In [5]:
net = Network(input_channels=1)

optimizer = optim.Adam(net.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

epochs = 1
def train():
    for epoch in range(epochs):
        net.train()
        running_loss = 0.0
        
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            
            optimizer.zero_grad()
            
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            if i % 100 == 99:
                print(f'Epoch [{epoch + 1}/{epochs}], Step [{i + 1}/{len(trainloader)}], Loss: {running_loss / 100:.4f}')
                running_loss = 0.0

train()

Epoch [1/1], Step [100/938], Loss: 0.8411
Epoch [1/1], Step [200/938], Loss: 0.2673
Epoch [1/1], Step [300/938], Loss: 0.1966
Epoch [1/1], Step [400/938], Loss: 0.1685
Epoch [1/1], Step [500/938], Loss: 0.1367
Epoch [1/1], Step [600/938], Loss: 0.1187
Epoch [1/1], Step [700/938], Loss: 0.1034
Epoch [1/1], Step [800/938], Loss: 0.1004
Epoch [1/1], Step [900/938], Loss: 0.0831


Let's consider this to be our "pre-trained" weights. We can then apply LoRA on these original weights by fine tuning on a specific number.

In [6]:
original_weights = {}
for name, param in net.named_parameters():
    original_weights[name] = param.clone().detach()

print(f'Weights shape: {len(original_weights)}')

total_parameters_original = 0
for index, layer in enumerate([net.conv1, net.conv2, net.conv3]):
    total_parameters_original += layer.weight.nelement() + layer.bias.nelement()
    print(f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape}')
print(f'Total number of parameters: {total_parameters_original:,}')

Weights shape: 8
Layer 1: W: torch.Size([32, 1, 3, 3]) + B: torch.Size([32])
Layer 2: W: torch.Size([64, 32, 3, 3]) + B: torch.Size([64])
Layer 3: W: torch.Size([128, 64, 3, 3]) + B: torch.Size([128])
Total number of parameters: 92,672


In [7]:
def test():
    correct = 0
    total = 0

    wrong_counts = [0 for i in range(10)]

    for i, data in enumerate(trainloader, 0):
        inputs, labels = data

        output = net(inputs)
        for idx, i in enumerate(output):
            if torch.argmax(i) == labels[idx]:
                correct +=1
            else:
                wrong_counts[labels[idx]] +=1
            total +=1
    print(f'Accuracy: {round(correct/total, 3)}')
    for i in range(len(wrong_counts)):
        print(f'wrong counts for the digit {i}: {wrong_counts[i]}')

test()

Accuracy: 0.977
wrong counts for the digit 0: 96
wrong counts for the digit 1: 208
wrong counts for the digit 2: 232
wrong counts for the digit 3: 155
wrong counts for the digit 4: 126
wrong counts for the digit 5: 165
wrong counts for the digit 6: 52
wrong counts for the digit 7: 69
wrong counts for the digit 8: 92
wrong counts for the digit 9: 161


Since there is high number of wrong counts for the number 3, let's fine tune our model on 3. First, we will have to implement LoRA which we can apply to each layer using parametrizations. This implementation follows the original LoRA paper (Hu et. al.). Similar to the paper, I am only applying LoRA to the weight matrix of the convolution layers (excluding the bias). 

In [8]:
class LoRA(nn.Module):
    def __init__(self, features_in, features_out, rank=1, alpha=1):
        super(LoRA, self).__init__()
        # A (r x k)
        self.lora_A = nn.Parameter(torch.zeros((rank,features_out)))
        # B (d x r)
        self.lora_B = nn.Parameter(torch.zeros((features_in, rank)))
        # Section 4.1 states that A is initialized to a random gaussian distribution while B is intialized to 0. Thus initially, âˆ†W = BA = 0.
        nn.init.normal_(self.lora_A, mean=0, std=1)
        self.scale = alpha / rank
        self.enabled = True

    def forward(self, original_weights):
        if self.enabled:
            # W + BA * scale
            return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
        else:
            return original_weights

def conv_layer_parameterization(layer, rank=1, alpha=1):
    out_channels, in_channels, kernel_height, kernel_width = layer.weight.shape
    features_in = in_channels * kernel_height * kernel_width
    features_out = out_channels
    return LoRA(features_in, features_out, rank=rank, alpha=alpha)

parametrize.register_parametrization(
    net.conv1, "weight", conv_layer_parameterization(net.conv1)
)
parametrize.register_parametrization(
    net.conv2, "weight", conv_layer_parameterization(net.conv2)
)
parametrize.register_parametrization(
    net.conv3, "weight", conv_layer_parameterization(net.conv3)
)

def enable_disable_lora(enabled=True):
    for layer in [net.conv1, net.conv2, net.conv3]:
        layer.parametrizations["weight"][0].enabled = enabled

In [9]:
total_parameters_lora = 0
total_parameters_non_lora = 0
for index, layer in enumerate([net.conv1, net.conv2, net.conv3]):
    total_parameters_lora += layer.parametrizations["weight"][0].lora_A.nelement() + layer.parametrizations["weight"][0].lora_B.nelement()
    total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
    print(
        f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + Lora_A: {layer.parametrizations["weight"][0].lora_A.shape} + Lora_B: {layer.parametrizations["weight"][0].lora_B.shape}'
    )

assert total_parameters_non_lora == total_parameters_original

print(f'Total number of parameters (original): {total_parameters_non_lora:,}')
print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}')
print(f'Parameters introduced by LoRA: {total_parameters_lora:,}')
parameters_increment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f'Parameters increment: {parameters_increment:.3f}%')

Layer 1: W: torch.Size([32, 1, 3, 3]) + B: torch.Size([32]) + Lora_A: torch.Size([1, 32]) + Lora_B: torch.Size([9, 1])
Layer 2: W: torch.Size([64, 32, 3, 3]) + B: torch.Size([64]) + Lora_A: torch.Size([1, 64]) + Lora_B: torch.Size([288, 1])
Layer 3: W: torch.Size([128, 64, 3, 3]) + B: torch.Size([128]) + Lora_A: torch.Size([1, 128]) + Lora_B: torch.Size([576, 1])
Total number of parameters (original): 92,672
Total number of parameters (original + LoRA): 93,769
Parameters introduced by LoRA: 1,097
Parameters increment: 1.184%


If we did not have LoRA, we would have had double the number of parameters as the original as we would have needed a new weight value for each parameter. However, with this approach, we are only adding in a marginal fraction of the original number of parameters, 1.184% to be exact.

Now, we can fine-tune our model by training specifically on the data with label = 3.

In [10]:
# Freeze the non-Lora parameters
for name, param in net.named_parameters():
    if 'lora' not in name:
        print(f'Freezing non-LoRA parameter {name}')
        param.requires_grad = False

trainset = datasets.MNIST(root='../data', train=True, download=True, transform=transform)

label_3_indices = trainset.targets == 3
label_3_data = trainset.data[label_3_indices]
label_3_targets = trainset.targets[label_3_indices]

trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

train()

Freezing non-LoRA parameter conv1.bias
Freezing non-LoRA parameter conv1.parametrizations.weight.original
Freezing non-LoRA parameter conv2.bias
Freezing non-LoRA parameter conv2.parametrizations.weight.original
Freezing non-LoRA parameter conv3.bias
Freezing non-LoRA parameter conv3.parametrizations.weight.original
Freezing non-LoRA parameter fc.weight
Freezing non-LoRA parameter fc.bias
Epoch [1/1], Step [100/938], Loss: 0.0713
Epoch [1/1], Step [200/938], Loss: 0.0823
Epoch [1/1], Step [300/938], Loss: 0.0889
Epoch [1/1], Step [400/938], Loss: 0.0779
Epoch [1/1], Step [500/938], Loss: 0.0789
Epoch [1/1], Step [600/938], Loss: 0.0774
Epoch [1/1], Step [700/938], Loss: 0.0860
Epoch [1/1], Step [800/938], Loss: 0.0732
Epoch [1/1], Step [900/938], Loss: 0.0762


In [11]:
def print_layer_weights(layer, name):
    print(f"--- {name} Weights ---")
    print(layer.weight)
    if "weight" in layer.parametrizations:
        print(f"--- {name} LoRA A Weights ---")
        print(layer.parametrizations["weight"][0].lora_A)
        print(f"--- {name} LoRA B Weights ---")
        print(layer.parametrizations["weight"][0].lora_B)
    print("\n")


print_layer_weights(net.conv1, "Conv1")
enable_disable_lora(True)
print_layer_weights(net.conv1, "Conv1 with LoRA Enabled")
enable_disable_lora(False)
print_layer_weights(net.conv1, "Conv1 with LoRA Disabled")

--- Conv1 Weights ---
tensor([[[[-0.3201,  0.2351, -0.0387],
          [ 0.2227,  0.1490,  0.0261],
          [ 0.1478,  0.1012,  0.1695]]],


        [[[-0.1382,  0.0292,  0.0362],
          [ 0.0689,  0.0742,  0.3690],
          [ 0.0837, -0.1061, -0.1688]]],


        [[[-0.0426, -0.0773, -0.0327],
          [ 0.0630,  0.3050,  0.2703],
          [-0.2952,  0.2926,  0.1694]]],


        [[[ 0.3924,  0.3351, -0.2210],
          [-0.3755, -0.1339,  0.3505],
          [-0.1859,  0.0849, -0.1109]]],


        [[[ 0.3868, -0.0481,  0.3359],
          [ 0.0497, -0.0952,  0.2387],
          [-0.1587,  0.1403, -0.0474]]],


        [[[-0.0246, -0.3671, -0.1821],
          [ 0.2409, -0.0310,  0.3586],
          [ 0.3353,  0.0215, -0.2283]]],


        [[[ 0.2614,  0.1904, -0.1782],
          [ 0.2923,  0.3085,  0.3480],
          [-0.1351,  0.0121,  0.0247]]],


        [[[ 0.0540, -0.3455, -0.3605],
          [ 0.2408, -0.0068,  0.1748],
          [ 0.0752,  0.2101,  0.1340]]],


        [[

We can now test our fine-tuned network:

In [12]:
enable_disable_lora(enabled=True)
test()

Accuracy: 0.977
wrong counts for the digit 0: 96
wrong counts for the digit 1: 208
wrong counts for the digit 2: 232
wrong counts for the digit 3: 155
wrong counts for the digit 4: 126
wrong counts for the digit 5: 165
wrong counts for the digit 6: 52
wrong counts for the digit 7: 69
wrong counts for the digit 8: 92
wrong counts for the digit 9: 161


In [13]:
enable_disable_lora(enabled=False)
test()

Accuracy: 0.977
wrong counts for the digit 0: 96
wrong counts for the digit 1: 208
wrong counts for the digit 2: 232
wrong counts for the digit 3: 155
wrong counts for the digit 4: 126
wrong counts for the digit 5: 165
wrong counts for the digit 6: 52
wrong counts for the digit 7: 69
wrong counts for the digit 8: 92
wrong counts for the digit 9: 161
