# Low Rank Adaptation
---

In [1]:
# !pip install torch
# !pip install torch==2.0.1+cu118 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# !pip install numpy==1.26.4

In [1]:
import torch.nn as nn
import torch.nn.functional as F
import torch

In [2]:
if torch.cuda.is_available():
    torch.backends.cudnn.deterministic=True
    DEVICE=torch.device('cuda')
else:
    DEVICE=torch.device('cpu')

print(DEVICE)

cuda


In [25]:
if torch.cuda.is_available():
    """
    NVIDIA CUDA Deep Neural Network (cuDNN) is a GPU-accelerated library of primitives for deep neural networks
    """
    torch.backends.cudnn.deterministic=True

class LoRALayer(nn.Module):
    def __init__(self, in_dim, out_dim, rank, alpha):
        super().__init__()
        std_dev = 1/torch.sqrt(torch.tensor(rank).float())
        self.A = nn.Parameter(torch.randn(in_dim, rank)*std_dev)
        self.B = nn.Parameter(torch.zeros(rank, out_dim))
        self.alpha = alpha
        
    def forward(self, x):
        x = self.alpha*(x@self.A@self.B) 
        # Here, @ denotes matrix multiplication
        return x
    
class LinearWithLoRA(nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear=linear
        self.lora=LoRALayer(
            linear.in_features, linear.out_features, rank, alpha
        )
    
    def forward(self, x):
        return self.linear(x)+self.lora(x)

### Applying LoRA to Linear Layer
Let's apply LoRA to the Linear layer, we see that the results are the same since we haven't trained the LoRA weights yet. In other words, everything works as expected:

In [6]:
# import torch.nn.functional as F

# # This LoRA code is equivalent to LinearWithLoRA
# class LinearWithLoRAMerged(nn.Module):
#     def __init__(self, linear, rank, alpha):
#         super().__init__()
#         self.linear = linear
#         self.lora = LoRALayer(
#             linear.in_features, linear.out_features, rank, alpha
#         )
    
#     def forward(self, x):
#         lora=self.lora.A @ self.lora.B # combine LoRA metrices
#         # then combine LoRA original weights
#         combined_weight = self.linear.weight + self.lora.alpha*lora.T
#         return F.linear(x, combined_weight, self.linear.bias)

In [7]:
# layer_lora_2=LinearWithLoRAMerged(layer, rank=2, alpha=4)
# print(layer_lora_2(x))

In [22]:
from configs import *

class MultilayerPerceptron(nn.Module):
    def __init__(self, num_features, num_hidden_1, num_hidden_2, num_classes):
        super().__init__()
        self.layers=nn.Sequential(
            nn.Linear(num_features, num_hidden_1),
            nn.ReLU(),
            nn.Linear(num_hidden_1, num_hidden_2),
            nn.ReLU(),
            nn.Linear(num_hidden_2, num_classes)
        )
    
    def forward(self, x):
        x=self.layers(x)
        return x
    
model=MultilayerPerceptron(
    num_features = num_features,
    num_hidden_1 = num_hidden_1,
    num_hidden_2 = num_hidden_2,
    num_classes = num_classes
)

model.to(DEVICE)
optimizer_pretrained=torch.optim.Adam(model.parameters(), lr=learning_rate)
print(DEVICE, '\n -----')
print(model, '\n -----')
print(optimizer_pretrained, '\n -----')

cuda 
 -----
MultilayerPerceptron(
  (layers): Sequential(
    (0): Linear(in_features=784, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=10, bias=True)
  )
) 
 -----
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.005
    maximize: False
    weight_decay: 0
) 
 -----


In [11]:
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from configs import *


# Note: transforms.ToTensor() scales input images to 0-1 range
train_dataset=datasets.MNIST(root='data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset=datasets.MNIST(root='data', train=False, transform=transforms.ToTensor())
train_loader=DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader=DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)

for images, labels in train_loader:
    print('Image batch dimensions:', images.shape)
    print('Image label dimensions:', labels.shape)
    break

Image batch dimensions: torch.Size([64, 1, 28, 28])
Image label dimensions: torch.Size([64])


In [12]:
def compute_accuracy(model, data_loader, device):
    model.eval()
    correct_pred, num_examples=0,0
    with torch.no_grad():
        for features, targets in data_loader:
            features=features.view(-1, 28*28).to(device)
            targets=targets.to(device)
            logits=model(features)
            _, predicted_labels=torch.max(logits,1)
            num_examples+=targets.size(0)
            correct_pred+=(predicted_labels==targets).sum()
        return correct_pred.float()/num_examples*100

In [23]:
import time

def train(num_epochs, model, optimizer, train_loader, device):
    start_time=time.time()
    for epoch in range(num_epochs):
        model.train()
        for batch_idx, (features, targets) in enumerate(train_loader):
            features=features.view(-1, 28*28).to(device)
            targets=targets.to(device)
            
            # forward and back propagation
            logits=model(features)
            loss=F.cross_entropy(logits, targets)
            optimizer.zero_grad()
            
            loss.backward()
            
            # update model parameters
            optimizer.step()
            
            # logging
            if not batch_idx %400:
                print('Epoch: %03d/%03d|Batch %03d/%03d| Loss: %.4f' % (epoch+1, num_epochs, batch_idx, len(train_loader), loss))
        
        with torch.set_grad_enabled(False):
            print('Epoch: %03d/%03d training accuracy: %.2f%%' % (epoch+1, num_epochs, compute_accuracy(model, train_loader, device)))
        
        print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))
    print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))
                  
                  
train(num_epochs, model, optimizer_pretrained, train_loader, DEVICE)
print(f'Test accuracy: {compute_accuracy(model, test_loader, DEVICE):.2f}%')

Epoch: 001/010|Batch 000/938| Loss: 2.3158
Epoch: 001/010|Batch 400/938| Loss: 0.2281
Epoch: 001/010|Batch 800/938| Loss: 0.0836
Epoch: 001/010 training accuracy: 97.13%
Time elapsed: 0.19 min
Epoch: 002/010|Batch 000/938| Loss: 0.0481
Epoch: 002/010|Batch 400/938| Loss: 0.0714
Epoch: 002/010|Batch 800/938| Loss: 0.1637
Epoch: 002/010 training accuracy: 96.48%
Time elapsed: 0.37 min
Epoch: 003/010|Batch 000/938| Loss: 0.2207
Epoch: 003/010|Batch 400/938| Loss: 0.2021
Epoch: 003/010|Batch 800/938| Loss: 0.0160
Epoch: 003/010 training accuracy: 97.78%
Time elapsed: 0.57 min
Epoch: 004/010|Batch 000/938| Loss: 0.0869
Epoch: 004/010|Batch 400/938| Loss: 0.0343
Epoch: 004/010|Batch 800/938| Loss: 0.1604
Epoch: 004/010 training accuracy: 97.79%
Time elapsed: 0.77 min
Epoch: 005/010|Batch 000/938| Loss: 0.0474
Epoch: 005/010|Batch 400/938| Loss: 0.0264
Epoch: 005/010|Batch 800/938| Loss: 0.0101
Epoch: 005/010 training accuracy: 98.30%
Time elapsed: 0.97 min
Epoch: 006/010|Batch 000/938| Loss:

##### Injecting LoRA Layers

In [26]:
import copy

model_lora=copy.deepcopy(model)

model_lora.layers[0]=LinearWithLoRA(model_lora.layers[0], rank=4, alpha=8)
model_lora.layers[2]=LinearWithLoRA(model_lora.layers[2], rank=4, alpha=8)
model_lora.layers[4]=LinearWithLoRA(model_lora.layers[4], rank=4, alpha=8)
model_lora.to(DEVICE)
optimizer_lora=torch.optim.Adam(model_lora.parameters(), lr=learning_rate)
print(model_lora)

MultilayerPerceptron(
  (layers): Sequential(
    (0): LinearWithLoRA(
      (linear): Linear(in_features=784, out_features=128, bias=True)
      (lora): LoRALayer()
    )
    (1): ReLU()
    (2): LinearWithLoRA(
      (linear): Linear(in_features=128, out_features=256, bias=True)
      (lora): LoRALayer()
    )
    (3): ReLU()
    (4): LinearWithLoRA(
      (linear): Linear(in_features=256, out_features=10, bias=True)
      (lora): LoRALayer()
    )
  )
)


In [27]:
print(f'Test accuracy orig model:{compute_accuracy(model, test_loader, DEVICE):.2f}%')
print(f'Test accuracy LoRA model:{compute_accuracy(model_lora, test_loader, DEVICE):.2f}%')

Test accuracy orig model:97.34%
Test accuracy LoRA model:97.34%


In [28]:
def freeze_linear_layers(model):
    for child in model.children():
        if isinstance(child, nn.Linear):
            for param in child.parameters():
                param.requires_grad=False
        else:
            # recursively freeze linear layers in children modules
            freeze_linear_layers(child)

freeze_linear_layers(model_lora)
for name, param in model_lora.named_parameters():
    print(f'{name}:{param.requires_grad}')

layers.0.linear.weight:False
layers.0.linear.bias:False
layers.0.lora.A:True
layers.0.lora.B:True
layers.2.linear.weight:False
layers.2.linear.bias:False
layers.2.lora.A:True
layers.2.lora.B:True
layers.4.linear.weight:False
layers.4.linear.bias:False
layers.4.lora.A:True
layers.4.lora.B:True


In [29]:
optimizer_lora=torch.optim.Adam(model_lora.parameters(), lr=learning_rate)
train(num_epochs, model_lora, optimizer_lora, train_loader, DEVICE)
print(f'Test accuracy LoRA finetune: {compute_accuracy(model_lora, test_loader, DEVICE):.2f}%')

Epoch: 001/010|Batch 000/938| Loss: 0.0506
Epoch: 001/010|Batch 400/938| Loss: 0.0474
Epoch: 001/010|Batch 800/938| Loss: 0.0007
Epoch: 001/010 training accuracy: 98.83%
Time elapsed: 0.22 min
Epoch: 002/010|Batch 000/938| Loss: 0.0000
Epoch: 002/010|Batch 400/938| Loss: 0.0300
Epoch: 002/010|Batch 800/938| Loss: 0.2433
Epoch: 002/010 training accuracy: 98.74%
Time elapsed: 0.51 min
Epoch: 003/010|Batch 000/938| Loss: 0.0163
Epoch: 003/010|Batch 400/938| Loss: 0.0380
Epoch: 003/010|Batch 800/938| Loss: 0.0108
Epoch: 003/010 training accuracy: 99.03%
Time elapsed: 0.79 min
Epoch: 004/010|Batch 000/938| Loss: 0.0447
Epoch: 004/010|Batch 400/938| Loss: 0.0099
Epoch: 004/010|Batch 800/938| Loss: 0.0630
Epoch: 004/010 training accuracy: 99.23%
Time elapsed: 1.02 min
Epoch: 005/010|Batch 000/938| Loss: 0.0014
Epoch: 005/010|Batch 400/938| Loss: 0.0800
Epoch: 005/010|Batch 800/938| Loss: 0.0177
Epoch: 005/010 training accuracy: 99.14%
Time elapsed: 1.25 min
Epoch: 006/010|Batch 000/938| Loss:

In [30]:
print(f'Test accuracy orig model:{compute_accuracy(model, test_loader, DEVICE):.2f}%')
print(f'Test accuracy LoRA model:{compute_accuracy(model_lora, test_loader, DEVICE):.2f}%')

Test accuracy orig model:97.34%
Test accuracy LoRA model:97.52%


##### Illustrative Example

In [1]:
import numpy as np

# Original matrix W
W = np.random.rand(10, 10)

# Low-rank approximation matrices A and B
rank = 2
A = np.random.rand(10, rank)
B = np.random.rand(rank, 10)

# Approximate W with A * B
W_approx = np.dot(A, B)

# Print the matrices
print("Original Matrix W:")
print(W)
print("\nMatrix A:")
print(A)
print("\nMatrix B:")
print(B)
print("\nApproximated Matrix W_approx (A * B):")
print(W_approx)

# Number of parameters
params_W = W.size
params_A_B = A.size + B.size

print(f"\nNumber of parameters in W: {params_W}")
print(f"Number of parameters in A and B: {params_A_B}")

Original Matrix W:
[[0.27623515 0.63215998 0.92459435 0.12553922 0.1324817  0.62949987
  0.67642814 0.20044365 0.15545272 0.31994525]
 [0.31620771 0.73463356 0.40055041 0.87588128 0.78227843 0.90167607
  0.32959735 0.28369558 0.42492724 0.133431  ]
 [0.87525289 0.32988044 0.83585347 0.93829371 0.09390059 0.25781041
  0.51288144 0.91023738 0.93678052 0.40134471]
 [0.10761006 0.26557536 0.89544327 0.32928028 0.46488922 0.97024327
  0.52649051 0.00352671 0.82629554 0.43071364]
 [0.52791184 0.59928369 0.46913586 0.68322513 0.67940471 0.08802975
  0.80061053 0.97934069 0.46685688 0.69116909]
 [0.647555   0.25602217 0.25908842 0.8408801  0.90313406 0.43487419
  0.48209494 0.22609687 0.05766792 0.06151783]
 [0.99683345 0.76136488 0.65484794 0.43741365 0.75294596 0.20422666
  0.32325106 0.07004434 0.73373486 0.81046958]
 [0.15282686 0.09544934 0.19786853 0.60568179 0.36603849 0.42335153
  0.17983766 0.71277869 0.92063546 0.78758077]
 [0.34864647 0.54255032 0.58589173 0.94692057 0.28321406 0.92