<a href="https://colab.research.google.com/github/weagan/Convolutional-Neural-Networks/blob/main/forgetting_continual_LoRA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np

## 1. Standard LoRA (Stage 1)
class SimpleLoRALinear(nn.Module):
    def __init__(self, weight, rank=16):
        super().__init__()
        self.out_features, self.in_features = weight.shape
        self.register_buffer("weight", weight.clone())
        self.lora_A = nn.Parameter(torch.randn(rank, self.in_features) * 0.01)
        self.lora_B = nn.Parameter(torch.zeros(self.out_features, rank))
        self.scaling = 0.05

    def forward(self, x):
        delta_w = self.lora_B @ self.lora_A
        return x @ (self.weight + delta_w * self.scaling).T

## 2. Subspace LoRA (Stage 2 & 3)
class SubspaceLoRALinear(nn.Module):
    def __init__(self, weight, rank=16):
        super().__init__()
        self.out_features, self.in_features = weight.shape
        self.register_buffer("weight", weight.clone())
        self.rank = rank
        self.register_buffer("basis", torch.zeros(rank, self.out_features * self.in_features))
        self.coefficients = nn.Parameter(torch.zeros(rank))
        self.scaling = 0.05

    def forward(self, x):
        delta_w = (self.coefficients @ self.basis).view(self.out_features, self.in_features)
        return x @ (self.weight + delta_w * self.scaling).T

## 3. Manager
class ContinualSubspaceManager:
    def __init__(self, model_name="distilbert-base-uncased", rank=16):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.rank = rank
        self.task_coefficients = {}

        for param in self.model.parameters():
            param.requires_grad = False

    def inject_standard_lora(self):
        for name, module in self.model.named_modules():
            if any(tgt in name for tgt in ["attention.out_lin", "attention.v_lin"]):
                parent = self._get_parent(name)
                setattr(parent, name.split('.')[-1], SimpleLoRALinear(module.weight, self.rank))

    def convert_to_subspace(self):
        with torch.no_grad():
            for name, module in list(self.model.named_modules()):
                if isinstance(module, SimpleLoRALinear):
                    delta_w = module.lora_B @ module.lora_A
                    U, S, Vh = torch.linalg.svd(delta_w, full_matrices=False)
                    new_module = SubspaceLoRALinear(module.weight, self.rank).to(self.device)
                    for r in range(self.rank):
                        new_module.basis[r] = torch.outer(U[:, r], Vh[r, :]).flatten()
                    setattr(self._get_parent(name), name.split('.')[-1], new_module)

    def _get_parent(self, name):
        parent = self.model
        for part in name.split('.')[:-1]:
            parent = getattr(parent, part)
        return parent

    def train_task(self, task_name, loader, is_subspace=True, epochs=1):
        print(f"\n🔥 Training {task_name} ({'Subspace' if is_subspace else 'Standard LoRA'})")
        keyword = "coefficients" if is_subspace else "lora_"
        params = [p for n, p in self.model.named_parameters() if keyword in n]
        optimizer = optim.AdamW(params, lr=1e-3)

        self.model.train()
        for epoch in range(epochs):
            for batch in tqdm(loader, leave=False):
                optimizer.zero_grad()
                labels = batch.pop("labels").to(self.device)
                inputs = {k: v.to(self.device) for k, v in batch.items()}
                outputs = self.model(**inputs, labels=labels)
                outputs.loss.backward()
                optimizer.step()

        if is_subspace:
            self.task_coefficients[task_name] = {n: p.clone().detach() for n, p in self.model.named_parameters() if "coefficients" in n}

    def evaluate_task(self, loader):
        self.model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for batch in loader:
                labels = batch.pop("labels").to(self.device)
                inputs = {k: v.to(self.device) for k, v in batch.items()}
                outputs = self.model(**inputs)
                correct += (outputs.logits.argmax(-1) == labels).sum().item()
                total += labels.size(0)
        return correct / total

    def load_task_weights(self, task_name):
        if task_name in self.task_coefficients:
            with torch.no_grad():
                for n, p in self.model.named_parameters():
                    if n in self.task_coefficients[task_name]:
                        p.copy_(self.task_coefficients[task_name][n])

## 4. Execution Logic
def get_loader(task, tokenizer):
    ds = load_dataset("glue", task)
    def tokenize_fn(ex):
        t = (ex["sentence"],) if "sentence" in ex else (ex["sentence1"], ex["sentence2"])
        res = tokenizer(*t, truncation=True, padding=False)
        res["labels"] = ex["label"]
        return res
    tok = ds.map(tokenize_fn, batched=True, remove_columns=ds["train"].column_names)
    return DataLoader(tok["train"], batch_size=16, shuffle=True, collate_fn=DataCollatorWithPadding(tokenizer)), \
           DataLoader(tok["validation"], batch_size=16, collate_fn=DataCollatorWithPadding(tokenizer))

manager = ContinualSubspaceManager()
tasks = ["cola", "mrpc", "sst2"]
loaders = {t: get_loader(t, manager.tokenizer) for t in tasks}
forgetting_table = np.zeros((3, 3))

# Stage 1: CoLA Warmup
manager.inject_standard_lora()
manager.train_task("cola", loaders["cola"][0], is_subspace=False)
manager.convert_to_subspace() # Creates the shared Basis

# Stage 2 & 3: Train others in subspace and evaluate
for i, task in enumerate(tasks):
    if i > 0: # CoLA is already trained
        manager.train_task(task, loaders[task][0], is_subspace=True)

    for j in range(i + 1):
        prev_task = tasks[j]
        manager.load_task_weights(prev_task)
        acc = manager.evaluate_task(loaders[prev_task][1])
        forgetting_table[i, j] = acc

## 5. Result Display
print("\n--- FORGETTING TABLE (Accuracy) ---")
print(f"{'State':<12} | {'CoLA':<6} | {'MRPC':<6} | {'SST2':<6}")
for i, row in enumerate(forgetting_table):
    row_str = " | ".join([f"{v:.3f}" if v > 0 else "-----" for v in row])
    print(f"After {tasks[i]:<7} | {row_str}")

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Loading weights:   0%|          | 0/100 [00:00<?, ?it/s]

DistilBertForSequenceClassification LOAD REPORT from: distilbert-base-uncased
Key                     | Status     | 
------------------------+------------+-
vocab_transform.weight  | UNEXPECTED | 
vocab_transform.bias    | UNEXPECTED | 
vocab_layer_norm.bias   | UNEXPECTED | 
vocab_projector.bias    | UNEXPECTED | 
vocab_layer_norm.weight | UNEXPECTED | 
pre_classifier.weight   | MISSING    | 
classifier.weight       | MISSING    | 
classifier.bias         | MISSING    | 
pre_classifier.bias     | MISSING    | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING	:those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

cola/train-00000-of-00001.parquet:   0%|          | 0.00/251k [00:00<?, ?B/s]

cola/validation-00000-of-00001.parquet:   0%|          | 0.00/37.6k [00:00<?, ?B/s]

cola/test-00000-of-00001.parquet:   0%|          | 0.00/37.7k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/8551 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1043 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1063 [00:00<?, ? examples/s]

Map:   0%|          | 0/8551 [00:00<?, ? examples/s]

Map:   0%|          | 0/1043 [00:00<?, ? examples/s]

Map:   0%|          | 0/1063 [00:00<?, ? examples/s]

mrpc/train-00000-of-00001.parquet:   0%|          | 0.00/649k [00:00<?, ?B/s]

mrpc/validation-00000-of-00001.parquet:   0%|          | 0.00/75.7k [00:00<?, ?B/s]

mrpc/test-00000-of-00001.parquet:   0%|          | 0.00/308k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/3668 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/408 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1725 [00:00<?, ? examples/s]

Map:   0%|          | 0/3668 [00:00<?, ? examples/s]

Map:   0%|          | 0/408 [00:00<?, ? examples/s]

Map:   0%|          | 0/1725 [00:00<?, ? examples/s]

sst2/train-00000-of-00001.parquet:   0%|          | 0.00/3.11M [00:00<?, ?B/s]

sst2/validation-00000-of-00001.parquet:   0%|          | 0.00/72.8k [00:00<?, ?B/s]

sst2/test-00000-of-00001.parquet:   0%|          | 0.00/148k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/67349 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/872 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1821 [00:00<?, ? examples/s]

Map:   0%|          | 0/67349 [00:00<?, ? examples/s]

Map:   0%|          | 0/872 [00:00<?, ? examples/s]

Map:   0%|          | 0/1821 [00:00<?, ? examples/s]


🔥 Training cola (Standard LoRA)





🔥 Training mrpc (Subspace)





🔥 Training sst2 (Subspace)





--- FORGETTING TABLE (Accuracy) ---
State        | CoLA   | MRPC   | SST2  
After cola    | 0.314 | ----- | -----
After mrpc    | 0.321 | 0.333 | -----
After sst2    | 0.480 | 0.333 | 0.557
