In [156]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch.nn.utils.parametrize as parametrize

In [157]:
# CIFAR10 normalization values
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2470, 0.2435, 0.2616))
])

# Load CIFAR10 dataset
cifar_trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(cifar_trainset, batch_size=64, shuffle=True)

cifar_testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(cifar_testset, batch_size=64, shuffle=False)

# Define device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [158]:
# CIFAR10 input (32*32*3 = 3072)
class classifier(nn.Module):
    def __init__(self, hidden_size_1=1000, hidden_size_2=2000):
        super(classifier,self).__init__()
        self.linear1 = nn.Linear(32*32*3, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()

    def forward(self, img):
        x = img.view(-1, 32*32*3)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x

net = classifier().to(device)

In [159]:
# Training loop
def train_model(train_loader, net, epochs=5, total_iterations_limit=None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

    total_iterations = 0
    for epoch in range(epochs):
        net.train()
        loss_sum = 0
        num_iterations = 0
        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit

        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x, y = x.to(device), y.to(device)

            optimizer.zero_grad()
            output = net(x.view(-1, 32*32*3))
            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return

# Train baseline
train_model(train_loader, net, epochs=2)

Epoch 1: 100%|██████████| 782/782 [00:17<00:00, 45.83it/s, loss=1.69]
Epoch 2: 100%|██████████| 782/782 [00:17<00:00, 44.01it/s, loss=1.49]


In [160]:
# Save original weights
original_weights = {}
for name, param in net.named_parameters():
    original_weights[name] = param.clone().detach()

In [161]:
# Test loop
def test_model():
    correct, total = 0, 0
    wrong_counts = [0 for i in range(10)]

    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x, y = x.to(device), y.to(device)
            output = net(x.view(-1, 32*32*3))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct += 1
                else:
                    wrong_counts[y[idx]] += 1
                total += 1
    print(f'Accuracy: {round(correct/total, 3)}')
    for i in range(len(wrong_counts)):
        print(f'wrong counts for class {i}: {wrong_counts[i]}')

test_model()

Testing: 100%|██████████| 157/157 [00:03<00:00, 48.03it/s]

Accuracy: 0.471
wrong counts for class 0: 534
wrong counts for class 1: 383
wrong counts for class 2: 746
wrong counts for class 3: 749
wrong counts for class 4: 439
wrong counts for class 5: 643
wrong counts for class 6: 623
wrong counts for class 7: 433
wrong counts for class 8: 289
wrong counts for class 9: 449





In [162]:
# Count parameters
total_parameters_original = 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_parameters_original += layer.weight.nelement() + layer.bias.nelement()
    print(f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape}')
print(f'Total number of parameters: {total_parameters_original:,}')

Layer 1: W: torch.Size([1000, 3072]) + B: torch.Size([1000])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10])
Total number of parameters: 5,095,010


In [163]:
# LoRA Parametrization
class LoRAParametrization(nn.Module):
    def __init__(self, features_in, features_out, rank=1, alpha=1, device='cpu'):
        super().__init__()
        self.lora_A = nn.Parameter(torch.zeros((rank,features_out)).to(device))
        self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
        nn.init.normal_(self.lora_A, mean=0, std=1)
        self.scale = alpha / rank
        self.enabled = True

    def forward(self, original_weights):
        if self.enabled:
            return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
        else:
            return original_weights

def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):
    features_in, features_out = layer.weight.shape
    return LoRAParametrization(
        features_in, features_out, rank=rank, alpha=lora_alpha, device=device
    )

def enable_disable_lora(enabled=True):
    for layer in [net.linear1, net.linear2, net.linear3]:
        layer.parametrizations["weight"][0].enabled = enabled



In [164]:
# Register parametrizations
parametrize.register_parametrization(
    net.linear1, "weight", linear_layer_parameterization(net.linear1, device)
)
parametrize.register_parametrization(
    net.linear2, "weight", linear_layer_parameterization(net.linear2, device)
)
parametrize.register_parametrization(
    net.linear3, "weight", linear_layer_parameterization(net.linear3, device)
)

# Count parameters with LoRA
total_parameters_lora, total_parameters_non_lora = 0, 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_parameters_lora += layer.parametrizations["weight"][0].lora_A.nelement() + layer.parametrizations["weight"][0].lora_B.nelement()
    total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
    print(f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + '
          f'Lora_A: {layer.parametrizations["weight"][0].lora_A.shape} + '
          f'Lora_B: {layer.parametrizations["weight"][0].lora_B.shape}')

assert total_parameters_non_lora == total_parameters_original
print(f'Total number of parameters (original): {total_parameters_non_lora:,}')
print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}')
print(f'Parameters introduced by LoRA: {total_parameters_lora:,}')
parameters_incremment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f'Parameters incremment: {parameters_incremment:.3f}%')

Layer 1: W: torch.Size([1000, 3072]) + B: torch.Size([1000]) + Lora_A: torch.Size([1, 3072]) + Lora_B: torch.Size([1000, 1])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000]) + Lora_A: torch.Size([1, 1000]) + Lora_B: torch.Size([2000, 1])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10]) + Lora_A: torch.Size([1, 2000]) + Lora_B: torch.Size([10, 1])
Total number of parameters (original): 5,095,010
Total number of parameters (original + LoRA): 5,104,092
Parameters introduced by LoRA: 9,082
Parameters incremment: 0.178%


In [165]:
# Freeze non-LoRA params
for name, param in net.named_parameters():
    if 'lora' not in name:
        print(f'Freezing non-LoRA parameter {name}')
        param.requires_grad = False

# Fine-tune only on CIFAR10 class "cat" (label 3)
cifar_trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
exclude_indices = torch.tensor(cifar_trainset.targets) == 5
cifar_trainset.data = cifar_trainset.data[exclude_indices]
cifar_trainset.targets = torch.tensor(cifar_trainset.targets)[exclude_indices]
train_loader = torch.utils.data.DataLoader(cifar_trainset, batch_size=64, shuffle=True)

# Train LoRA on class 3 (cat) only
train_model(train_loader, net, epochs=1, total_iterations_limit=100)

Freezing non-LoRA parameter linear1.bias
Freezing non-LoRA parameter linear1.parametrizations.weight.original
Freezing non-LoRA parameter linear2.bias
Freezing non-LoRA parameter linear2.parametrizations.weight.original
Freezing non-LoRA parameter linear3.bias
Freezing non-LoRA parameter linear3.parametrizations.weight.original


Epoch 1:  79%|███████▉  | 79/100 [00:01<00:00, 40.45it/s, loss=0.334]


In [166]:
# Check frozen params unchanged
assert torch.all(net.linear1.parametrizations.weight.original == original_weights['linear1.weight'])
assert torch.all(net.linear2.parametrizations.weight.original == original_weights['linear2.weight'])
assert torch.all(net.linear3.parametrizations.weight.original == original_weights['linear3.weight'])

# Enable/disable LoRA checks
enable_disable_lora(enabled=True)
assert torch.equal(net.linear1.weight, net.linear1.parametrizations.weight.original + (net.linear1.parametrizations.weight[0].lora_B @ net.linear1.parametrizations.weight[0].lora_A) * net.linear1.parametrizations.weight[0].scale)

enable_disable_lora(enabled=False)
assert torch.equal(net.linear1.weight, original_weights['linear1.weight'])

print("\nWith LoRA finetuning")
# Test LoRA enabled
enable_disable_lora(enabled=True)
test_model()

print("\nWithout LoRA finetuning")
# Test LoRA disabled
enable_disable_lora(enabled=False)
test_model()


With LoRA finetuning


Testing: 100%|██████████| 157/157 [00:04<00:00, 37.05it/s]


Accuracy: 0.139
wrong counts for class 0: 987
wrong counts for class 1: 923
wrong counts for class 2: 1000
wrong counts for class 3: 998
wrong counts for class 4: 994
wrong counts for class 5: 9
wrong counts for class 6: 995
wrong counts for class 7: 939
wrong counts for class 8: 876
wrong counts for class 9: 888

Without LoRA finetuning


Testing: 100%|██████████| 157/157 [00:03<00:00, 48.67it/s]

Accuracy: 0.471
wrong counts for class 0: 534
wrong counts for class 1: 383
wrong counts for class 2: 746
wrong counts for class 3: 749
wrong counts for class 4: 439
wrong counts for class 5: 643
wrong counts for class 6: 623
wrong counts for class 7: 433
wrong counts for class 8: 289
wrong counts for class 9: 449



