## LoRA - Low Rank Application of Large Language Models

### Why LoRA is important?

- LoRA significantly reduce the computational cost of Fine-Tuning a Model

## What is Fine-Tuning?

##### Fine-tuning Definition:
- Training a pre-trained network on new data
- Enhances performace for specific task

##### Example Fine-tune a Language Model (LLM):

- Initially trained on multiple programming languages.
- Target: Improved performance specifically for SQL.

##### Process

- Utilize existing knowledge from initial training
- Adjust parameters for the new task
- Train on task specific data to refine the performence

##### Outcome

- Enhanced model proficiency in targated area

##### Thought:

- Fine-tuning leverages the foundation laid by pre-training, making AI models adaptable and specialized for diverse applications.


## Problems with fine tuning

##### Computational Expense:

- Training the entire network during fine-tuning is computaionally expensive
- Particularly for users who are dealing with an Large Language Model

##### Storage Challenges:

- Parameter checkpoint requirements are very high
- Saving the complete model for each checkpoint is computationally expensive

##### Model switching Complexity:

- Reloading all the model weights when switching between fine tuned models is necessary
- Can be both time-consuming and resource-intensive.
- Example: Different models for SQL queries and JavaScript code assistance.

## How LoRA solves the problem?

##### Basic Overview
- Neural networks contain dense layers performing matrix multiplication.
- Weight matrices typically have full rank.
- Aghajanyan et al. (2020) demonstrated that pre-trained language models have a low "intrinsic dimension."
- **Hypothesis of LoRA**: Updates to weights during adaptation have a low "intrinsic rank."

##### Weight Update Representation:
- For a pre-trained weight matrix  $$W_0\in ℝ^{(d×k)}:$$
  - Constrain update by low-rank decomposition: $$ W_0 + \Delta W = W_0 + BA $$
  - Where $$ B \in \mathbb{R}^{d \times r},  A \in \mathbb{R}^{r \times k} \text{  and  } r \leq \min(d, k)$$



## Libraries

In [1]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
from torchviz import make_dot
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

- Make the model determinastic (sets the seed for PyTorch's random number generator to 0, ensuring reproducibility in random processes.)

In [2]:
_ = torch.manual_seed(0)

#### we will be training a network to classify MNIST digits and fine-tune the network on a particular digit on which model is not performed well

transformation includes <br>
- converting the image into a tensor
- Normalizing the tensor with a mean of 0.1307 and standard deviation 0.3081
- These specific values are choosen based on the statistics of the MNIST dataset (explained later)
- We can use mean and standard deviation of the standard dataset
- The purpose of normalization is to scale the pixel values so that they have a similar scale

In [3]:
#creates a sequence of image transformations, normalizes the image to have a mean of 0.1307 and a standard deviation of 0.3081
transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])


In [4]:
# Load MNIST dataset
mnist_trainset = datasets.MNIST(root = './data', train=True, download=True, transform=transforms)

# create a loader
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=64, shuffle=True)

In [5]:
# Calculate the mean and standard deviation for the MNIST dataset
mnist_mean = mnist_trainset.data.float().mean() / 255
mnist_std = mnist_trainset.data.float().std() / 255

print("Mean:", mnist_mean)
print("Standard Deviation:", mnist_std)


Mean: tensor(0.1307)
Standard Deviation: tensor(0.3081)


In [6]:
# Load the mnist test dataset
mnist_testset = datasets.MNIST(root = './data', train=False, download=True, transform=transforms)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=64, shuffle=True)

In [7]:
# define the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

#### Create an overly expensive neural network to classify MNIST digits

In [8]:
class ExpensiveNet(nn.Module):
    def __init__(self, hidden_size_1 = 1000, hidden_size_2 = 2000):
        super(ExpensiveNet, self).__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x
net = ExpensiveNet().to(device)

In [9]:
# Dummy input for visualization
dummy_input = torch.randn(1, 1, 28, 28).to(device)

# Visualize the network
make_dot(net(dummy_input), params=dict(net.named_parameters()))

# Create a graph of the neural network
graph = make_dot(net(dummy_input), params=dict(net.named_parameters()))

# # Save the graph as an image
# graph.render(filename='expensive_net', format='png', cleanup=True)

# Display the graph
# graph.view()

'expensive_net.pdf'

Train the network only for few epochs to simulate a complete general pre-training on the data

In [14]:
def train(train_loader, net, epochs = 5, total_iteration_limit = None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
    
    total_iteration = 0

    for epoch in range(epochs):
        net.train()

        loss_sum = 0
        num_iterations = 0

        data_iterator = tqdm(train_loader, desc=f'Training (epoch = {epoch+1}/{epochs})')
        if total_iteration_limit is not None:
            data_iterator.total = total_iteration_limit
        for data in data_iterator:
            num_iterations += 1
            total_iteration += 1
            x,y = data

            x = x.to(device)
            y = y.to(device)

            optimizer.zero_grad()
            output = net(x.view(-1, 28*28))
            loss = cross_el(output, y)
            loss_sum = loss_sum + loss.item()
            avg_loss = loss_sum/num_iterations
            data_iterator.set_postfix(loss = avg_loss)
            loss.backward()
            optimizer.step()

            if total_iteration_limit is not None and total_iteration >= total_iteration_limit:
                break

In [15]:
train(train_loader, net, epochs = 3)

Training (epoch = 1/5):   0%|          | 0/938 [00:00<?, ?it/s]

Training (epoch = 2/5):   0%|          | 0/938 [00:00<?, ?it/s]

Training (epoch = 3/5):   0%|          | 0/938 [00:00<?, ?it/s]

Training (epoch = 4/5):   0%|          | 0/938 [00:00<?, ?it/s]

Training (epoch = 5/5):   0%|          | 0/938 [00:00<?, ?it/s]

#### Keep a copy of original weights (cloning them), so later we can prove that a fine-tuning with LoRA does not alter the original weights

In [16]:
original_weights = {}

for name, param in net.named_parameters():
    original_weights[name] = param.clone().detach()

#### Let's see the performance of the pre-trained network

In [20]:
def test():
    correct = 0
    total = 0

    wrong_counts = torch.zeros(10)

    with torch.no_grad():
        for data in test_loader:
           x,y = data
           x = x.to(device)
           y = y.to(device)
           output = net(x.view(-1, 28*28))
           for idx, i in enumerate(output):
               if torch.argmax(i) == y[idx]:
                   correct += 1
               else:
                    wrong_counts[y[idx]] += 1
               total += 1

    print("Accuracy: ", round(correct/total, 3))
    for i in range(10):
        print(f"Number {i} wrong count: {int(wrong_counts[i])}")

test()


Accuracy:  0.98
Number 0 wrong count: 8
Number 1 wrong count: 5
Number 2 wrong count: 29
Number 3 wrong count: 18
Number 4 wrong count: 13
Number 5 wrong count: 29
Number 6 wrong count: 16
Number 7 wrong count: 20
Number 8 wrong count: 24
Number 9 wrong count: 34


In [31]:
def count_parameters(model):
    total_params = 0
    for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
        total_params += layer.weight.nelement() + layer.bias.nelement()
        print(f"Layer {index} : Weight: {layer.weight.shape} + Bias: {layer.bias.shape}")
    print(f"Total Parameters: {total_params}")

count_parameters(net)

Layer 0 : Weight: torch.Size([1000, 784]) + Bias: torch.Size([1000])
Layer 1 : Weight: torch.Size([2000, 1000]) + Bias: torch.Size([2000])
Layer 2 : Weight: torch.Size([10, 2000]) + Bias: torch.Size([10])
Total Parameters: 2807010
