# Weight Decomposed Low-Rank Adaptation (DoRA)
---
Low-rank adaptation(**LoRA**) is a machine learning technique that modifies a pretrained model (e.g., an LLM) to better suit a specific, often smaller, dataset by adjusting only a small, low-rank subset of the model's parameters.

This approach is important because it allows for efficient finetuning of large models on task-specific data significantly reducing the computational cost and time required for finetuning.

In this notebook, we are going to talk about [Weight-Decomposed Low-Rank Adaptation](https://arxiv.org/abs/2402.09353), which is a new alterative to **LoRA**, which may outperform LoRA by a large margin. We are going to implement both **LoRA** and **DoRA** in PyTorch from scratch in this notebook.

**DoRA** can be seen as an improvement or extension of **LoRA** that is built on top of it, and we can now easily adapt some of our previous code to implement **DoRA**. **DoRA** can be described in two steps, where the first step is to decompose a pretrained weight matrix into a magnitude *vector(m)* and a directional *matrix(V)*. The second step is applyting **LoRA** to the directional matrix *V* and training the magnitude vector *m* separately.
The decomposition into magnitude and directional components is inspired by the mathematical principle that any vector can be represented as the product of its magnitude(a scalar value indicating its length) and its direction (a unit vector indicating its orientation in space).

- Set up work environment

In [47]:
# !pip install torch
# !pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 -f https://download.pytorch.org/whl/torch_stable.html
# !pip install numpy==1.26.4

In [48]:
import torch.nn as nn
import torch.nn.functional as F
import torch
from cuda import DEVICE

- Check CUDA availability

In [49]:
if torch.cuda.is_available():
    torch.backends.cudnn.deterministic=True
    DEVICE=torch.device('cuda')
else:
    DEVICE=torch.device('cpu')

print(DEVICE)

cuda


- Implement **LoRA** and **DoRA** Layers 

In [50]:
class LoRALayer(nn.Module):
    def __init__(self, in_dim, out_dim, rank, alpha):
        super().__init__()
        std_dev=1/torch.sqrt(torch.tensor(rank).float())
        self.A=nn.Parameter(torch.randn(in_dim, rank)*std_dev)
        self.B=nn.Parameter(torch.zeros(rank, out_dim))
        self.alpha=alpha
    
    def forward(self, x):
        # @ means matrix multiplication
        x=self.alpha*(x @ self.A @ self.B)
        return x
    
class LinearWithDoRA(nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear=linear
        self.lora=LoRALayer(
            linear.in_features, linear.out_features, rank, alpha
        )
        self.m=nn.Parameter(torch.ones(1, linear.out_features))
        
    def forward(self, x):
        linear_output=self.linear(x)
        lora_output=self.lora(x)
        lora_output_norm=lora_output/lora_output.norm(p=2, dim=1, keepdim=True)
        dora_modification=self.m * lora_output_norm
        dora_output=self.lora(x)
        return linear_output+dora_output
    
# this code is equivalent to LinearWithDoRA
class LinearWithDoRAMerged(nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear=linear
        self.lora=LoRALayer(
            linear.in_features, linear.out_features, rank, alpha
        )
        self.m=nn.Parameter(self.linear.weight.norm(p=2, dim=0, keepdim=True))
        
    def forward(self, x):
        lora=self.lora.A @self.lora.B
        numerator=self.linear.weight+self.lora.alpha*lora.T
        denominator=numerator.norm(p=2, dim=0, keepdim=True)
        directional_component=numerator/denominator
        new_weight=self.m*directional_component
        return F.linear(x, new_weight, self.linear.bias)

- Swaping existing Linear layers

In [51]:
# hyperparameters
random_seed=123

torch.manual_seed(random_seed)

layer=nn.Linear(10,2)
x=torch.randn(1,10)

layer_dora_1=LinearWithDoRA(layer, rank=2, alpha=4)
print(layer_dora_1(x), '\n ------------')
layer_dora_2=LinearWithDoRAMerged(layer, rank=2, alpha=4)
print(layer_dora_2(x))

tensor([[0.6639, 0.4487]], grad_fn=<AddBackward0>) 
 ------------
tensor([[0.6639, 0.4487]], grad_fn=<AddmmBackward0>)


In [52]:
# hyperparameters
learning_rate=0.005
num_epochs=15

# architecture
num_features=784
num_hidden_1=128
num_hidden_2=256
num_classes=10

- Define Multilayer Perceptron Model (Without **DoRA**)

In [53]:
class MultilayerPerceptron(nn.Module):
    def __init__(self, num_features, num_hidden_1, num_hidden_2, num_classes):
        super().__init__()
        
        self.layers=nn.Sequential(
            nn.Linear(num_features, num_hidden_1),
            nn.ReLU(),
            nn.Linear(num_hidden_1, num_hidden_2),
            nn.ReLU(),
            nn.Linear(num_hidden_2, num_classes)
        )
    
    def forward(self, x):
        x=self.layers(x)
        return x

model=MultilayerPerceptron(
    num_features=num_features,
    num_hidden_1=num_hidden_1,
    num_hidden_2=num_hidden_2,
    num_classes=num_classes
)

In [54]:
model.to(DEVICE)
optimizer_pretrained=torch.optim.Adam(model.parameters(), lr=learning_rate)

- Prepare and Load the Dataset

In [55]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

BATCH_SIZE=64

# note transforms.ToTensor() scales input images to 0-1 range
train_dataset=datasets.MNIST(root='data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset=datasets.MNIST(root='data', train=False, transform=transforms.ToTensor())

train_loader=DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader=DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# checking the dataset
for images, labels in train_loader:
    print('Image batch dimensions:', images.shape)
    print('Image label dimensions:', labels.shape)
    break



Image batch dimensions: torch.Size([64, 1, 28, 28])
Image label dimensions: torch.Size([64])


- Define Evaluation and Training functions

In [56]:
import time

def compute_accuracy(model, data_loader, device):
    model.eval()
    correct_pred, num_examples=0,0
    with torch.no_grad():
        for features, targets in data_loader:
            features=features.view(-1, 28*28).to(device)
            targets=targets.to(device)
            logits=model(features)
            _, predicted_labels=torch.max(logits,1)
            num_examples+=targets.size(0)
            correct_pred+=(predicted_labels==targets).sum()
        return correct_pred.float()/num_examples*100

def train(num_epochs, model, optimizer, train_loader, device):
    start_time=time.time()
    for epoch in range(num_epochs):
        model.train()
        for batch_idx, (features, targets) in enumerate(train_loader):
            features=features.view(-1, 28*28).to(device)
            targets=targets.to(device)
            
            # Forward and backpropgation
            logits=model(features)
            loss=F.cross_entropy(logits, targets)
            optimizer.zero_grad()
            
            loss.backward()
            
            # update model parameters
            optimizer.step()
            
            #logging
            if not batch_idx %400:
                print('Epoch %03d/%03d | Batch %03d/%03d | Loss: %.4f' % (epoch+1, num_epochs, batch_idx, len(train_loader), loss))
        with torch.set_grad_enabled(False):
            print('Epoch: %03d/%03d training accuracy: %0.2f%%' % (epoch+1, num_epochs, compute_accuracy(model, train_loader, device)))
        print('Time elapsed: %.2f min' % ((time.time()- start_time)/60))
    print('Total Training TIme: %.2f min' % ((time.time()-start_time)/60))
    

train(num_epochs, model, optimizer_pretrained, train_loader, DEVICE)
print(f'Test accuracy: {compute_accuracy(model, test_loader, DEVICE):.2f}%')


Epoch 001/015 | Batch 000/938 | Loss: 2.3110
Epoch 001/015 | Batch 400/938 | Loss: 0.2177
Epoch 001/015 | Batch 800/938 | Loss: 0.3389
Epoch: 001/015 training accuracy: 95.14%
Time elapsed: 0.09 min
Epoch 002/015 | Batch 000/938 | Loss: 0.1139
Epoch 002/015 | Batch 400/938 | Loss: 0.1833
Epoch 002/015 | Batch 800/938 | Loss: 0.2406
Epoch: 002/015 training accuracy: 96.94%
Time elapsed: 0.18 min
Epoch 003/015 | Batch 000/938 | Loss: 0.0661
Epoch 003/015 | Batch 400/938 | Loss: 0.0905
Epoch 003/015 | Batch 800/938 | Loss: 0.2603
Epoch: 003/015 training accuracy: 97.14%
Time elapsed: 0.26 min
Epoch 004/015 | Batch 000/938 | Loss: 0.0786
Epoch 004/015 | Batch 400/938 | Loss: 0.1385
Epoch 004/015 | Batch 800/938 | Loss: 0.3343
Epoch: 004/015 training accuracy: 97.71%
Time elapsed: 0.34 min
Epoch 005/015 | Batch 000/938 | Loss: 0.0386
Epoch 005/015 | Batch 400/938 | Loss: 0.1358
Epoch 005/015 | Batch 800/938 | Loss: 0.0768
Epoch: 005/015 training accuracy: 96.85%
Time elapsed: 0.43 min
Epoch

In [57]:
import copy

model_dora=copy.deepcopy(model)
print(model_dora, '\n -----------')

model_dora.layers[0]=LinearWithDoRAMerged(model_dora.layers[0], rank=4, alpha=8)
model_dora.layers[2]=LinearWithDoRAMerged(model_dora.layers[2], rank=4, alpha=8)
model_dora.layers[4]=LinearWithDoRAMerged(model_dora.layers[4], rank=4, alpha=8)

model_dora.to(DEVICE)
optimizer_dora=torch.optim.Adam(model_dora.parameters(), lr=learning_rate)

print(model_dora, '\n -----------')

MultilayerPerceptron(
  (layers): Sequential(
    (0): Linear(in_features=784, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=10, bias=True)
  )
) 
 -----------
MultilayerPerceptron(
  (layers): Sequential(
    (0): LinearWithDoRAMerged(
      (linear): Linear(in_features=784, out_features=128, bias=True)
      (lora): LoRALayer()
    )
    (1): ReLU()
    (2): LinearWithDoRAMerged(
      (linear): Linear(in_features=128, out_features=256, bias=True)
      (lora): LoRALayer()
    )
    (3): ReLU()
    (4): LinearWithDoRAMerged(
      (linear): Linear(in_features=256, out_features=10, bias=True)
      (lora): LoRALayer()
    )
  )
) 
 -----------


- Freeze the orignal weights

In [58]:


def freeze_linear_layers(model):
    for child in model.children():
        if isinstance(child, nn.Linear):
            for param in child.parameters():
                param.requires_grad=False
        else:
            # recursively freeze linear layers in children modules
            freeze_linear_layers(child)

freeze_linear_layers(model_dora)

# check if linear layers are frozen
for name, param in model_dora.named_parameters():
    print(f'{name}: {param.requires_grad}')

print(20*'-')

layers.0.m: True
layers.0.linear.weight: False
layers.0.linear.bias: False
layers.0.lora.A: True
layers.0.lora.B: True
layers.2.m: True
layers.2.linear.weight: False
layers.2.linear.bias: False
layers.2.lora.A: True
layers.2.lora.B: True
layers.4.m: True
layers.4.linear.weight: False
layers.4.linear.bias: False
layers.4.lora.A: True
layers.4.lora.B: True
--------------------


In [59]:
optimizer_dora=torch.optim.Adam(model_dora.parameters(), lr=learning_rate)
train(num_epochs, model_dora, optimizer_dora, train_loader, DEVICE)
print(f'Test accuracy DoRA finetune: {compute_accuracy(model_dora, test_loader, DEVICE):.2f}%')

print('++++')

print(f'Test accuracy: {compute_accuracy(model, test_loader, DEVICE):.2f}%')
print(f'Test accuracy DoRA finetune: {compute_accuracy(model_dora, test_loader, DEVICE):.2f}%')



Epoch 001/015 | Batch 000/938 | Loss: 0.0001
Epoch 001/015 | Batch 400/938 | Loss: 0.0127
Epoch 001/015 | Batch 800/938 | Loss: 0.0228
Epoch: 001/015 training accuracy: 98.93%
Time elapsed: 0.13 min
Epoch 002/015 | Batch 000/938 | Loss: 0.0343
Epoch 002/015 | Batch 400/938 | Loss: 0.0036
Epoch 002/015 | Batch 800/938 | Loss: 0.0421
Epoch: 002/015 training accuracy: 98.96%
Time elapsed: 0.26 min
Epoch 003/015 | Batch 000/938 | Loss: 0.0086
Epoch 003/015 | Batch 400/938 | Loss: 0.0700
Epoch 003/015 | Batch 800/938 | Loss: 0.0146
Epoch: 003/015 training accuracy: 98.93%
Time elapsed: 0.38 min
Epoch 004/015 | Batch 000/938 | Loss: 0.0319
Epoch 004/015 | Batch 400/938 | Loss: 0.0062
Epoch 004/015 | Batch 800/938 | Loss: 0.0480
Epoch: 004/015 training accuracy: 99.23%
Time elapsed: 0.51 min
Epoch 005/015 | Batch 000/938 | Loss: 0.0155
Epoch 005/015 | Batch 400/938 | Loss: 0.0190
Epoch 005/015 | Batch 800/938 | Loss: 0.1159
Epoch: 005/015 training accuracy: 99.06%
Time elapsed: 0.63 min
Epoch