# LoRA from Scratch
In this notebook, we are going to implement the LoRa fine tuning in a Pre-trained PyTorch model.

## Imports

In [1]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt

from tqdm import tqdm

Make the model deterministic

In [2]:
# Make torch deterministic
_ = torch.manual_seed(0)

## Dataset
We will be training a network to classify FashionMNIST and then fine-tune the network on a particular class on which it doesn't perform well.

In [3]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Load the FashionMNIST dataset
fmnist_trainset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)

# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(fmnist_trainset, batch_size=16, shuffle=True)

# Load the FashionMNIST test set
fmnist_testset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(fmnist_testset, batch_size=16, shuffle=True)

# Define the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:02<00:00, 11494270.20it/s]


Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 208933.34it/s]


Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:01<00:00, 3882853.15it/s]


Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 17771421.39it/s]


Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw



## Model Building
Create the Neural Network to classify the fashion product, making it overly complicated to better show the power of LoRA

In [4]:
# Create an overly expensive neural network to classify FashionMNIST
class HeavyNet(nn.Module):
    def __init__(self, hidden_size_1=1000, hidden_size_2=2000):
        super(HeavyNet,self).__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x

model = HeavyNet().to(device)

Train the network only for 4 epoch to simulate a complete general pre-training on the data

In [5]:
def train(train_loader, net, epochs=5, total_iterations_limit=None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

    total_iterations = 0

    for epoch in range(epochs):
        net.train()

        loss_sum = 0
        num_iterations = 0

        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = net(x.view(-1, 28*28))
            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return

train(train_loader, model, epochs=40)

Epoch 1: 100%|██████████| 3750/3750 [00:38<00:00, 97.16it/s, loss=0.475] 
Epoch 2: 100%|██████████| 3750/3750 [00:43<00:00, 86.58it/s, loss=0.369]
Epoch 3: 100%|██████████| 3750/3750 [00:37<00:00, 99.86it/s, loss=0.333]
Epoch 4: 100%|██████████| 3750/3750 [00:38<00:00, 97.47it/s, loss=0.308] 
Epoch 5: 100%|██████████| 3750/3750 [00:37<00:00, 100.45it/s, loss=0.293]
Epoch 6: 100%|██████████| 3750/3750 [00:37<00:00, 100.55it/s, loss=0.282]
Epoch 7: 100%|██████████| 3750/3750 [00:36<00:00, 101.92it/s, loss=0.273]
Epoch 8: 100%|██████████| 3750/3750 [00:37<00:00, 101.04it/s, loss=0.26]
Epoch 9: 100%|██████████| 3750/3750 [00:37<00:00, 100.57it/s, loss=0.245]
Epoch 10: 100%|██████████| 3750/3750 [00:37<00:00, 101.27it/s, loss=0.243]
Epoch 11: 100%|██████████| 3750/3750 [00:36<00:00, 102.24it/s, loss=0.235]
Epoch 12: 100%|██████████| 3750/3750 [00:36<00:00, 101.93it/s, loss=0.228]
Epoch 13: 100%|██████████| 3750/3750 [00:36<00:00, 102.03it/s, loss=0.223]
Epoch 14: 100%|██████████| 3750/3750 

Keep a copy of the original weights (cloning them) so later we can prove that a fine-tuning with LoRA doesn't alter the original weights

In [6]:
original_weights = {}
for name, param in model.named_parameters():
    original_weights[name] = param.clone().detach()

In [7]:
print(original_weights)

{'linear1.weight': tensor([[ 0.1223,  0.1458,  0.0794,  ...,  0.1505,  0.1159,  0.0941],
        [ 0.2599,  0.2648,  0.2694,  ...,  0.1045,  0.1344,  0.2050],
        [-0.2794, -0.2387, -0.2705,  ..., -0.1876, -0.0399, -0.2510],
        ...,
        [ 0.0669,  0.1504,  0.0941,  ..., -0.0678, -0.0242,  0.0169],
        [ 0.2425,  0.2048,  0.1815,  ...,  0.0613,  0.1004,  0.1938],
        [ 0.2224,  0.1993,  0.3989,  ...,  0.3100,  0.4096,  0.4887]],
       device='cuda:0'), 'linear1.bias': tensor([-1.4711e-01, -2.5835e-01,  2.6955e-01, -3.0588e-01, -3.0351e-01,
        -2.1274e-01, -1.9532e-01, -2.3322e-01, -3.0170e-01, -3.2989e-01,
        -4.2280e-01, -2.9504e-01, -1.6234e-01, -3.6246e-01, -2.8369e-01,
        -2.7464e-01, -2.4800e-01, -2.6639e-01, -2.1724e-01, -2.7043e-01,
        -2.3653e-01, -2.8108e-01, -6.6564e-02, -3.5127e-01, -1.7921e-01,
        -2.3703e-01, -2.8320e-01, -2.1528e-01, -2.8924e-01, -2.8332e-01,
        -3.7025e-01, -3.4820e-01, -2.1242e-01, -2.1207e-01, -3.4194e

The the performance of the pretrained network.

In [8]:
def test():
    correct = 0
    total = 0

    wrong_counts = [0 for i in range(10)]

    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = model(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct +=1
                else:
                    wrong_counts[y[idx]] +=1
                total +=1
    print(f'Accuracy: {round(correct/total, 3)}')
    for i in range(len(wrong_counts)):
        print(f'wrong counts for the class {i}: {wrong_counts[i]}')

test()

Testing: 100%|██████████| 625/625 [00:03<00:00, 163.24it/s]

Accuracy: 0.886
wrong counts for the class 0: 156
wrong counts for the class 1: 17
wrong counts for the class 2: 158
wrong counts for the class 3: 132
wrong counts for the class 4: 200
wrong counts for the class 5: 22
wrong counts for the class 6: 348
wrong counts for the class 7: 40
wrong counts for the class 8: 37
wrong counts for the class 9: 34





As we can see, the network performs poorly on the class 6. Let's fine-tune it on the class 6

Let's visualize how many parameters are in the original network, before introducing the LoRA matrices.

In [9]:
# Print the size of the weights matrices of the network
# Save the count of the total number of parameters
total_parameters_original = 0
for index, layer in enumerate([model.linear1, model.linear2, model.linear3]):
    total_parameters_original += layer.weight.nelement() + layer.bias.nelement()
    print(f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape}')
print(f'Total number of parameters: {total_parameters_original:,}')

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10])
Total number of parameters: 2,807,010


## Implementing LoRA
Define the LoRA parameterization as described in the paper.
The full detail on how PyTorch parameterizations work is here: https://pytorch.org/tutorials/intermediate/parametrizations.html

In [10]:
class LoRAParametrization(nn.Module):
    def __init__(self, features_in, features_out, rank=1, alpha=1, device='cpu'):
        super().__init__()
        # Section 4.1 of the paper:
        #   We use a random Gaussian initialization for A and zero for B, so ∆W = BA is zero at the beginning of training
        self.lora_A = nn.Parameter(torch.zeros((rank,features_out)).to(device))
        self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
        nn.init.normal_(self.lora_A, mean=0, std=1)

        # Section 4.1 of the paper:
        #   We then scale ∆Wx by α/r , where α is a constant in r.
        #   When optimizing with Adam, tuning α is roughly the same as tuning the learning rate if we scale the initialization appropriately.
        #   As a result, we simply set α to the first r we try and do not tune it.
        #   This scaling helps to reduce the need to retune hyperparameters when we vary r.
        self.scale = alpha / rank
        self.enabled = True

    def forward(self, original_weights):
        if self.enabled:
            # Return W + (B*A)*scale
            return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
        else:
            return original_weights

Add the parameterization to our network.

In [11]:
import torch.nn.utils.parametrize as parametrize

def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):
    # Only add the parameterization to the weight matrix, ignore the Bias

    # From section 4.2 of the paper:
    #   We limit our study to only adapting the attention weights for downstream tasks and freeze the MLP modules (so they are not trained in downstream tasks) both for simplicity and parameter-efficiency.
    #   [...]
    #   We leave the empirical investigation of [...], and biases to a future work.

    features_in, features_out = layer.weight.shape
    return LoRAParametrization(
        features_in, features_out, rank=rank, alpha=lora_alpha, device=device
    )

In [12]:
parametrize.register_parametrization(
    model.linear1, "weight", linear_layer_parameterization(model.linear1, device)
)
parametrize.register_parametrization(
    model.linear2, "weight", linear_layer_parameterization(model.linear2, device)
)
parametrize.register_parametrization(
    model.linear3, "weight", linear_layer_parameterization(model.linear3, device)
)

ParametrizedLinear(
  in_features=2000, out_features=10, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): LoRAParametrization()
    )
  )
)

In [13]:
def enable_disable_lora(enabled=True):
    for layer in [model.linear1, model.linear2, model.linear3]:
        layer.parametrizations["weight"][0].enabled = enabled

Display the number of parameters added by LoRA.

In [14]:
total_parameters_lora = 0
total_parameters_non_lora = 0
for index, layer in enumerate([model.linear1, model.linear2, model.linear3]):
    total_parameters_lora += layer.parametrizations["weight"][0].lora_A.nelement() + layer.parametrizations["weight"][0].lora_B.nelement()
    total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
    print(
        f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + Lora_A: {layer.parametrizations["weight"][0].lora_A.shape} + Lora_B: {layer.parametrizations["weight"][0].lora_B.shape}'
    )
# The non-LoRA parameters count must match the original network
assert total_parameters_non_lora == total_parameters_original

print(f'Total number of parameters (original): {total_parameters_non_lora:,}')
print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}')
print(f'Parameters introduced by LoRA: {total_parameters_lora:,}')
parameters_incremment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f'Parameters incremment: {parameters_incremment:.3f}%')

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000]) + Lora_A: torch.Size([1, 784]) + Lora_B: torch.Size([1000, 1])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000]) + Lora_A: torch.Size([1, 1000]) + Lora_B: torch.Size([2000, 1])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10]) + Lora_A: torch.Size([1, 2000]) + Lora_B: torch.Size([10, 1])
Total number of parameters (original): 2,807,010
Total number of parameters (original + LoRA): 2,813,804
Parameters introduced by LoRA: 6,794
Parameters incremment: 0.242%


Freeze all the parameters of the original network and only fine tuning the ones introduced by LoRA. Then fine-tune the model on the class 6 and only for 100 batches.

In [15]:
# Freeze the non-Lora parameters
for name, param in model.named_parameters():
    if 'lora' not in name:
        print(f'Freezing non-LoRA parameter {name}')
        param.requires_grad = False

# Load the FashionMNIST dataset again, by keeping only the class 6
fmnist_trainset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
exclude_indices = fmnist_trainset.targets == 6
fmnist_trainset.data = fmnist_trainset.data[exclude_indices]
fmnist_trainset.targets = fmnist_trainset.targets[exclude_indices]
print(fmnist_trainset)

Freezing non-LoRA parameter linear1.bias
Freezing non-LoRA parameter linear1.parametrizations.weight.original
Freezing non-LoRA parameter linear2.bias
Freezing non-LoRA parameter linear2.parametrizations.weight.original
Freezing non-LoRA parameter linear3.bias
Freezing non-LoRA parameter linear3.parametrizations.weight.original
Dataset FashionMNIST
    Number of datapoints: 6000
    Root location: ./data
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
               Normalize(mean=(0.1307,), std=(0.3081,))
           )


In [16]:
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(fmnist_trainset, batch_size=10, shuffle=True)

# Train the network with LoRA only on the class 6 and only for 100 batches (hoping that it would improve the performance on the class 6)
train(train_loader, model, epochs=1, total_iterations_limit=100)

Epoch 1:  99%|█████████▉| 99/100 [00:00<00:00, 109.67it/s, loss=0.0828]


Verify that the fine-tuning didn't alter the original weights, but only the ones introduced by LoRA.

In [17]:
# Check that the frozen parameters are still unchanged by the finetuning
assert torch.all(model.linear1.parametrizations.weight.original == original_weights['linear1.weight'])
assert torch.all(model.linear2.parametrizations.weight.original == original_weights['linear2.weight'])
assert torch.all(model.linear3.parametrizations.weight.original == original_weights['linear3.weight'])

enable_disable_lora(enabled=True)
# The new linear1.weight is obtained by the "forward" function of our LoRA parametrization
# The original weights have been moved to model.linear1.parametrizations.weight.original
# More info here: https://pytorch.org/tutorials/intermediate/parametrizations.html#inspecting-a-parametrized-module
assert torch.equal(model.linear1.weight, model.linear1.parametrizations.weight.original + (model.linear1.parametrizations.weight[0].lora_B @ model.linear1.parametrizations.weight[0].lora_A) * model.linear1.parametrizations.weight[0].scale)

enable_disable_lora(enabled=False)
# If we disable LoRA, the linear1.weight is the original one
assert torch.equal(model.linear1.weight, original_weights['linear1.weight'])

Test the network with LoRA enabled (the class 6 should be classified better)

In [18]:
# Test with LoRA enabled
enable_disable_lora(enabled=True)
test()

Testing: 100%|██████████| 625/625 [00:03<00:00, 179.24it/s]

Accuracy: 0.718
wrong counts for the class 0: 862
wrong counts for the class 1: 26
wrong counts for the class 2: 559
wrong counts for the class 3: 518
wrong counts for the class 4: 646
wrong counts for the class 5: 16
wrong counts for the class 6: 46
wrong counts for the class 7: 36
wrong counts for the class 8: 76
wrong counts for the class 9: 37





Test the network with LoRA disabled (the accuracy and errors counts must be the same as the original network)

In [19]:
# Test with LoRA disabled
enable_disable_lora(enabled=False)
test()

Testing: 100%|██████████| 625/625 [00:03<00:00, 198.34it/s]

Accuracy: 0.886
wrong counts for the class 0: 156
wrong counts for the class 1: 17
wrong counts for the class 2: 158
wrong counts for the class 3: 132
wrong counts for the class 4: 200
wrong counts for the class 5: 22
wrong counts for the class 6: 348
wrong counts for the class 7: 40
wrong counts for the class 8: 37
wrong counts for the class 9: 34



