In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
from tqdm import tqdm

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
# Setting up LoRA class which adds two trainable matrices on top of the layers's original weights W
class LoRA(nn.Module):
    def __init__(self, features_in, features_out, rank=1, alpha=1, device=device):
        super().__init__()

        self.mat_A = nn.Parameter(torch.zeros((rank,features_out)).to(device))
        self.mat_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
        nn.init.normal_(self.mat_A, mean=0, std=1)

        self.scale = alpha / rank


    def forward(self, W):
        return W + torch.matmul(self.mat_B, self.mat_A).view(W.shape) * self.scale


In [4]:
#This function takes the layer as the input and sets the features_in,features_out
#equal to the shape of the weight matrix. This will help the LoRA class to
#initialize the A and B Matrices

def layer_parametrization(layer, device, rank = 1, lora_alpha = 1):
  features_in, features_out = layer.weight.shape
  return LoRA(features_in, features_out, rank = rank, alpha = lora_alpha, device = device)

In [5]:
_ = torch.manual_seed(42)

In [6]:
#Make a transform pipeline so that we can make the data training ready
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1301,),(0.3081,))])

mnist_train = datasets.MNIST(root = './data', train = True, download = True, transform = transform)
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size = 10, shuffle = True)

mnist_test = datasets.MNIST(root = './data', train = False, download = True, transform = transform)
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size = 10, shuffle = True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 105164975.90it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 40418983.59it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz



100%|██████████| 1648877/1648877 [00:00<00:00, 29321273.08it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 6088376.08it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



In [7]:
#creating a demanding model so that it has more parameters
#1000 neurons for first layer and 2000 layers for the second

class exp_clf(nn.Module):

  def __init__(self, neurons_1 = 1000, neurons_2 = 2000):
    super(exp_clf, self).__init__()
    self.linear1 = nn.Linear(28*28, neurons_1)
    self.linear2 = nn.Linear(neurons_1, neurons_2)
    self.linear3 = nn.Linear(neurons_2, 10)
    self.relu = nn.ReLU()

  def forward(self, img):
    x = img.view(-1, 28*28)
    x = self.relu(self.linear1(x))
    x = self.relu(self.linear2(x))
    x = self.linear3(x)
    return x

exp = exp_clf().to(device)

In [8]:
#let's create a training loop

def train(train_loader, mod, epochs = 5, total_iterations_limits = None):
  cel = nn.CrossEntropyLoss()
  optim = torch.optim.Adam(mod.parameters(), lr=0.001)

  total_iterations = 0

  for epoch in range(epochs):
    mod.train()

    loss_sum = 0
    num_iteration = 0

    data_iterator = tqdm(train_loader, desc = f'Number of Epoch:{epoch+1}')
    if total_iterations_limits is not None:
      data_iterator.total = total_iterations_limits
    for data in data_iterator:
      num_iteration += 1
      total_iterations += 1
      x, y = data
      x = x.to(device)
      y = y.to(device)
      optim.zero_grad()
      output = mod(x.view(-1, 28*28))
      loss = cel(output, y)
      loss_sum += loss.item()
      avg_loss = loss_sum / num_iteration
      data_iterator.set_postfix(loss = avg_loss)
      loss.backward()
      optim.step()

      if total_iterations_limits is not None and total_iterations >= total_iterations_limits:
        return

In [9]:
train(train_loader, exp, epochs = 1)

Number of Epoch:1: 100%|██████████| 6000/6000 [00:46<00:00, 129.25it/s, loss=0.237]


In [10]:
original_weights = {}
for name, params in exp.named_parameters():
  original_weights[name] = params.clone().detach()

In [11]:
# testing the model and looking out for the wrongly identified cases
def test():
  correct = 0
  total = 0

  wrong_counts = [0 for i in range(10)]

  with torch.no_grad():
    for data in tqdm(test_loader, desc = 'testing'):
      x, y = data
      x = x.to(device)
      y = y.to(device)
      output = exp(x.view(-1, 784))
      for idx, i in enumerate(output):
        if torch.argmax(i) == y[idx]:
          correct += 1
        else:
          wrong_counts[y[idx]] += 1
        total +=1

  print(f'\nAccuracy: {round(correct/total, 3)}')
  for i in range(len(wrong_counts)):
    print(f'wrong counts for the digit {i}: {wrong_counts[i]}')
test()

testing: 100%|██████████| 1000/1000 [00:03<00:00, 256.86it/s]


Accuracy: 0.953
wrong counts for the digit 0: 49
wrong counts for the digit 1: 5
wrong counts for the digit 2: 43
wrong counts for the digit 3: 29
wrong counts for the digit 4: 42
wrong counts for the digit 5: 42
wrong counts for the digit 6: 96
wrong counts for the digit 7: 67
wrong counts for the digit 8: 50
wrong counts for the digit 9: 48





In [12]:
total_params = 0

for idx, layer in enumerate([exp.linear1, exp.linear2, exp.linear3]):
  total_params += layer.weight.nelement() + layer.bias.nelement()
  print(f'Layer {idx+1}: W: {layer.weight.shape} + B: {layer.bias.shape}')

print('The total trainable parameters of our model are:', total_params)

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10])
The total trainable parameters of our model are: 2807010


In [13]:
#Parametrizing such as when the model calls for weights, it gets weights plus the two trainable matrices introduced.

import torch.nn.utils.parametrize as parametrize

parametrize.register_parametrization(exp.linear1, 'weight', layer_parametrization(exp.linear1, device))
parametrize.register_parametrization(exp.linear2, 'weight', layer_parametrization(exp.linear2, device))
parametrize.register_parametrization(exp.linear3, 'weight', layer_parametrization(exp.linear3, device))


ParametrizedLinear(
  in_features=2000, out_features=10, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): LoRA()
    )
  )
)

In [14]:
total_parameters_lora = 0
total_parameters_non_lora = 0

for index, layer in enumerate([exp.linear1, exp.linear2, exp.linear3]):
  total_parameters_lora += layer.parametrizations['weight'][0].mat_A.nelement() + layer.parametrizations['weight'][0].mat_B.nelement()
  total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
  print(
        f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + mat_A: {layer.parametrizations["weight"][0].mat_A.shape} + mat_B: {layer.parametrizations["weight"][0].mat_B.shape}'
    )

assert total_parameters_non_lora == total_params
print(f'Total number of parameters (original): {total_parameters_non_lora:,}')
print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}')
print(f'Parameters introduced by LoRA: {total_parameters_lora:,}')
parameters_incremment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f'Parameters incremment: {parameters_incremment:.3f}%')

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000]) + mat_A: torch.Size([1, 784]) + mat_B: torch.Size([1000, 1])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000]) + mat_A: torch.Size([1, 1000]) + mat_B: torch.Size([2000, 1])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10]) + mat_A: torch.Size([1, 2000]) + mat_B: torch.Size([10, 1])
Total number of parameters (original): 2,807,010
Total number of parameters (original + LoRA): 2,813,804
Parameters introduced by LoRA: 6,794
Parameters incremment: 0.242%


In [15]:
for name, param in exp.named_parameters():
    if 'mat' not in name:
        print(f'Freezing non-LoRA parameter {name}')
        param.requires_grad = False

# Load the MNIST dataset again, by keeping only the digit 7
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
exclude_indices = mnist_trainset.targets == 7
mnist_trainset.data = mnist_trainset.data[exclude_indices]
mnist_trainset.targets = mnist_trainset.targets[exclude_indices]
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

for layer in [exp.linear1, exp.linear2, exp.linear3]:
  layer.parametrizations["weight"][0].requires_grad = True

# Train the network with LoRA only on the digit 7 and only for 100 batches (hoping that it would improve the performance on the digit 7)
train(train_loader, exp, epochs=1, total_iterations_limits=100)

Freezing non-LoRA parameter linear1.bias
Freezing non-LoRA parameter linear1.parametrizations.weight.original
Freezing non-LoRA parameter linear2.bias
Freezing non-LoRA parameter linear2.parametrizations.weight.original
Freezing non-LoRA parameter linear3.bias
Freezing non-LoRA parameter linear3.parametrizations.weight.original


Number of Epoch:1:  99%|█████████▉| 99/100 [00:00<00:00, 124.36it/s, loss=0.0307]


In [16]:
test()

testing: 100%|██████████| 1000/1000 [00:03<00:00, 280.70it/s]


Accuracy: 0.893
wrong counts for the digit 0: 52
wrong counts for the digit 1: 247
wrong counts for the digit 2: 144
wrong counts for the digit 3: 55
wrong counts for the digit 4: 66
wrong counts for the digit 5: 55
wrong counts for the digit 6: 111
wrong counts for the digit 7: 1
wrong counts for the digit 8: 137
wrong counts for the digit 9: 205





In [17]:
# The wrong counts for 7 has reduced to a great extent