In [49]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm

In [50]:
_ = torch.manual_seed(0)

In [51]:
#Loading the data
transform = transforms.Compose([transforms.ToTensor()])
train_data=datasets.MNIST(root='./data',train=True,download=True,transform=transform)
train_loader=torch.utils.data.DataLoader(train_data,batch_size=100,shuffle=True)
test_data=datasets.MNIST(root='./data',train=False,download=True,transform=transform)
test_loader=torch.utils.data.DataLoader(test_data,batch_size=100,shuffle=True)


In [52]:
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [53]:
class SampleNet(nn.Module):
  def __init__(self,hidden_size1=1000,hidden_size2=2000):
    super().__init__()
    self.linear1=nn.Linear(28*28,hidden_size1)
    self.linear2=nn.Linear(hidden_size1,hidden_size2)
    self.linear3=nn.Linear(hidden_size2,10)
    self.relu=nn.ReLU()
  def forward(self,img):
    x=img.view(-1,28*28)
    x=self.relu(self.linear1(x))
    x=self.relu(self.linear2(x))
    x=self.linear3(x)
    return x
model=SampleNet().to(device)

In [54]:
#Training the network to complete pretraining process
def train(train_loader,net,epochs=5,total_iterations_limit=None):
  cross_en=nn.CrossEntropyLoss()
  optimizer=torch.optim.Adam(net.parameters(),lr=0.001)
  total_iterations=0
  for epoch in range(epochs):
    model.train()
    loss_sum=0
    num_iterations=0
    data_iterator=tqdm(train_loader,desc=f'Epoch {epoch+1}')
    if total_iterations is not None:
      data_iterator.total=total_iterations_limit
    for data in data_iterator:
      num_iterations+=1
      total_iterations+=1
      x,y=data
      x=x.to(device)
      y=y.to(device)
      optimizer.zero_grad()
      output=model(x.view(-1,28*28))
      loss=cross_en(output,y)
      loss_sum+=loss.item()
      avg_loss=loss_sum/num_iterations
      data_iterator.set_postfix({'loss':f'{avg_loss:.2f}'})
      loss.backward()
      optimizer.step()
      if total_iterations_limit is not None and total_iterations>=total_iterations_limit:
        break

train(train_loader,model,epochs=1)

Epoch 1: 600it [00:57, 10.41it/s, loss=0.20]


In [55]:
#Keeping copy of orignal weights
orignal_weights={}
for name,param in model.named_parameters():
  orignal_weights[name]=param.clone().detach()


In [56]:
def test():
  correct=0
  total=0
  wrong_counts=[0 for i in range(10)]
  with torch.no_grad():
    for data in tqdm(test_loader,desc='Testing'):
      x,y=data
      x=x.to(device)
      y=y.to(device)
      output=model(x.view(-1,28*28))
      for idx,i in enumerate(output):
        if torch.argmax(i)==y[idx]:
          correct+=1
        else:
          wrong_counts[y[idx]]+=1
        total+=1
  print(f'Accuracy: {round(correct/total,3)}')
  for i in range(len(wrong_counts)):
    print(f'Wrong counts for class {i}: {wrong_counts[i]}')
test()

Testing: 100%|██████████| 100/100 [00:03<00:00, 26.88it/s]

Accuracy: 0.97
Wrong counts for class 0: 13
Wrong counts for class 1: 13
Wrong counts for class 2: 34
Wrong counts for class 3: 51
Wrong counts for class 4: 30
Wrong counts for class 5: 13
Wrong counts for class 6: 16
Wrong counts for class 7: 43
Wrong counts for class 8: 47
Wrong counts for class 9: 40





In [57]:
#Class 8 doesn't perform that well
total_parameter=0
for index,layer in enumerate([model.linear1,model.linear2,model.linear3]):
  total_parameter+=layer.weight.nelement()+layer.bias.nelement()
  print(f'Layer{index+1}:W:{layer.weight.shape}+B:{layer.bias.shape}')
print(f'Total number of parameters: {total_parameter}')

Layer1:W:torch.Size([1000, 784])+B:torch.Size([1000])
Layer2:W:torch.Size([2000, 1000])+B:torch.Size([2000])
Layer3:W:torch.Size([10, 2000])+B:torch.Size([10])
Total number of parameters: 2807010


In [58]:
class LoRaParameterized(nn.Module):
  def __init__(self,features_in,features_out,rank=1,alpha=1,device='cpu'):
    super().__init__()
    self.lora_A=nn.Parameter(torch.zeros((rank,features_out)).to(device))
    self.lora_B=nn.Parameter(torch.zeros((features_in,rank)).to(device))
    nn.init.normal_(self.lora_A,mean=0,std=1)

    self.scale=alpha/rank
    self.enabled=True

  def forward(self,orignal_weights):
    if self.enabled:
      return orignal_weights+torch.matmul(self.lora_B,self.lora_A).view(orignal_weights.shape)*self.scale
    else:
      return orignal_weights

In [59]:
import torch.nn.utils.parametrize as parametrize
def linear_layer_parameterization(layer,device,rank=1,alpha=1):
  features_in,features_out=layer.weight.shape
  return LoRaParameterized(features_in,features_out,rank=rank,alpha=alpha,device=device)

In [60]:
parametrize.register_parametrization(
    model.linear1,'weight',linear_layer_parameterization(model.linear1,device)
)
parametrize.register_parametrization(
    model.linear2,'weight',linear_layer_parameterization(model.linear2,device)
)
parametrize.register_parametrization(
    model.linear3,'weight',linear_layer_parameterization(model.linear3,device)
)

ParametrizedLinear(
  in_features=2000, out_features=10, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): LoRaParameterized()
    )
  )
)

In [61]:
def enable_device(enabled=True):
  for layer in [model.linear1,model.linear2,model.linear3]:
    layer.parametrizations["weight"][0].enabled=enabled

In [62]:
total_parameters_lora=0
total_parameters_non_lora=0

for index,layer in enumerate([model.linear1,model.linear2,model.linear3]):
  total_parameters_lora += layer.parametrizations["weight"][0].lora_A.nelement() + layer.parametrizations["weight"][0].lora_B.nelement()
  total_parameters_non_lora+=layer.weight.nelement()+layer.bias.nelement()
assert total_parameters_non_lora == total_parameter
print(f'Total number of parameters (original): {total_parameters_non_lora:,}')
print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}')
print(f'Parameters introduced by LoRA: {total_parameters_lora:,}')
parameters_incremment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f'Parameters incremment: {parameters_incremment:.3f}%')

Total number of parameters (original): 2,807,010
Total number of parameters (original + LoRA): 2,813,804
Parameters introduced by LoRA: 6,794
Parameters incremment: 0.242%


In [63]:
#freeze the non lora paramters
for name,param in model.named_parameters():
  if 'lora' not in name:
    param.requires_grad=False
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
exclude_indices = mnist_trainset.targets == 9
mnist_trainset.data = mnist_trainset.data[exclude_indices]
mnist_trainset.targets = mnist_trainset.targets[exclude_indices]
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Train the network with LoRA only on the digit 9 and only for 100 batches (hoping that it would improve the performance on the digit 9)
train(train_loader, model, epochs=1, total_iterations_limit=100)

Epoch 1:  99%|█████████▉| 99/100 [00:03<00:00, 26.09it/s, loss=0.02]


In [64]:
# Check that the frozen parameters are still unchanged by the finetuning
assert torch.all(model.linear1.parametrizations.weight.original == orignal_weights['linear1.weight'])
assert torch.all(model.linear2.parametrizations.weight.original == orignal_weights['linear2.weight'])
assert torch.all(model.linear3.parametrizations.weight.original == orignal_weights['linear3.weight'])

enable_device(enabled=True)
# The new linear1.weight is obtained by the "forward" function of our LoRA parametrization
# The original weights have been moved to net.linear1.parametrizations.weight.original
# More info here: https://pytorch.org/tutorials/intermediate/parametrizations.html#inspecting-a-parametrized-module
assert torch.equal(model.linear1.weight, model.linear1.parametrizations.weight.original + (model.linear1.parametrizations.weight[0].lora_B @ model.linear1.parametrizations.weight[0].lora_A) * model.linear1.parametrizations.weight[0].scale)

enable_device(enabled=False)
# If we disable LoRA, the linear1.weight is the original one
assert torch.equal(model.linear1.weight, orignal_weights['linear1.weight'])

In [65]:
# Test with LoRA enabled
enable_device(enabled=True)
test()

Testing: 100%|██████████| 100/100 [00:04<00:00, 20.48it/s]

Accuracy: 0.771
Wrong counts for class 0: 128
Wrong counts for class 1: 44
Wrong counts for class 2: 99
Wrong counts for class 3: 244
Wrong counts for class 4: 636
Wrong counts for class 5: 101
Wrong counts for class 6: 225
Wrong counts for class 7: 344
Wrong counts for class 8: 461
Wrong counts for class 9: 3





In [66]:
# Test with LoRA disabled
enable_device(enabled=False)
test()

Testing: 100%|██████████| 100/100 [00:03<00:00, 31.01it/s]

Accuracy: 0.97
Wrong counts for class 0: 13
Wrong counts for class 1: 13
Wrong counts for class 2: 34
Wrong counts for class 3: 51
Wrong counts for class 4: 30
Wrong counts for class 5: 13
Wrong counts for class 6: 16
Wrong counts for class 7: 43
Wrong counts for class 8: 47
Wrong counts for class 9: 40



