<a href="https://colab.research.google.com/github/thebhulawat/DPO/blob/main/MNIST_Lora_from_Scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [50]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm

In [51]:
_ = torch.manual_seed(42)

In [52]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size = 10, shuffle = True)
mnist_testset = datasets.MNIST(root='./data', train = False, download = True, transform = transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size = 10, shuffle=True)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [66]:
class Net(nn.Module):
  def __init__(self, hidden_size_1 = 1000, hidden_size_2 = 2000):
    super(Net, self).__init__()
    self.linear1 = nn.Linear(28*28, hidden_size_1)
    self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
    self.linear3 = nn.Linear(hidden_size_2, 10)
    self.relu = nn.ReLU()

  def forward(self, img):
    x = img.view(-1, 28* 28)
    x = self.relu(self.linear1(x))
    x = self.relu(self.linear2(x))
    x = self.linear3(x)
    return x

net = Net().to(device)

In [67]:
def train(train_loader, net, epochs = 5, total_iterations_limit = None):
  cross_el = nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
  total_iterations = 0
  for epoch in range(epochs):
    net.train()
    loss_sum = 0
    num_iterations = 0
    data_iterator = tqdm(train_loader, desc = f'Epoch {epoch + 1}')
    if total_iterations_limit is not None:
      data_iterator.total = total_iterations_limit
    for data in data_iterator:
      num_iterations += 1
      total_iterations += 1
      x,y = data
      x = x.to(device)
      y = y.to(device)
      optimizer.zero_grad()
      output = net(x.view(-1, 28*28))
      loss = cross_el(output, y)
      loss_sum += loss.item()
      avg_loss = loss_sum / num_iterations
      data_iterator.set_postfix(loss = avg_loss)
      loss.backward()
      optimizer.step()
      if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
        return

train(train_loader, net, epochs = 1)


Epoch 1: 100%|██████████| 592/592 [00:04<00:00, 128.95it/s, loss=0.00414]


In [68]:
original_weights = {}
for name, param in net.named_parameters():
  original_weights[name]= param.clone().detach()

for name, weight in list(original_weights.items())[:10]:  # Show the first 3 weights
    print(f"Layer: {name}, Weight Sample: {weight[:2]}")

Layer: linear1.weight, Weight Sample: tensor([[-0.0325, -0.0392, -0.0175,  ..., -0.0289, -0.0327, -0.0400],
        [ 0.0364,  0.0295, -0.0269,  ...,  0.0309,  0.0356, -0.0030]],
       device='cuda:0')
Layer: linear1.bias, Weight Sample: tensor([0.0140, 0.0215], device='cuda:0')
Layer: linear2.weight, Weight Sample: tensor([[-0.0052,  0.0025, -0.0201,  ...,  0.0357,  0.0350, -0.0062],
        [ 0.0308,  0.0264, -0.0237,  ...,  0.0220,  0.0358,  0.0084]],
       device='cuda:0')
Layer: linear2.bias, Weight Sample: tensor([ 2.7604e-05, -1.0695e-02], device='cuda:0')
Layer: linear3.weight, Weight Sample: tensor([[ 0.0005,  0.0050, -0.0277,  ..., -0.0145, -0.0197,  0.0007],
        [-0.0217,  0.0138, -0.0176,  ..., -0.0266, -0.0002,  0.0093]],
       device='cuda:0')
Layer: linear3.bias, Weight Sample: tensor([-0.0034, -0.0042], device='cuda:0')


In [69]:
def test():
    correct = 0
    total = 0

    wrong_counts = [0 for i in range(10)]

    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = net(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct +=1
                else:
                    wrong_counts[y[idx]] +=1
                total +=1
    print(f'Accuracy: {round(correct/total, 3)}')
    for i in range(len(wrong_counts)):
        print(f'wrong counts for the digit {i}: {wrong_counts[i]}')

test()

Testing: 100%|██████████| 1000/1000 [00:03<00:00, 305.99it/s]

Accuracy: 0.096
wrong counts for the digit 0: 980
wrong counts for the digit 1: 1135
wrong counts for the digit 2: 1032
wrong counts for the digit 3: 1010
wrong counts for the digit 4: 982
wrong counts for the digit 5: 892
wrong counts for the digit 6: 0
wrong counts for the digit 7: 1028
wrong counts for the digit 8: 974
wrong counts for the digit 9: 1009





In [70]:
total_parameters_original = 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
  total_parameters_original += layer.weight.nelement() + layer.bias.nelement()
  print(f'Layer {index + 1}: W: {layer.weight.shape} + B: {layer.bias.shape}')

print(f'Total number of parameters: {total_parameters_original}')

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10])
Total number of parameters: 2807010


In [71]:
class LoRAParametrization(nn.Module):
    def __init__(self, features_in, features_out, rank=1, alpha=1, device='cpu'):
        super().__init__()
        # Section 4.1 of the paper:
        #   We use a random Gaussian initialization for A and zero for B, so ∆W = BA is zero at the beginning of training
        self.lora_A = nn.Parameter(torch.zeros((rank,features_out)).to(device))
        self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
        nn.init.normal_(self.lora_A, mean=0, std=1)

        # Section 4.1 of the paper:
        #   We then scale ∆Wx by α/r , where α is a constant in r.
        #   When optimizing with Adam, tuning α is roughly the same as tuning the learning rate if we scale the initialization appropriately.
        #   As a result, we simply set α to the first r we try and do not tune it.
        #   This scaling helps to reduce the need to retune hyperparameters when we vary r.
        self.scale = alpha / rank
        self.enabled = True

    def forward(self, original_weights):
        if self.enabled:
            # Return W + (B*A)*scale
            return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
        else:
            return original_weights

In [72]:
import torch.nn.utils.parametrize as parametrize

def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):
    # Only add the parameterization to the weight matrix, ignore the Bias

    # From section 4.2 of the paper:
    #   We limit our study to only adapting the attention weights for downstream tasks and freeze the MLP modules (so they are not trained in downstream tasks) both for simplicity and parameter-efficiency.
    #   [...]
    #   We leave the empirical investigation of [...], and biases to a future work.

    features_in, features_out = layer.weight.shape
    return LoRAParametrization(
        features_in, features_out, rank=rank, alpha=lora_alpha, device=device
    )

parametrize.register_parametrization(
    net.linear1, "weight", linear_layer_parameterization(net.linear1, device)
)
parametrize.register_parametrization(
    net.linear2, "weight", linear_layer_parameterization(net.linear2, device)
)
parametrize.register_parametrization(
    net.linear3, "weight", linear_layer_parameterization(net.linear3, device)
)


def enable_disable_lora(enabled=True):
    for layer in [net.linear1, net.linear2, net.linear3]:
        layer.parametrizations["weight"][0].enabled = enabled

In [73]:
total_parameters_lora = 0

total_parameters_non_lora = 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
  total_parameters_lora += layer.parametrizations["weight"][0].lora_A.nelement() + layer.parametrizations["weight"][0].lora_B.nelement()
  total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
  print(f'Layer {index + 1}: W: {layer.weight.shape} + A: {layer.parametrizations["weight"][0].lora_A.shape} + B: {layer.parametrizations["weight"][0].lora_B.shape}')

assert total_parameters_non_lora == total_parameters_original
print(f'Total number of parameters (original): {total_parameters_non_lora:,}')
print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}')
print(f'Parameters introduced by LoRA: {total_parameters_lora:,}')
parameters_incremment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f'Parameters incremment: {parameters_incremment:.3f}%')


Layer 1: W: torch.Size([1000, 784]) + A: torch.Size([1, 784]) + B: torch.Size([1000, 1])
Layer 2: W: torch.Size([2000, 1000]) + A: torch.Size([1, 1000]) + B: torch.Size([2000, 1])
Layer 3: W: torch.Size([10, 2000]) + A: torch.Size([1, 2000]) + B: torch.Size([10, 1])
Total number of parameters (original): 2,807,010
Total number of parameters (original + LoRA): 2,813,804
Parameters introduced by LoRA: 6,794
Parameters incremment: 0.242%


In [74]:
for name, param in net.named_parameters():
    if 'lora' not in name:
        print(f'Freezing non-LoRA parameter {name}')
        param.requires_grad = False

# Load the MNIST dataset again, by keeping only the digit 9
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
exclude_indices = mnist_trainset.targets == 6
mnist_trainset.data = mnist_trainset.data[exclude_indices]
mnist_trainset.targets = mnist_trainset.targets[exclude_indices]
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Train the network with LoRA only on the digit 9 and only for 100 batches (hoping that it would improve the performance on the digit 9)
train(train_loader, net, epochs=1, total_iterations_limit=100)



Freezing non-LoRA parameter linear1.bias
Freezing non-LoRA parameter linear1.parametrizations.weight.original
Freezing non-LoRA parameter linear2.bias
Freezing non-LoRA parameter linear2.parametrizations.weight.original
Freezing non-LoRA parameter linear3.bias
Freezing non-LoRA parameter linear3.parametrizations.weight.original


Epoch 1:  99%|█████████▉| 99/100 [00:00<00:00, 132.12it/s, loss=0]


In [75]:
assert torch.all(net.linear1.parametrizations.weight.original == original_weights['linear1.weight'])
assert torch.all(net.linear2.parametrizations.weight.original == original_weights['linear2.weight'])
assert torch.all(net.linear3.parametrizations.weight.original == original_weights['linear3.weight'])

enable_disable_lora(enabled=True)

assert torch.equal(net.linear1.weight, net.linear1.parametrizations.weight.original + (net.linear1.parametrizations.weight[0].lora_B @ net.linear1.parametrizations.weight[0].lora_A) * net.linear1.parametrizations.weight[0].scale)

enable_disable_lora(enabled=False)
assert torch.equal(net.linear1.weight, original_weights['linear1.weight'])
assert torch.equal(net.linear2.weight, original_weights['linear2.weight'])
assert torch.equal(net.linear3.weight, original_weights['linear3.weight'])


In [76]:
enable_disable_lora(enabled=True)
test()

Testing: 100%|██████████| 1000/1000 [00:04<00:00, 247.75it/s]

Accuracy: 0.096
wrong counts for the digit 0: 980
wrong counts for the digit 1: 1135
wrong counts for the digit 2: 1032
wrong counts for the digit 3: 1010
wrong counts for the digit 4: 982
wrong counts for the digit 5: 892
wrong counts for the digit 6: 0
wrong counts for the digit 7: 1028
wrong counts for the digit 8: 974
wrong counts for the digit 9: 1009





In [77]:
enable_disable_lora(enabled=False)
test()

Testing: 100%|██████████| 1000/1000 [00:03<00:00, 254.85it/s]

Accuracy: 0.096
wrong counts for the digit 0: 980
wrong counts for the digit 1: 1135
wrong counts for the digit 2: 1032
wrong counts for the digit 3: 1010
wrong counts for the digit 4: 982
wrong counts for the digit 5: 892
wrong counts for the digit 6: 0
wrong counts for the digit 7: 1028
wrong counts for the digit 8: 974
wrong counts for the digit 9: 1009



