<a href="https://colab.research.google.com/github/samitha278/transformer-optim/blob/main/lora_from_scratch_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [77]:
import torch
import torch.nn as nn
import torch.nn.utils.parametrize as P

import torchvision.datasets as datasets
import torchvision.transforms as transforms

In [78]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

## Dataset

In [79]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

In [110]:
# Load the MNIST test set
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

In [80]:
# Load the MNIST dataset again, by keeping only the digit 7
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
indices = mnist_trainset.targets == 7
mnist_trainset.data = mnist_trainset.data[indices]
mnist_trainset.targets = mnist_trainset.targets[indices]
# Create a dataloader for the training
train_loader_7 = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

## Model

In [81]:
class SimpleNN(nn.Module):
    def __init__(self, hidden_size_1=1000, hidden_size_2=2000):
        super().__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x

In [82]:
model = SimpleNN()

path = '/content/drive/MyDrive/model_weights.pth'
state_dict = torch.load(path, map_location=device)
model.load_state_dict(state_dict)
model.to(device)

SimpleNN(
  (linear1): Linear(in_features=784, out_features=1000, bias=True)
  (linear2): Linear(in_features=1000, out_features=2000, bias=True)
  (linear3): Linear(in_features=2000, out_features=10, bias=True)
  (relu): ReLU()
)

---
## Parametrization

In [83]:
class Norm1(nn.Module):
    def forward(self,W):
        return W+W.mean(dim=-1,keepdim=True)

In [84]:
P.register_parametrization(model.linear1,'weight',Norm1())

ParametrizedLinear(
  in_features=784, out_features=1000, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): Norm1()
    )
  )
)

In [85]:
model.__dict__['_modules']

{'linear1': ParametrizedLinear(
   in_features=784, out_features=1000, bias=True
   (parametrizations): ModuleDict(
     (weight): ParametrizationList(
       (0): Norm1()
     )
   )
 ),
 'linear2': Linear(in_features=1000, out_features=2000, bias=True),
 'linear3': Linear(in_features=2000, out_features=10, bias=True),
 'relu': ReLU()}

#### Check original vs parametrized weights

In [86]:
model.linear1.weight

tensor([[ 0.0216,  0.0764,  0.0437,  ...,  0.0337,  0.0645,  0.0849],
        [ 0.0627,  0.0119,  0.0187,  ...,  0.0414,  0.0102,  0.0744],
        [-0.0130,  0.0348,  0.0329,  ...,  0.0154,  0.0023, -0.0019],
        ...,
        [ 0.0954,  0.0566,  0.0446,  ...,  0.1086,  0.0640,  0.0441],
        [ 0.0401,  0.0019,  0.0490,  ...,  0.0571, -0.0048,  0.0222],
        [ 0.0135, -0.0111,  0.0175,  ...,  0.0358, -0.0039, -0.0060]],
       grad_fn=<AddBackward0>)

In [87]:
model.linear1.parametrizations.weight.original

Parameter containing:
tensor([[ 0.0031,  0.0579,  0.0253,  ...,  0.0152,  0.0460,  0.0665],
        [ 0.0492, -0.0016,  0.0052,  ...,  0.0278, -0.0033,  0.0608],
        [-0.0084,  0.0394,  0.0375,  ...,  0.0200,  0.0069,  0.0028],
        ...,
        [ 0.0701,  0.0314,  0.0193,  ...,  0.0833,  0.0388,  0.0188],
        [ 0.0326, -0.0056,  0.0415,  ...,  0.0496, -0.0123,  0.0147],
        [ 0.0119, -0.0128,  0.0158,  ...,  0.0342, -0.0056, -0.0076]],
       requires_grad=True)

In [89]:
model.linear1.parametrizations["weight"][0]

Norm1()

#### Remove parametrizations

In [90]:
P.remove_parametrizations(model.linear1,"weight")

Linear(in_features=784, out_features=1000, bias=True)

In [91]:
model.__dict__['_modules']

{'linear1': Linear(in_features=784, out_features=1000, bias=True),
 'linear2': Linear(in_features=1000, out_features=2000, bias=True),
 'linear3': Linear(in_features=2000, out_features=10, bias=True),
 'relu': ReLU()}

---

## LoRA parametrization

In [92]:
class LoRA(nn.Module):

    def __init__(self,fan_in , fan_out, r=1,alpha=1 ):
        super().__init__()

        # random Gaussian initialization for A and zero for B
        # âˆ†W = BA is zero at the beginning of training

        self.A = nn.Parameter(torch.randn(r,fan_out))
        self.B = nn.Parameter(torch.zeros(fan_in,r))

        self.scale = alpha / r
        self.enabled = True


    def forward(self,original_weights):

        if self.enabled:

            d_w = self.scale* (self.B @ self.A)
            return original_weights + d_w.T

        return original_weights

### Apply LoRA to model

In [102]:
def apply_lora(layer):
    fan_in,fan_out = layer.in_features,layer.out_features

    return LoRA(fan_in,fan_out).to(device)

In [103]:
P.register_parametrization(model.linear1,"weight",apply_lora(model.linear1))
P.register_parametrization(model.linear2,"weight",apply_lora(model.linear2))
P.register_parametrization(model.linear3,"weight",apply_lora(model.linear3))

ParametrizedLinear(
  in_features=2000, out_features=10, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): LoRA()
    )
  )
)

In [104]:
model.__dict__['_modules']

{'linear1': ParametrizedLinear(
   in_features=784, out_features=1000, bias=True
   (parametrizations): ModuleDict(
     (weight): ParametrizationList(
       (0): LoRA()
     )
   )
 ),
 'linear2': ParametrizedLinear(
   in_features=1000, out_features=2000, bias=True
   (parametrizations): ModuleDict(
     (weight): ParametrizationList(
       (0): LoRA()
     )
   )
 ),
 'linear3': ParametrizedLinear(
   in_features=2000, out_features=10, bias=True
   (parametrizations): ModuleDict(
     (weight): ParametrizationList(
       (0): LoRA()
     )
   )
 ),
 'relu': ReLU()}

### Enable Disable LoRA

In [107]:
def enable_lora(enabled = True):
    for layer in [model.linear1,model.linear2,model.linear3]:
        layer.parametrizations["weight"][0].enabled = enabled

In [108]:
enable_lora()

## Train parametrized model

In [75]:
def train(model,train_loader):
    model.train()
    optimizer = torch.optim.AdamW(model.parameters())
    cross_el = nn.CrossEntropyLoss()

    for i in range(1000):
        x,y = next(iter(train_loader))
        x,y = x.to(device),y.to(device)

        logits = model(x)
        loss = cross_el(logits,y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i%100==0 :
            print(f'{i}, {loss.item()}')

In [109]:
train(model,train_loader_7)

0, 0.1605968177318573
100, 0.0
200, 0.0
300, 0.0
400, 0.0
500, 0.0
600, 0.0
700, 0.0
800, 0.0
900, 0.0


### Test when LoRA enable

In [111]:
def test(model):
    model.eval()

    total = 0
    correct = 0
    wrong_counts = [0 for i in range(10)]

    for x,y in iter(test_loader):
        x,y = x.to(device) , y.to(device)

        logits = model(x)
        predicts = torch.argmax(logits,dim=-1)

        for i,predict in enumerate(predicts):
            if predict==y[i] :
                correct+=1
            else:
                wrong_counts[y[i]]+=1
            total+=1

    accuracy = correct/total
    for i in range(len(wrong_counts)):
        print(f'wrong counts for the digit {i}: {wrong_counts[i]}')

In [112]:
test(model)

wrong counts for the digit 0: 980
wrong counts for the digit 1: 1135
wrong counts for the digit 2: 1006
wrong counts for the digit 3: 385
wrong counts for the digit 4: 967
wrong counts for the digit 5: 364
wrong counts for the digit 6: 348
wrong counts for the digit 7: 0
wrong counts for the digit 8: 504
wrong counts for the digit 9: 1009


### Test when LoRA disable

In [113]:
enable_lora(enabled=False)  # Disable lora : now original weights

test(model)

wrong counts for the digit 0: 978
wrong counts for the digit 1: 1124
wrong counts for the digit 2: 881
wrong counts for the digit 3: 304
wrong counts for the digit 4: 883
wrong counts for the digit 5: 285
wrong counts for the digit 6: 292
wrong counts for the digit 7: 0
wrong counts for the digit 8: 495
wrong counts for the digit 9: 1003
