In [1]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm

In [2]:
# make th torch deterministic
_ = torch.manual_seed(0)

In [3]:
import torch.utils


transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081))])

# Load mnist dataset
mnist_train = datasets.MNIST(root="./data", train=True, download=TimeoutError, transform=transform)
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=10, shuffle=True)

mnist_test= datasets.MNIST(root="./data", download=True, train=False, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=10, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
# Create an overly expensive NN to classify MNMIST dataset
class RichBoyNet(nn.Module):
    def __init__(self, hidden_state_size1=1000, hidden_state_size2=2000):
        super(RichBoyNet, self).__init__()
        self.linear1 = nn.Linear(28*28, hidden_state_size1)
        self.linear2 = nn.Linear(hidden_state_size1, hidden_state_size2)
        self.linear3 = nn.Linear(hidden_state_size2, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x
    

net = RichBoyNet().to(device)


In [5]:
# train the newtwork for only 1 epoch to simulate a complete general pre-training on the data

def train(train_loader, net, epochs=5, total_iterations_limit=None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

    total_iterations = 0
    for epoch in range(epochs):
        net.train()
        loss_sum = 0
        num_iterations = 0
        data_iterator = tqdm(train_loader, desc=f"Epoch {epoch+1}")
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1

            x, y = data
            x = x.to(device)
            y = y.to(device)

            # forward pass
            optimizer.zero_grad()
            output = net(x.view(-1, 28*28))
            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum/num_iterations
            data_iterator.set_postfix(loss=avg_loss)

            # backward pass
            loss.backward()
            optimizer.step()
            

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return


train(train_loader, net, epochs=1)

Epoch 1: 100%|██████████| 6000/6000 [04:57<00:00, 20.15it/s, loss=0.238]


In [6]:
# Keep the original weights (cloning them), so later we can prove that finetuning 
# with LoRA doesn't alter the oginal weights
original_weights = {}
for name, param in net.named_parameters():
    original_weights[name] = param.clone().detach()


- The performance of the pretrained network

In [7]:
def test():
    correct = 0
    total = 0
    wrong_counts = [0 for i in range(10)]

    with torch.no_grad():
        for data in tqdm(test_loader, desc="Testing"):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = net(x.view(-1, 28*28))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct += 1
                else:
                    wrong_counts[y[idx]] += 1
                total += 1
    print(f"Accuracy: {round(correct/total, 3)}")
    for i in range(len(wrong_counts)):
        print(f"Wrong counts for the digit {i}: {wrong_counts[i]}")

test()

Testing: 100%|██████████| 1000/1000 [00:04<00:00, 219.14it/s]

Accuracy: 0.961
Wrong counts for the digit 0: 10
Wrong counts for the digit 1: 18
Wrong counts for the digit 2: 41
Wrong counts for the digit 3: 83
Wrong counts for the digit 4: 28
Wrong counts for the digit 5: 23
Wrong counts for the digit 6: 60
Wrong counts for the digit 7: 34
Wrong counts for the digit 8: 17
Wrong counts for the digit 9: 79





- Visualize how many parameters are in the Original network, before introducing the LoRA matrices

In [8]:
# print the size of the wrights metrices of the network
# save the count of the total numer of parameters
total_parameters_original = 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_parameters_original += layer.weight.nelement() + layer.bias.nelement()
    print(f"Layer {index +1}: W {layer.weight.shape} + B: {layer.bias.shape}")

print(f"Toal number of parameters: {total_parameters_original:,}")

Layer 1: W torch.Size([1000, 784]) + B: torch.Size([1000])
Layer 2: W torch.Size([2000, 1000]) + B: torch.Size([2000])
Layer 3: W torch.Size([10, 2000]) + B: torch.Size([10])
Toal number of parameters: 2,807,010


##### Define th LoRA parametrization

In [9]:
class LoRAParametrization(nn.Module):
    def __init__(self, features_in, features_out, rank=1, alpha=1, device='cpu'):
        super().__init__()
        # section 4.1 of the paper
        # Use random Gaussian initialization for A and Zero for B, so changeW = BA is zero at the beginning
        self.lora_A = nn.Parameter(torch.zeros((rank, features_out)).to(device))
        self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
        nn.init.normal_(self.lora_A, mean=0, std=1)

        # We then scale changeW by alpha/rank, wher alpha is a constant in r. 
        # when optimizing with Adam, tuning alpha is roughly the same as tuning the learning rate if we scale
        # the initlization appropriately. 
        # As a result, we simply set alpha to the first r we try and do not tune it
        # THis scaling helps to reduce the need to retune hyperparameters when we try r. 
        self.scale = alpha/rank
        self.enabled = True

    def forward(self, original_weights):
        if self.enabled:
            # Return W + (B*A)*scale
            return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
        else:
            return original_weights


- Add the parametrization to out network

In [10]:
import torch.nn.utils.parametrize as parametrize

def linear_layer_parametrization(layer, device, rank=1, lora_alpha=1):
    # only add the parametrization to the weight matrix, ignore the bias
    features_in, features_out = layer.weight.shape
    return LoRAParametrization(features_in=features_in, features_out=features_out, rank=rank, alpha=lora_alpha, device=device)

parametrize.register_parametrization(
    net.linear1, "weight", linear_layer_parametrization(net.linear1, device)
)
parametrize.register_parametrization(
    net.linear2, "weight", linear_layer_parametrization(net.linear2, device)
)
parametrize.register_parametrization(
    net.linear3, "weight", linear_layer_parametrization(net.linear3, device)
)

def enable_disable_lora(enabled=True):
    for layer in [net.linear1, net.linear2, net.linear3]:
        layer.parametrizations["weight"][0].enabled = enabled

- Number of parameters added by LoRA

In [11]:
total_parameters_lora = 0
total_parameters_non_lora = 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_parameters_lora += layer.parametrizations['weight'][0].lora_A.nelement() + layer.parametrizations['weight'][0].lora_B.nelement()
    total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
    print(f"Layer {index+1} W: {layer.weight.shape} + B: {layer.bias.shape} + Lora_A: {layer.parametrizations['weight'][0].lora_A.shape} + Lora_B: {layer.parametrizations['weight'][0].lora_B.shape}")


# The non-LoRA parameters count should match the original network
assert total_parameters_non_lora == total_parameters_original
print(f"Total number of parameters (Original): {total_parameters_non_lora:,}")
print(f"Total parameters (original + lora): {total_parameters_lora + total_parameters_original:,}")
print(f"Parameters introduced by Lora: {total_parameters_lora:,}")
parameters_increment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f"Pramteres increment: {parameters_increment:.3f}")

Layer 1 W: torch.Size([1000, 784]) + B: torch.Size([1000]) + Lora_A: torch.Size([1, 784]) + Lora_B: torch.Size([1000, 1])
Layer 2 W: torch.Size([2000, 1000]) + B: torch.Size([2000]) + Lora_A: torch.Size([1, 1000]) + Lora_B: torch.Size([2000, 1])
Layer 3 W: torch.Size([10, 2000]) + B: torch.Size([10]) + Lora_A: torch.Size([1, 2000]) + Lora_B: torch.Size([10, 1])
Total number of parameters (Original): 2,807,010
Total parameters (original + lora): 2,813,804
Parameters introduced by Lora: 6,794
Pramteres increment: 0.242


#### Freeze all the paramters of the original network and only FInetune the ones introduced by LoRA

In [12]:
# Freeze the non-lora paramters

for name, param in net.named_parameters():
    if 'lora' not in name: 
        print(f"Freezing non-LoRA parameter: {name}")
        param.requires_grad = False

# Load the Mnist dataset again, by keeping only digit 9
mnist_trainset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
exclude_indeces = mnist_trainset.targets == 9
mnist_trainset.data = mnist_trainset.data[exclude_indeces]
mnist_trainset.targets = mnist_trainset.targets[exclude_indeces]

# dataloader
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

Freezing non-LoRA parameter: linear1.bias
Freezing non-LoRA parameter: linear1.parametrizations.weight.original
Freezing non-LoRA parameter: linear2.bias
Freezing non-LoRA parameter: linear2.parametrizations.weight.original
Freezing non-LoRA parameter: linear3.bias
Freezing non-LoRA parameter: linear3.parametrizations.weight.original


In [13]:
# train
train(train_loader, net, epochs=1, total_iterations_limit=100)

Epoch 1:  99%|█████████▉| 99/100 [00:02<00:00, 42.72it/s, loss=0.0749]


#### Verify that fine-tuning didn't alter the original weights, but only the ones introduced by LoRA

In [14]:
# Check that frozen parameters are still unchanged
assert torch.all(net.linear1.parametrizations['weight'].original == original_weights['linear1.weight']), "Assertion on first line"
assert torch.all(net.linear2.parametrizations.weight.original == original_weights['linear2.weight'])
assert torch.all(net.linear3.parametrizations.weight.original == original_weights['linear3.weight'])

enable_disable_lora(enabled=True)

'''
The New linear1.weight is obtained by the forward function of our LoRA parametrizations, 
The original weights have been moved to net.linear1.parametrizations.weight.original
'''
assert torch.equal(
    net.linear1.weight, net.linear1.parametrizations.weight.original +
    (net.linear1.parametrizations.weight[0].lora_B @ net.linear1.parametrizations.weight[0].lora_A) * 
    net.linear1.parametrizations.weight[0].scale), "Is not equal"
enable_disable_lora(enabled=False)
# if we disable LoRA, the linear1.weight is hte original one
assert torch.equal(net.linear1.weight, original_weights['linear1.weight'])               

In [15]:
# Add this debugging code before the assertion
print("Shapes:")
print(f"Stored original weights shape: {original_weights['linear1.weight'].shape}")
print(f"Current original weights shape: {net.linear1.parametrizations['weight'].original.shape}")

print("\nFirst few values:")
print("Stored original:")
print(original_weights['linear1.weight'][:5, :5])
print("\nCurrent original:")
print(net.linear1.parametrizations['weight'].original[:5, :5])

print("\nMax absolute difference:")
print(torch.max(torch.abs(net.linear1.parametrizations['weight'].original - original_weights['linear1.weight'])))

# Then try the assertion
assert torch.all(net.linear1.parametrizations['weight'].original == original_weights['linear1.weight'])

Shapes:
Stored original weights shape: torch.Size([1000, 784])
Current original weights shape: torch.Size([1000, 784])

First few values:
Stored original:
tensor([[0.0300, 0.0495, 0.0009, 0.0040, 0.0166],
        [0.0242, 0.0290, 0.0336, 0.0092, 0.0459],
        [0.0357, 0.0707, 0.0226, 0.0355, 0.0826],
        [0.0587, 0.0097, 0.0444, 0.0334, 0.0658],
        [0.0460, 0.0355, 0.0283, 0.0452, 0.0198]])

Current original:
tensor([[0.0300, 0.0495, 0.0009, 0.0040, 0.0166],
        [0.0242, 0.0290, 0.0336, 0.0092, 0.0459],
        [0.0357, 0.0707, 0.0226, 0.0355, 0.0826],
        [0.0587, 0.0097, 0.0444, 0.0334, 0.0658],
        [0.0460, 0.0355, 0.0283, 0.0452, 0.0198]])

Max absolute difference:
tensor(0.)


- Test the network with LoRA, enabled (digit 9 should be classified better)

In [16]:
# Test with LoRA enabled
enable_disable_lora(enabled=True)
test()

Testing: 100%|██████████| 1000/1000 [00:09<00:00, 100.38it/s]

Accuracy: 0.943
Wrong counts for the digit 0: 14
Wrong counts for the digit 1: 22
Wrong counts for the digit 2: 58
Wrong counts for the digit 3: 109
Wrong counts for the digit 4: 98
Wrong counts for the digit 5: 50
Wrong counts for the digit 6: 88
Wrong counts for the digit 7: 78
Wrong counts for the digit 8: 37
Wrong counts for the digit 9: 14





- Test network with LoRA disabled (the accuracy and error counts should be same as the original netwowrk)

In [17]:
# Test without LORA
enable_disable_lora(enabled=False)
test()

Testing: 100%|██████████| 1000/1000 [00:05<00:00, 193.03it/s]

Accuracy: 0.961
Wrong counts for the digit 0: 10
Wrong counts for the digit 1: 18
Wrong counts for the digit 2: 41
Wrong counts for the digit 3: 83
Wrong counts for the digit 4: 28
Wrong counts for the digit 5: 23
Wrong counts for the digit 6: 60
Wrong counts for the digit 7: 34
Wrong counts for the digit 8: 17
Wrong counts for the digit 9: 79



