<a href="https://colab.research.google.com/github/samitha278/transformer-optim/blob/main/lora_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision.datasets as datasets
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
from tqdm import tqdm

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Dataset prp

In [3]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Load the MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Load the MNIST test set
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

100%|██████████| 9.91M/9.91M [00:01<00:00, 6.11MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 160kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.52MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 9.71MB/s]


## Simple Neural Network

In [23]:
class SimpleNN(nn.Module):
    def __init__(self, hidden_size_1=1000, hidden_size_2=2000):
        super().__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x

simple_nn = SimpleNN().to(device)

### vanila SGD from scratch

In [5]:
class SGD:
    def __init__(self, parameters, lr: float = 0.01):

        self.params = parameters  # Store references to parameters
        self.lr = lr  # Learning rate

    def zero_grad(self):
        """Clear gradients of all parameters"""
        for param in self.params:
            # Check if parameter has gradients
            if param.grad is not None:
                # Set gradient tensor to zeros (in-place)
                param.grad.zero_()

    def step(self):
        """Update parameters using gradients"""
        for param in self.params:
            # Skip parameters without gradients
            if param.grad is None:
                continue

            # SGD update rule: param = param - lr * gradient
            # .data accesses the tensor data without affecting gradient computation
            param.data = param.data - self.lr * param.grad

## Trainer

In [6]:
cross_el = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(simple_nn.parameters())
epochs = 1

losses = []
val_losses = []
j = 0

for i in range(epochs):

    simple_nn.train()          # activate train mode

    for x , y in iter(train_loader):
        x,y = x.to(device),y.to(device)  # move to device

        logits = simple_nn(x)
        loss = cross_el(logits,y)

        optimizer.zero_grad()  # param.grad = 0
        loss.backward()        # update param.grad
        optimizer.step()       # update param.data



        # validation
        if j%100==0:
            total_val_loss,k  = 0,0
            simple_nn.eval()

            with torch.no_grad():
                for a,b in iter(test_loader):
                    a,b = a.to(device),b.to(device)

                    out = simple_nn(a)
                    error = F.cross_entropy(out,b)

                    k+=1
                    total_val_loss+=error.item()

                val_losses.append(total_val_loss/k)
            simple_nn.train()


        # extra
        loss = loss.item()
        losses.append(loss)
        j+=1
        if j% 1000==0:
            print(f'{j} loss: {loss}')



1000 loss: 0.05644175410270691


KeyboardInterrupt: 

In [None]:
plt.plot(losses)
plt.plot(val_losses)

In [49]:
simple_nn.state_dict().keys()

odict_keys(['linear1.weight', 'linear1.bias', 'linear2.weight', 'linear2.bias', 'linear3.weight', 'linear3.bias'])

In [60]:
next(iter(simple_nn.parameters()))

Parameter containing:
tensor([[-0.0242, -0.0309, -0.0184,  ..., -0.0270, -0.0259,  0.0354],
        [ 0.0036,  0.0334, -0.0028,  ...,  0.0171, -0.0355, -0.0274],
        [ 0.0137,  0.0197,  0.0243,  ...,  0.0232, -0.0323,  0.0021],
        ...,
        [-0.0055, -0.0012, -0.0032,  ...,  0.0128,  0.0101,  0.0280],
        [-0.0236, -0.0075,  0.0189,  ...,  0.0004,  0.0069,  0.0291],
        [-0.0227, -0.0200,  0.0302,  ...,  0.0206,  0.0128,  0.0034]],
       requires_grad=True)

In [61]:
next(iter(simple_nn.named_parameters()))     # contain name

('linear1.weight',
 Parameter containing:
 tensor([[-0.0242, -0.0309, -0.0184,  ..., -0.0270, -0.0259,  0.0354],
         [ 0.0036,  0.0334, -0.0028,  ...,  0.0171, -0.0355, -0.0274],
         [ 0.0137,  0.0197,  0.0243,  ...,  0.0232, -0.0323,  0.0021],
         ...,
         [-0.0055, -0.0012, -0.0032,  ...,  0.0128,  0.0101,  0.0280],
         [-0.0236, -0.0075,  0.0189,  ...,  0.0004,  0.0069,  0.0291],
         [-0.0227, -0.0200,  0.0302,  ...,  0.0206,  0.0128,  0.0034]],
        requires_grad=True))

In [65]:
next(iter(simple_nn.named_modules()))     # contain name

('',
 SimpleNN(
   (linear1): Linear(in_features=784, out_features=1000, bias=True)
   (linear2): Linear(in_features=1000, out_features=2000, bias=True)
   (linear3): Linear(in_features=2000, out_features=10, bias=True)
   (relu): ReLU()
 ))

## LoRA

In [20]:
class LoRALayer(nn.Module):

    def __init__(self,fan_in , fan_out, r=1,alpha=1 ):
        super().__init__()

        # random Gaussian initialization for A and zero for B
        # ∆W = BA is zero at the beginning of training

        self.A = nn.Parameter(torch.randn(r,fan_out)).to(device)
        self.B = nn.Parameter(torch.zeros(fan_in,r)).to(device)

        self.scale = alpha / r
        self.enabled = True


    def forward(self,original_weights):

        if self.enabled:
            return original_weights + self.scale* (self.B @ self.A)

        return original_weights


In [28]:
# lora layer test

layer1_w = simple_nn.linear1.weight
print(layer1_w.shape)

i,o = layer1_w.T.shape

lora1 = LoRALayer(i,o)

param1 = lora1.__dict__['_parameters']
A1,B1 = param1['A'], param1['B']

A1.shape,B1.shape,(B1@A1).shape

torch.Size([1000, 784])


(torch.Size([1, 1000]), torch.Size([784, 1]), torch.Size([784, 1000]))

In [29]:
detail_dict = simple_nn.__dict__
detail_dict

{'training': True,
 '_parameters': {},
 '_buffers': {},
 '_non_persistent_buffers_set': set(),
 '_backward_pre_hooks': OrderedDict(),
 '_backward_hooks': OrderedDict(),
 '_is_full_backward_hook': None,
 '_forward_hooks': OrderedDict(),
 '_forward_hooks_with_kwargs': OrderedDict(),
 '_forward_hooks_always_called': OrderedDict(),
 '_forward_pre_hooks': OrderedDict(),
 '_forward_pre_hooks_with_kwargs': OrderedDict(),
 '_state_dict_hooks': OrderedDict(),
 '_state_dict_pre_hooks': OrderedDict(),
 '_load_state_dict_pre_hooks': OrderedDict(),
 '_load_state_dict_post_hooks': OrderedDict(),
 '_modules': {'linear1': Linear(in_features=784, out_features=1000, bias=True),
  'linear2': Linear(in_features=1000, out_features=2000, bias=True),
  'linear3': Linear(in_features=2000, out_features=10, bias=True),
  'relu': ReLU()}}

In [47]:
full_dict = simple_nn.__dict__['_modules']

# full_dict['linear1'].__dict__      # contains all info of Linear layer

config = {}
for name , module in simple_nn.__dict__['_modules'].items():
    if isinstance(module,nn.Linear):
        print(module)
        config[name] = (module.in_features,module.out_features)

config

Linear(in_features=784, out_features=1000, bias=True)
Linear(in_features=1000, out_features=2000, bias=True)
Linear(in_features=2000, out_features=10, bias=True)


{'linear1': (784, 1000), 'linear2': (1000, 2000), 'linear3': (2000, 10)}

In [52]:
simple_nn.state_dict().keys()

odict_keys(['linear1.weight', 'linear1.bias', 'linear2.weight', 'linear2.bias', 'linear3.weight', 'linear3.bias'])

## NN with LoRA

In [66]:
class SimpleNNLoRA(nn.Module):

    def __init__(self,config):
        super().__init__()

        l1_in , l1_out = config['linear1']
        self.lora1 = LoRALayer(l1_in,l1_out)

        l2_in , l2_out = config['linear2']
        self.lora2 = LoRALayer(l2_in,l2_out)

        l3_in , l3_out = config['linear3']
        self.lora3 = LoRALayer(l3_in,l3_out)

        self.relu = nn.ReLU()


    def forward(self,x,state_dict):

        x = self.relu(torch.matmul(x,self.lora1(state_dict['linear1.weight']).T) + state_dict['linear1.bias'])
        x = self.relu(torch.matmul(x,self.lora2(state_dict['linear2.weight']).T) + state_dict['linear2.bias'])
        x = torch.matmul(x,self.lora3(state_dict['linear3.weight']).T) + state_dict['linear3.bias']

        return x



In [70]:
state_dict = simple_nn.state_dict()

config = {}
for name , module in simple_nn.__dict__['_modules'].items():
    if isinstance(module,nn.Linear):
        # print(module)
        config[name] = (module.in_features,module.out_features)




simple_nn_with_lora = SimpleNNLoRA(config)

simple_nn_with_lora.__dict__    # now only train LoRA params

{'training': True,
 '_parameters': {},
 '_buffers': {},
 '_non_persistent_buffers_set': set(),
 '_backward_pre_hooks': OrderedDict(),
 '_backward_hooks': OrderedDict(),
 '_is_full_backward_hook': None,
 '_forward_hooks': OrderedDict(),
 '_forward_hooks_with_kwargs': OrderedDict(),
 '_forward_hooks_always_called': OrderedDict(),
 '_forward_pre_hooks': OrderedDict(),
 '_forward_pre_hooks_with_kwargs': OrderedDict(),
 '_state_dict_hooks': OrderedDict(),
 '_state_dict_pre_hooks': OrderedDict(),
 '_load_state_dict_pre_hooks': OrderedDict(),
 '_load_state_dict_post_hooks': OrderedDict(),
 '_modules': {'lora1': LoRALayer(),
  'lora2': LoRALayer(),
  'lora3': LoRALayer(),
  'relu': ReLU()}}