In [103]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm

In [104]:
device= 'cuda' if torch.cuda.is_available() else 'cpu'

In [105]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))])

# Load the MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Load the MNIST test set
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

In [106]:
class DigiNet(nn.Module):
  def __init__(self, HiddenLayer1=1500):
    super(DigiNet, self).__init__()
    self.linear1=nn.Linear(28*28, HiddenLayer1)
    self.linear2=nn.Linear(HiddenLayer1, 10)
    self.relu=nn.ReLU()

  def forward(self, img):
    x=img.view(-1, 28*28)
    x=self.relu(self.linear1(x))
    x=self.linear2(x)
    return x

net=DigiNet().to(device)

In [107]:
def train(train_loader, net, epochs=7):
  CELoss=nn.CrossEntropyLoss()
  optim=torch.optim.Adam(net.parameters(), lr=0.001)
  total_loss=0
  num_iterations=0

  for epoch in range(epochs):
    net.train() #enable dropout and training
    data_iterator=tqdm(train_loader, desc=f'Epoch {epoch +1}')
    for data in data_iterator:
      num_iterations+=1
      x, y=data
      x=x.to(device)
      y=y.to(device)
      optim.zero_grad()
      output=net(x.view(-1, 28*28))
      loss=CELoss(output, y)
      total_loss+=loss.item()
      avg_loss=total_loss/num_iterations
      data_iterator.set_postfix(loss=avg_loss)
      loss.backward()
      optim.step()

In [108]:
original_weights = {}
for name, param in net.named_parameters():
    original_weights[name] = param.clone().detach()

In [109]:
def test():
  correct=0
  total=0

  wrong_counts=[0 for i in range(10)]
  with torch.no_grad():
    for data in tqdm(test_loader, desc='Testing'):
      x,y =data
      x=x.to(device)
      y=y.to(device)
      output=net(x.view(-1, 28*28))

      for idx, i in enumerate(output):
        if torch.argmax(i)==y[idx]:
          correct+=1
        else:
          wrong_counts[y[idx]]+=1
        total+=1
    print(f'Accuracy: {round(correct/total, 3)}')

    for i in range(len(wrong_counts)):
        print(f'wrong counts for digit {i}: {wrong_counts[i]}')
test()

Testing: 100%|██████████| 1000/1000 [00:03<00:00, 251.00it/s]

Accuracy: 0.12
wrong counts for digit 0: 894
wrong counts for digit 1: 1126
wrong counts for digit 2: 834
wrong counts for digit 3: 1005
wrong counts for digit 4: 961
wrong counts for digit 5: 787
wrong counts for digit 6: 956
wrong counts for digit 7: 1028
wrong counts for digit 8: 888
wrong counts for digit 9: 318





In [110]:
total_original_parameter=0
for index, layer in enumerate([net.linear1, net.linear2]):
  total_original_parameter+=layer.weight.nelement()+ layer.bias.nelement()
  print(f'Layer {index+1}: W: {layer.weight.shape} + B:{layer.bias.shape}')
print(f'Total parameters: {total_original_parameter}')


Layer 1: W: torch.Size([1500, 784]) + B:torch.Size([1500])
Layer 2: W: torch.Size([10, 1500]) + B:torch.Size([10])
Total parameters: 1192510


## Lora Parameterization:

In [111]:
class LoRAParameterization(nn.Module):
  def __init__(self, in_features, out_features,rank=1, alpha=1, device='cpu'):
    super().__init__()
    self.LoRA_A=nn.Parameter(torch.zeros((rank, in_features)).to(device))
    self.LoRA_B=nn.Parameter(torch.zeros((out_features, rank)).to(device))
    nn.init.normal_(self.LoRA_A, mean=0, std=1)
    self.scale=alpha/rank
    self.enabled=True

  def forward(self, original_weights):
    if self.enabled:
      return original_weights + torch.matmul(self.LoRA_B, self.LoRA_A).view(original_weights.shape)*self.scale
    else:
      return original_weights


In [112]:
import torch.nn.utils.parametrize as parametrize

def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):
  in_features, out_features=layer.weight.shape
  return LoRAParameterization(
      in_features, out_features,rank=rank, alpha=lora_alpha, device=device
  )

parametrize.register_parametrization(
    net.linear1, "weight", linear_layer_parameterization(net.linear1, device)
)

parametrize.register_parametrization(
    net.linear2, "weight", linear_layer_parameterization(net.linear2, device)
)

ParametrizedLinear(
  in_features=1500, out_features=10, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): LoRAParameterization()
    )
  )
)

In [113]:
def enable_disable_lora(enabled=True):
  for layer in [net.linear1, net.linear2]:
    layer.parametrizations['weight'][0].enabled=enabled

In [114]:
total_lora_params=0
for index, layers in enumerate([net.linear1, net.linear2]):
  total_lora_params+=layer.parametrizations['weight'][0].LoRA_A.nelement() + layer.parametrizations['weight'][0].LoRA_B.nelement()


print(f'Total LoRA parameters: {total_lora_params}')

Total LoRA parameters: 3020


In [115]:
#@ Freezing non-lora params:
for name, param in net.named_parameters():
  if 'lora' not in name:
    param.requires_grad=False

for layer in [net.linear1, net.linear2]:
    layer.parametrizations.weight[0].LoRA_A.requires_grad = True
    layer.parametrizations.weight[0].LoRA_B.requires_grad = True

In [116]:
#@ MNIST Data for digit 9:
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
exclude_indices = mnist_trainset.targets == 9
mnist_trainset.data = mnist_trainset.data[exclude_indices]
mnist_trainset.targets = mnist_trainset.targets[exclude_indices]
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

train(train_loader, net, epochs=1)

Epoch 1: 100%|██████████| 595/595 [00:03<00:00, 178.13it/s, loss=0.0144]


In [117]:
# Check that the frozen parameters are still unchanged by the finetuning
assert torch.all(net.linear1.parametrizations.weight.original == original_weights['linear1.weight'])
assert torch.all(net.linear2.parametrizations.weight.original == original_weights['linear2.weight'])


enable_disable_lora(enabled=True)
# The new linear1.weight is obtained by the "forward" function of our LoRA parametrization
# The original weights have been moved to net.linear1.parametrizations.weight.original
# More info here: https://pytorch.org/tutorials/intermediate/parametrizations.html#inspecting-a-parametrized-module
assert torch.equal(net.linear1.weight, net.linear1.parametrizations.weight.original + (net.linear1.parametrizations.weight[0].LoRA_B @ net.linear1.parametrizations.weight[0].LoRA_A) * net.linear1.parametrizations.weight[0].scale)

enable_disable_lora(enabled=False)
# If we disable LoRA, the linear1.weight is the original one
assert torch.equal(net.linear1.weight, original_weights['linear1.weight'])

RuntimeError: The size of tensor a (784) must match the size of tensor b (1500) at non-singleton dimension 1

In [118]:

# Test with LoRA enabled
enable_disable_lora(enabled=True)
test()

Testing: 100%|██████████| 1000/1000 [00:04<00:00, 244.22it/s]

Accuracy: 0.101
wrong counts for digit 0: 980
wrong counts for digit 1: 1135
wrong counts for digit 2: 1032
wrong counts for digit 3: 1010
wrong counts for digit 4: 982
wrong counts for digit 5: 892
wrong counts for digit 6: 958
wrong counts for digit 7: 1028
wrong counts for digit 8: 974
wrong counts for digit 9: 0





In [119]:
enable_disable_lora(enabled=False)
test()

Testing: 100%|██████████| 1000/1000 [00:03<00:00, 302.56it/s]

Accuracy: 0.12
wrong counts for digit 0: 894
wrong counts for digit 1: 1126
wrong counts for digit 2: 834
wrong counts for digit 3: 1005
wrong counts for digit 4: 961
wrong counts for digit 5: 787
wrong counts for digit 6: 956
wrong counts for digit 7: 1028
wrong counts for digit 8: 888
wrong counts for digit 9: 318



