# Implementing DoRA: Weight-Decomposed Low-Rank Adaptation

In this notebook we will be implementing the [Weight-Decomposed Low-Rank Adaptation (DoRA)](https://arxiv.org/abs/2402.09353) architecture, proposed by researchers as a technique that outperforms LoRA by a large margin.

To follow along conceptually, you can refer to the [writeup](https://medium.com/p/f814ba519af4/edit) where we go over the theoretical concepts and the motivation behind LoRA and DoRA.

## Importing libraries

In [None]:
import time
import numpy as np
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.nn as nn
import torch


if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True

## Settings and dataset

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 64

train_dataset = datasets.MNIST(root='data/',
                               train=True,
                               transform=transforms.ToTensor(),
                               download=True)

test_dataset = datasets.MNIST(root='data/',
                               train=False,
                               transform=transforms.ToTensor())

train_loader = DataLoader(dataset=train_dataset,
                          batch_size=BATCH_SIZE,
                          shuffle=True)

test_loader = DataLoader(dataset=test_dataset,
                         batch_size=BATCH_SIZE,
                         shuffle=False)

for images, labels in train_loader:
    print('Image batch dimensions:', images.shape)
    print('Image label dimensions:', labels.shape)
    break # Only print one, all will have the same dimensions

Image batch dimensions: torch.Size([64, 1, 28, 28])
Image label dimensions: torch.Size([64])


### Multilayer Perceptron Model

In [None]:
random_seed = 123
learning_rate = 0.005
num_epochs = 10

num_features = 784
num_hidden_1 = 128
num_hidden_2 = 256
num_classes = 10


class MultilayerPerceptron(nn.Module):

    def __init__(self, num_features, num_hidden_1, num_hidden_2, num_classes):
        super().__init__()

        self.layers = nn.Sequential(
            nn.Linear(num_features, num_hidden_1),
            nn.ReLU(),
            nn.Linear(num_hidden_1, num_hidden_2),
            nn.ReLU(),
            nn.Linear(num_hidden_2, num_classes)
        )

    def forward(self, x):
        x = self.layers(x)
        return x


torch.manual_seed(random_seed)
model_pretrained = MultilayerPerceptron(
    num_features=num_features,
    num_hidden_1=num_hidden_1,
    num_hidden_2=num_hidden_2,
    num_classes=num_classes
)

model_pretrained.to(DEVICE)
optimizer_pretrained = torch.optim.Adam(model_pretrained.parameters(), lr=learning_rate)

In [None]:
def compute_accuracy(model, data_loader, device):
    model.eval()
    correct_pred, num_examples = 0, 0
    with torch.no_grad():
        for features, targets in data_loader: # Processing batches
            features = features.view(-1, 28*28).to(device)
            targets = targets.to(device)
            logits = model(features)
            _, predicted_labels = torch.max(logits, 1)
            num_examples += targets.size(0)
            correct_pred += (predicted_labels == targets).sum()
        return correct_pred.float()/num_examples * 100


def train(num_epochs, model, optimizer, train_loader, device):
    model.train()
    start_time = time.time()
    for epoch in range(num_epochs):
        for batch_idx, (features, targets) in enumerate(train_loader):

            features = features.view(-1, 28*28).to(device)
            targets = targets.to(device)

            logits = model(features)
            loss = F.cross_entropy(logits, targets)
            optimizer.zero_grad()

            loss.backward()

            optimizer.step() # Update parameters

            if not batch_idx % 400:
                print('Epoch: %03d/%03d | Batch %03d/%03d | Loss: %.4f'
                      % (epoch+1, num_epochs, batch_idx,
                          len(train_loader), loss))

        with torch.set_grad_enabled(False):
            print('Epoch: %03d/%03d training accuracy: %.2f%%' % (
                  epoch+1, num_epochs,
                  compute_accuracy(model, train_loader, device)))

        print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))

    print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))

### Initial Training

In [None]:
train(num_epochs, model_pretrained, optimizer_pretrained, train_loader, DEVICE)
print(f'Test accuracy: {compute_accuracy(model_pretrained, test_loader, DEVICE):.2f}%')

Epoch: 001/010 | Batch 000/938 | Loss: 2.2971
Epoch: 001/010 | Batch 400/938 | Loss: 0.1774
Epoch: 001/010 | Batch 800/938 | Loss: 0.1849
Epoch: 001/010 training accuracy: 94.83%
Time elapsed: 0.37 min
Epoch: 002/010 | Batch 000/938 | Loss: 0.0912
Epoch: 002/010 | Batch 400/938 | Loss: 0.0571
Epoch: 002/010 | Batch 800/938 | Loss: 0.0569
Epoch: 002/010 training accuracy: 97.30%
Time elapsed: 0.67 min
Epoch: 003/010 | Batch 000/938 | Loss: 0.0802
Epoch: 003/010 | Batch 400/938 | Loss: 0.0549
Epoch: 003/010 | Batch 800/938 | Loss: 0.0249
Epoch: 003/010 training accuracy: 97.94%
Time elapsed: 0.98 min
Epoch: 004/010 | Batch 000/938 | Loss: 0.0687
Epoch: 004/010 | Batch 400/938 | Loss: 0.1166
Epoch: 004/010 | Batch 800/938 | Loss: 0.1479
Epoch: 004/010 training accuracy: 98.34%
Time elapsed: 1.28 min
Epoch: 005/010 | Batch 000/938 | Loss: 0.0716
Epoch: 005/010 | Batch 400/938 | Loss: 0.1389
Epoch: 005/010 | Batch 800/938 | Loss: 0.0586
Epoch: 005/010 training accuracy: 98.12%
Time elapsed:

### Multilayer Perceptron with LoRA and DoRA

In [None]:
class LoRALayer(nn.Module):
    def __init__(self, in_dim, out_dim, rank, alpha):
        super().__init__()
        std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
        self.A = nn.Parameter(torch.randn(in_dim, rank) * std_dev)
        self.B = nn.Parameter(torch.zeros(rank, out_dim))
        self.alpha = alpha

    def forward(self, x):
        x = self.alpha * (x @ self.A @ self.B)
        return x


class LinearWithLoRA(nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(
            linear.in_features, linear.out_features, rank, alpha
        )

    def forward(self, x):
        lora = self.lora.A @ self.lora.B
        combined_weight = self.linear.weight + self.lora.alpha*lora.T
        return F.linear(x, combined_weight, self.linear.bias)


class LinearWithDoRA(nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(
            linear.in_features, linear.out_features, rank, alpha
        )

        self.m = nn.Parameter(
            self.linear.weight.norm(p=2, dim=0, keepdim=True))

    def forward(self, x):
        lora = self.lora.A @ self.lora.B
        numerator = self.linear.weight + self.lora.alpha*lora.T
        denominator = numerator.norm(p=2, dim=0, keepdim=True)
        directional_component = numerator / denominator
        new_weight = self.m * directional_component
        return F.linear(x, new_weight, self.linear.bias)

In [None]:
torch.manual_seed(123)

layer = nn.Linear(10, 2)
x = torch.randn((1, 10))

print("Original output:", layer(x))

Original output: tensor([[0.6639, 0.4487]], grad_fn=<AddmmBackward0>)


In [None]:
layer_lora_2 = LinearWithLoRA(layer, rank=2, alpha=4)
print("LoRA output:", layer_lora_2(x))

LoRA output: tensor([[0.6639, 0.4487]], grad_fn=<AddmmBackward0>)


In [None]:
layer_dora_2 = LinearWithDoRA(layer, rank=2, alpha=4)

print("DoRA output:", layer_dora_2(x))

DoRA output: tensor([[0.6639, 0.4487]], grad_fn=<AddmmBackward0>)


In [None]:
model_pretrained

MultilayerPerceptron(
  (layers): Sequential(
    (0): Linear(in_features=784, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=10, bias=True)
  )
)

In [None]:
import copy

model_lora = copy.deepcopy(model_pretrained)
model_dora = copy.deepcopy(model_pretrained)

In [None]:
model_lora.layers[0] = LinearWithLoRA(model_lora.layers[0], rank=4, alpha=8)
model_lora.layers[2] = LinearWithLoRA(model_lora.layers[2], rank=4, alpha=8)
model_lora.layers[4] = LinearWithLoRA(model_lora.layers[4], rank=4, alpha=8)

model_lora.to(DEVICE)
optimizer_lora = torch.optim.Adam(model_lora.parameters(), lr=learning_rate)
model_lora

MultilayerPerceptron(
  (layers): Sequential(
    (0): LinearWithLoRA(
      (linear): Linear(in_features=784, out_features=128, bias=True)
      (lora): LoRALayer()
    )
    (1): ReLU()
    (2): LinearWithLoRA(
      (linear): Linear(in_features=128, out_features=256, bias=True)
      (lora): LoRALayer()
    )
    (3): ReLU()
    (4): LinearWithLoRA(
      (linear): Linear(in_features=256, out_features=10, bias=True)
      (lora): LoRALayer()
    )
  )
)

In [None]:
model_dora.layers[0] = LinearWithDoRA(model_dora.layers[0], rank=4, alpha=8)
model_dora.layers[2] = LinearWithDoRA(model_dora.layers[2], rank=4, alpha=8)
model_dora.layers[4] = LinearWithDoRA(model_dora.layers[4], rank=4, alpha=8)

model_dora.to(DEVICE)
optimizer_dora = torch.optim.Adam(model_dora.parameters(), lr=learning_rate)
model_dora

MultilayerPerceptron(
  (layers): Sequential(
    (0): LinearWithDoRA(
      (linear): Linear(in_features=784, out_features=128, bias=True)
      (lora): LoRALayer()
    )
    (1): ReLU()
    (2): LinearWithDoRA(
      (linear): Linear(in_features=128, out_features=256, bias=True)
      (lora): LoRALayer()
    )
    (3): ReLU()
    (4): LinearWithDoRA(
      (linear): Linear(in_features=256, out_features=10, bias=True)
      (lora): LoRALayer()
    )
  )
)

In [None]:
print(f'Test accuracy original model: {compute_accuracy(model_pretrained, test_loader, DEVICE):.2f}%')
print(f'Test accuracy LoRA model: {compute_accuracy(model_lora, test_loader, DEVICE):.2f}%')
print(f'Test accuracy DoRA model: {compute_accuracy(model_dora, test_loader, DEVICE):.2f}%')

Test accuracy original model: 97.59%
Test accuracy LoRA model: 97.59%
Test accuracy DoRA model: 97.59%


### Finetuning With LoRA

In [None]:
def freeze_linear_layers(model):
    for child in model.children():
        if isinstance(child, nn.Linear):
            for param in child.parameters():
                param.requires_grad = False
        else:
            # Recursively freeze linear layers in children modules
            freeze_linear_layers(child)

In [None]:
freeze_linear_layers(model_lora)

for name, param in model_lora.named_parameters():
    print(f"{name}: {param.requires_grad}")

layers.0.linear.weight: False
layers.0.linear.bias: False
layers.0.lora.A: True
layers.0.lora.B: True
layers.2.linear.weight: False
layers.2.linear.bias: False
layers.2.lora.A: True
layers.2.lora.B: True
layers.4.linear.weight: False
layers.4.linear.bias: False
layers.4.lora.A: True
layers.4.lora.B: True


In [None]:
optimizer_lora = torch.optim.Adam(model_lora.parameters(), lr=learning_rate)
train(num_epochs, model_lora, optimizer_lora, train_loader, DEVICE)
print(f'Test accuracy LoRA finetune: {compute_accuracy(model_lora, test_loader, DEVICE):.2f}%')

Epoch: 001/010 | Batch 000/938 | Loss: 0.0032
Epoch: 001/010 | Batch 400/938 | Loss: 0.0098
Epoch: 001/010 | Batch 800/938 | Loss: 0.2822
Epoch: 001/010 training accuracy: 98.81%
Time elapsed: 0.30 min
Epoch: 002/010 | Batch 000/938 | Loss: 0.1740
Epoch: 002/010 | Batch 400/938 | Loss: 0.0479
Epoch: 002/010 | Batch 800/938 | Loss: 0.0421
Epoch: 002/010 training accuracy: 98.90%
Time elapsed: 0.62 min
Epoch: 003/010 | Batch 000/938 | Loss: 0.0000
Epoch: 003/010 | Batch 400/938 | Loss: 0.0330
Epoch: 003/010 | Batch 800/938 | Loss: 0.0224
Epoch: 003/010 training accuracy: 99.24%
Time elapsed: 0.92 min
Epoch: 004/010 | Batch 000/938 | Loss: 0.1600
Epoch: 004/010 | Batch 400/938 | Loss: 0.0220
Epoch: 004/010 | Batch 800/938 | Loss: 0.0485
Epoch: 004/010 training accuracy: 99.21%
Time elapsed: 1.23 min
Epoch: 005/010 | Batch 000/938 | Loss: 0.0037
Epoch: 005/010 | Batch 400/938 | Loss: 0.0007
Epoch: 005/010 | Batch 800/938 | Loss: 0.0959
Epoch: 005/010 training accuracy: 98.77%
Time elapsed:

### Finetuning With DoRA

In [None]:
freeze_linear_layers(model_dora)

for name, param in model_dora.named_parameters():
    print(f"{name}: {param.requires_grad}")

layers.0.m: True
layers.0.linear.weight: False
layers.0.linear.bias: False
layers.0.lora.A: True
layers.0.lora.B: True
layers.2.m: True
layers.2.linear.weight: False
layers.2.linear.bias: False
layers.2.lora.A: True
layers.2.lora.B: True
layers.4.m: True
layers.4.linear.weight: False
layers.4.linear.bias: False
layers.4.lora.A: True
layers.4.lora.B: True


In [None]:
optimizer_dora = torch.optim.Adam(model_dora.parameters(), lr=learning_rate)
train(num_epochs, model_dora, optimizer_dora, train_loader, DEVICE)
print(f'Test accuracy DoRA finetune: {compute_accuracy(model_dora, test_loader, DEVICE):.2f}%')

Epoch: 001/010 | Batch 000/938 | Loss: 0.0016
Epoch: 001/010 | Batch 400/938 | Loss: 0.0490
Epoch: 001/010 | Batch 800/938 | Loss: 0.0824
Epoch: 001/010 training accuracy: 99.02%
Time elapsed: 0.35 min
Epoch: 002/010 | Batch 000/938 | Loss: 0.0242
Epoch: 002/010 | Batch 400/938 | Loss: 0.0046
Epoch: 002/010 | Batch 800/938 | Loss: 0.0760
Epoch: 002/010 training accuracy: 99.36%
Time elapsed: 0.71 min
Epoch: 003/010 | Batch 000/938 | Loss: 0.0583
Epoch: 003/010 | Batch 400/938 | Loss: 0.0264
Epoch: 003/010 | Batch 800/938 | Loss: 0.0269
Epoch: 003/010 training accuracy: 99.36%
Time elapsed: 1.06 min
Epoch: 004/010 | Batch 000/938 | Loss: 0.0051
Epoch: 004/010 | Batch 400/938 | Loss: 0.0032
Epoch: 004/010 | Batch 800/938 | Loss: 0.0023
Epoch: 004/010 training accuracy: 99.33%
Time elapsed: 1.42 min
Epoch: 005/010 | Batch 000/938 | Loss: 0.0195
Epoch: 005/010 | Batch 400/938 | Loss: 0.0304
Epoch: 005/010 | Batch 800/938 | Loss: 0.0529
Epoch: 005/010 training accuracy: 99.38%
Time elapsed: