In [53]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import torch.optim as optimizer
import timm
import matplotlib.pyplot as plt
import numpy as np

In [43]:
transform = transforms.Compose([
    transforms.Resize(224), # ViT requires 224x224 input
    transforms.Grayscale(num_output_channels=3), # Convert to 3 channels
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Normalize image using mean and std deviation
])

train_ds = datasets.FashionMNIST(root="./fashion_data", train=True, download=True, transform=transform)
test_ds = datasets.FashionMNIST(root="./fashion_data", train=False, download=True, transform=transform)

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=264, shuffle=False)

num_classes = len(train_loader.dataset.classes)

In [None]:
# checks
train_loader.dataset.data.shape
num_classes

152

## Load pretrained vit model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu"); print("Using device:", device)
vit_model = models.vit_b_16(weights='DEFAULT')
vit_model.to(device)
vit_model.eval()  # freeze backbone

# Freeze all parameters
for param in vit_model.parameters():
    param.requires_grad = False

n_features = vit_model.heads.head.in_features
print(f'there are {n_features} number of input parameters in the last layer')

vit_model.heads.head = nn.Sequential(nn.Linear(n_features, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, num_classes)) # Replace classification head:
for i in [-1,-2]: # Train last 2 layers besides the final classification layer.
    for param in vit_model.encoder.layers[i].parameters():
        param.requires_grad = True

# # Safety check
# print("Trainable parameters in the model:")
# for name, param in vit_model.named_parameters():
#     if param.requires_grad:
#         print(f"  {name}")

Using device: cpu
there are 768 number of output parameters in the last layer. So unfreeze it to learn the 10 in fashion mnist
Trainable parameters in the model:
  encoder.layers.encoder_layer_10.ln_1.weight
  encoder.layers.encoder_layer_10.ln_1.bias
  encoder.layers.encoder_layer_10.self_attention.in_proj_weight
  encoder.layers.encoder_layer_10.self_attention.in_proj_bias
  encoder.layers.encoder_layer_10.self_attention.out_proj.weight
  encoder.layers.encoder_layer_10.self_attention.out_proj.bias
  encoder.layers.encoder_layer_10.ln_2.weight
  encoder.layers.encoder_layer_10.ln_2.bias
  encoder.layers.encoder_layer_10.mlp.0.weight
  encoder.layers.encoder_layer_10.mlp.0.bias
  encoder.layers.encoder_layer_10.mlp.3.weight
  encoder.layers.encoder_layer_10.mlp.3.bias
  encoder.layers.encoder_layer_11.ln_1.weight
  encoder.layers.encoder_layer_11.ln_1.bias
  encoder.layers.encoder_layer_11.self_attention.in_proj_weight
  encoder.layers.encoder_layer_11.self_attention.in_proj_bias
  en

## Simple BiLoRA layer Class

$$W_t = W_0 + \Delta W_t$$
where $$\Delta W_t = UB_tV^T$$

In [None]:
class BiLoRALinear(nn.Module):
    def __init__(self, in_dim, out_dim, n_frq=100, alpha=100.0, n_tasks=3):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.n_frq = n_frq
        self.alpha = alpha
        self.n_tasks = n_tasks

        # Base weight (frozen backbone)
        self.base_weight = nn.Parameter(torch.randn(out_dim, in_dim) * 0.02)
        self.base_bias = nn.Parameter(torch.zeros(out_dim))

        # Per-task frequency coefficients
        self.coef = nn.ParameterList([
            nn.Parameter(torch.randn(n_frq), requires_grad=True)
            for _ in range(n_tasks)
        ])
        self.indices = [self._select_positions(t) for t in range(n_tasks)]

    def _select_positions(self, t, seed=777):
        torch.manual_seed(seed + t)
        idx = torch.randperm(self.out_dim * self.in_dim)[:self.n_frq]
        return torch.stack([idx // self.in_dim, idx % self.in_dim])

    def _get_delta_w(self, task):
        # frequency-domain delta
        Freq = torch.zeros(self.out_dim, self.in_dim, device=self.base_weight.device)
        Freq[self.indices[task][0], self.indices[task][1]] = self.coef[task]
        delta_w = torch.fft.ifft2(Freq).real * self.alpha
        return delta_w

    def forward(self, x, task_id=0):
        delta_w = self._get_delta_w(task_id)
        w_eff = self.base_weight + delta_w
        return F.linear(x, w_eff, self.base_bias)
