# LoRA Implementation with PyTorch

In [2]:
import torch
import torch.utils
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt 
import torch.nn.utils.parametrize as parametrize
from tqdm import tqdm



Make the model deterministic

In [3]:
_ = torch.manual_seed(0)

Train a network to classift MNIST digits then fine-tune the network on a specific digit in which it did not learn well.

In [4]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

In [5]:
mnist_trainset = datasets.MNIST(root='./data', train=True, download= True, transform= transform)

# Create a dataloader for training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Load MNIST Test set
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

In [6]:
torch.cuda.is_available()

False

In [7]:
# Define Device ---Check to see if possible to train on gpu instead of CPU
if torch.backends.mps.is_available():
    device = torch.device("mps")  # Use Metal Performance Shaders
else:
    device = torch.device("cpu")  # Fallback to CPU

In [7]:
device

device(type='mps')

Create a neural network to classify digits

In [8]:
class ClassifyDigits(nn.Module):
    def __init__(self, hidden_size_1= 1000, hidden_size_2= 2000):
        super(ClassifyDigits, self).__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()
        
    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
    
        return x
        
net = ClassifyDigits().to(device)

Train the network for only 1 EPOCH to simulate complete pre-training on the data

In [9]:
def train(train_loader, net, epochs=5, total_iterations_limit=None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr= 0.001)
    
    total_iterations = 0
    
    for epoch in range(epochs):
        net.train()
        
        loss_sum = 0
        num_iterations = 0
        
        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
            
        for data in data_iterator:
            num_iterations +=1 
            total_iterations +=1
            
            x, y = data
            x = x.to(device)
            y = y.to(device)
            
            optimizer.zero_grad()
            output = net(x)
            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss= avg_loss)
            loss.backward()
            optimizer.step()
            
            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return         

In [10]:
train(train_loader, net, epochs=1)

Epoch 1: 100%|██████████| 6000/6000 [00:38<00:00, 156.97it/s, loss=0.238]


Keep a copy of the original weights to prove fine-tuning with LoRA doesn't impact original weights.

In [11]:
original_weights = {}
for name, param, in net.named_parameters():
    original_weights[name] = param.clone().detach()

Test the performence of the pretrained network. The network performs poorly on digit 9. Let's fine tune on digit 9 to improve performence. 

In [12]:
def test():
    correct = 0
    total = 0
    
    wrong_counts = [0 for i in range(10)]
    
    with torch.no_grad():
        for data in tqdm(test_loader, desc='Test'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = net(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct += 1
                else:
                    wrong_counts[y[idx]] += 1
                total += 1
    
    print(f'Accuracy: {round(correct/total, 3)}')
    for i in range(len(wrong_counts)):
        print(f'wrong counts for the digit {i}: {wrong_counts[i]}')

In [13]:
test()

Test: 100%|██████████| 1000/1000 [00:04<00:00, 222.43it/s]

Accuracy: 0.955
wrong counts for the digit 0: 36
wrong counts for the digit 1: 23
wrong counts for the digit 2: 48
wrong counts for the digit 3: 69
wrong counts for the digit 4: 18
wrong counts for the digit 5: 15
wrong counts for the digit 6: 86
wrong counts for the digit 7: 47
wrong counts for the digit 8: 41
wrong counts for the digit 9: 71





Let's visualize how many parameters are in the original network, before introducing the LoRA matrices. 

In [14]:
total_parameters_original = 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_parameters_original += layer.weight.nelement() + layer.bias.nelement()
    print(f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape}')

print(f'Total number of parameters: {total_parameters_original}') ## i want to see what the output looks like first

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10])
Total number of parameters: 2807010


Define LoRA Parameterization in the paper.

In [34]:
class LoRAParametrization(nn.Module):
    def __init__(self, features_in, features_out, rank=1, alpha=1, device='cpu'):
        """
        LoRA Parameterization for low-rank adaptation of linear layers.

        Args:
            features_in (int): Number of input features.
            features_out (int): Number of output features.
            rank (int): Rank for the low-rank decomposition.
            alpha (float): Scaling factor for the LoRA update.
            device (str): Device to place the parameters ('cpu' or 'cuda').
        """
        super().__init__()
        
        # Initialize low-rank matrices A and B
        self.lora_A = nn.Parameter(torch.zeros((rank, features_out), device=device))
        self.lora_B = nn.Parameter(torch.zeros((features_in, rank), device=device))
        
        # Gaussian initialization for both A and B
        nn.init.normal_(self.lora_A, mean=0, std=0.01)  # Smaller std for stability
        nn.init.normal_(self.lora_B, mean=0, std=0.01)
        
        self.scale = alpha / rank  # Scaling factor
        self.enabled = True  # Toggle for enabling/disabling LoRA
    
    def forward(self, original_weights):
        """
        Applies the LoRA parameterization to the original weights.

        Args:
            original_weights (Tensor): The original weights of the layer.

        Returns:
            Tensor: Updated weights with LoRA adjustments.
        """
        if self.enabled:
            # Compute LoRA update and add to original weights
            lora_update = torch.matmul(self.lora_B, self.lora_A)  # (features_in, features_out)
            return original_weights + lora_update.view(original_weights.shape) * self.scale
        else:
            return original_weights

Add parameterization to out network

In [16]:
def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):
    features_in, features_out = layer.weight.shape
    
    return LoRAParametrization(
        features_in, features_out, rank=rank, alpha=lora_alpha, device=device
    )


In [17]:
parametrize.register_parametrization(
    net.linear1, 'weight', linear_layer_parameterization(net.linear1, device)
)

parametrize.register_parametrization(
    net.linear2, 'weight', linear_layer_parameterization(net.linear2, device)
)

parametrize.register_parametrization(
    net.linear3, 'weight', linear_layer_parameterization(net.linear3, device)
)

ParametrizedLinear(
  in_features=2000, out_features=10, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): LoRAParametrization()
    )
  )
)

In [35]:
def enable_disable_lora(enabled=True):
    for layer in [net.linear1, net.linear2, net.linear3]:
        # Corrected 'parametrizations' spelling
        layer.parametrizations['weight'][0].enabled = enabled

In [23]:
total_parameters_lora = 0
total_parameters_non_lora = 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_parameters_lora += (
        layer.parametrizations['weight'][0].lora_A.nelement() 
        + layer.parametrizations['weight'][0].lora_B.nelement()
    )
    
    total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
    
    print(
        f"Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + "
        f"Lora_A: {layer.parametrizations['weight'][0].lora_A.shape} + "
        f"Lora_B: {layer.parametrizations['weight'][0].lora_B.shape}"
    )

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000]) + Lora_A: torch.Size([1, 784]) + Lora_B: torch.Size([1000, 1])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000]) + Lora_A: torch.Size([1, 1000]) + Lora_B: torch.Size([2000, 1])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10]) + Lora_A: torch.Size([1, 2000]) + Lora_B: torch.Size([10, 1])


In [25]:
assert total_parameters_non_lora == total_parameters_original
print(f'Total number of parameters (original): {total_parameters_non_lora:,}')
print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}')
print(f'Parameters introduced by LoRA: {total_parameters_lora:,}')
parameters_increment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f'Parameters incremenet: {parameters_increment:.3f}%')

Total number of parameters (original): 2,807,010
Total number of parameters (original + LoRA): 2,813,804
Parameters introduced by LoRA: 6,794
Parameters incremenet: 0.242%


Freeze all non-LoRA parameters. Fine-tune on digit 9 for 100 batches.

In [26]:
for name, param in net.named_parameters():
    if 'lora' not in name:
        print(f'Freezing non-lora parameters {name}')
        param.requires_grad = False

Freezing non-lora parameters linear1.bias
Freezing non-lora parameters linear1.parametrizations.weight.original
Freezing non-lora parameters linear2.bias
Freezing non-lora parameters linear2.parametrizations.weight.original
Freezing non-lora parameters linear3.bias
Freezing non-lora parameters linear3.parametrizations.weight.original


In [27]:
mnist_trainset = datasets.MNIST(root='./data', train=True, download= True, transform= transform)
exclude_indices = mnist_trainset.targets == 9
mnist_trainset.data = mnist_trainset.data[exclude_indices]
mnist_trainset.targets = mnist_trainset.targets[exclude_indices]

train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)
train(train_loader, net, epochs=1, total_iterations_limit=100)

Epoch 1:  99%|█████████▉| 99/100 [00:02<00:00, 46.27it/s, loss=0.0823]


Verify that fine-tuning has not altered the original weights, but only the ones introduced by LoRA.

In [30]:
# Check frozen parameters are still unchanged by fine-tuning
assert torch.all(net.linear1.parametrizations.weight.original == original_weights['linear1.weight'])
assert torch.all(net.linear2.parametrizations.weight.original == original_weights['linear2.weight'])
assert torch.all(net.linear3.parametrizations.weight.original == original_weights['linear3.weight'])

In [36]:
enable_disable_lora(enabled=True)
# The new linear1.weights is obtained by the forward function of our LoRA parametization
# The original weights have been moved to net.linear1.parametrizations.weight.original
assert torch.equal(net.linear1.weight,
                   net.linear1.parametrizations.weight.original
                   + (net.linear1.parametrizations.weight[0].lora_B @ net.linear1.parametrizations.weight[0].lora_A).view(net.linear1.weight.shape)
)

# If we disable LoRA, linear1.weight is the original
enable_disable_lora(enabled=False)
assert torch.equal(net.linear1.weight, original_weights['linear1.weight'])

Test the network with LoRA ENABLED. The digit 9 accuracy should improve.

In [37]:
enable_disable_lora(enabled=True)
test()

Test: 100%|██████████| 1000/1000 [00:05<00:00, 192.76it/s]

Accuracy: 0.929
wrong counts for the digit 0: 68
wrong counts for the digit 1: 34
wrong counts for the digit 2: 65
wrong counts for the digit 3: 88
wrong counts for the digit 4: 91
wrong counts for the digit 5: 20
wrong counts for the digit 6: 112
wrong counts for the digit 7: 100
wrong counts for the digit 8: 114
wrong counts for the digit 9: 13





Test with LoRA DISABLED

In [38]:
enable_disable_lora(enabled=False)
test()

Test: 100%|██████████| 1000/1000 [00:04<00:00, 223.02it/s]

Accuracy: 0.955
wrong counts for the digit 0: 36
wrong counts for the digit 1: 23
wrong counts for the digit 2: 48
wrong counts for the digit 3: 69
wrong counts for the digit 4: 18
wrong counts for the digit 5: 15
wrong counts for the digit 6: 86
wrong counts for the digit 7: 47
wrong counts for the digit 8: 41
wrong counts for the digit 9: 71



