In [None]:
import torch
import torch.nn as nn
import numpy as np
from avalanche.benchmarks import RotatedMNIST
from avalanche.models import SimpleMLP
from avalanche.training.supervised import Naive
from avalanche.evaluation.metrics import accuracy_metrics
from avalanche.logging import InteractiveLogger
from avalanche.training.plugins import EvaluationPlugin
from torch.utils.data import DataLoader
from numpy.linalg import lstsq

# Helper function: Load precomputed task vectors
def load_task_vectors(filepath):
    with open(filepath, "rb") as f:
        return pickle.load(f)

# Helper function: Compute task vector
def compute_task_vector(pretrained_model, finetuned_model):
    task_vector = []
    param_shapes = []  # Store parameter shapes
    
    for p_pre, p_fine in zip(pretrained_model.parameters(), finetuned_model.parameters()):
        param_shapes.append(p_pre.shape)  # Save the shape of each parameter
        task_vector.append((p_fine.data - p_pre.data).detach().cpu().numpy())
    
    # Flatten the task vector
    flattened_task_vector = np.concatenate([p.flatten() for p in task_vector])
    return flattened_task_vector, param_shapes

# Helper function: Check if task vector is in the span of active task vectors
def is_in_span(vector, span_vectors):
    if not span_vectors:
        return False, None
    span_matrix = np.stack(span_vectors, axis=1)  # Stack vectors as columns
    coeffs, residuals, _, _ = lstsq(span_matrix, vector, rcond=None)
    in_span = np.allclose(span_matrix @ coeffs, vector, atol=1e-5)
    return in_span, coeffs

# Class for Localize and Stitch
class LocalizeAndStitch:
    def __init__(self, model, pretrained_model, task_vectors_active, sparsity=0.01):
        self.model = model
        self.pretrained_model = pretrained_model
        self.task_vectors_active = task_vectors_active
        self.sparsity = sparsity  # Sparsity level for localization

    def localize_task(self, task_vector):
        # Sparse mask creation
        abs_vector = torch.abs(torch.tensor(task_vector))
        k = int(self.sparsity * abs_vector.numel())
        topk_indices = abs_vector.topk(k).indices
        mask = torch.zeros_like(abs_vector)
        mask[topk_indices] = 1
        return mask

    def apply_task_vector(self, task_vector, mask, param_shapes):
        # Apply sparse task vector updates using the mask and unflatten
        offset = 0
        for p, shape in zip(self.model.parameters(), param_shapes):
            numel = np.prod(shape)  # Number of elements in the parameter
            # Unflatten task vector slice to match parameter shape
            task_slice = torch.tensor(task_vector[offset:offset + numel]).view(shape)
            mask_slice = torch.tensor(mask[offset:offset + numel]).view(shape)
            # Apply updates with the mask
            p.data += (task_slice * mask_slice).to(p.device)
            offset += numel

# Main workflow for continual learning
def continual_learning_with_localize_and_stitch():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Benchmark: RotatedMNIST for new tasks [115°, 145°, 175°]
    rotation_angles = [115, 145, 175]
    rotated_benchmark = RotatedMNIST(n_experiences=len(rotation_angles), seed=1234)

    # Load precomputed task vectors (from [0°, 15°, ..., 90°])
    task_vectors_active = load_task_vectors("rotated_task_vectors.pkl")

    # Model initialization
    model_base = SimpleMLP(num_classes=10).to(device)
    optimizer = torch.optim.SGD(model_base.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()

    # Pretrained model
    model_pretrained = SimpleMLP(num_classes=10).to(device)
    model_pretrained.load_state_dict(model_base.state_dict())

    # Evaluation plugin
    eval_plugin = EvaluationPlugin(
        accuracy_metrics(epoch=True, stream=True),
        loggers=[InteractiveLogger()]
    )

    # Localize and Stitch handler
    las_handler = LocalizeAndStitch(model_base, model_pretrained, task_vectors_active)

    # Training loop over new tasks
    for task_id, experience in enumerate(rotated_benchmark.train_stream):
        print(f"\n### Training on Task {task_id+1} (Rotation: {rotation_angles[task_id]}°) ###")

        # Fine-tune model on current task
        model_finetuned = SimpleMLP(num_classes=10).to(device)
        model_finetuned.load_state_dict(model_pretrained.state_dict())

        trainer = Naive(
            model_finetuned,
            optimizer,
            criterion,
            train_mb_size=128,
            device=device
        )
        trainer.train(experience, epochs=1)

        # Compute the task vector for the current task
        task_vector, param_shapes = compute_task_vector(model_pretrained, model_finetuned)

        # Check if the task vector is in the span of active task vectors
        in_span, coefficients = is_in_span(task_vector, task_vectors_active)
        if in_span:
            print(f"Task vector {task_id+1} is in span of active task vectors.")
        else:
            print(f"Task vector {task_id+1} is NOT in span. Adding to active set.")
            task_vectors_active.append(task_vector)

        # Localize task vector and apply it to the model
        sparse_mask = las_handler.localize_task(task_vector, param_shapes)
        las_handler.apply_task_vector(task_vector, sparse_mask)

        # Update the pretrained model
        model_pretrained.load_state_dict(model_finetuned.state_dict())

    # Evaluate on all tasks
    print("\n### Evaluating Model ###")
    evaluator = Naive(
        model_base,
        optimizer,
        criterion,
        device=device,
        evaluator=eval_plugin
    )
    evaluator.eval(rotated_benchmark.test_stream)

if __name__ == "__main__":
    continual_learning_with_localize_and_stitch()




### Training on Task 1 (Rotation: 115°) ###
-- >> Start of training phase << --
100%|██████████| 469/469 [04:11<00:00,  1.87it/s]
Epoch 0 ended.
	Loss_Epoch/train_phase/train_stream/Task000 = 0.8249
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.7804
-- >> End of training phase << --
Task vector 1 is NOT in span. Adding to active set.


  mask_slice = torch.tensor(mask[offset:offset + numel]).view(shape)



### Training on Task 2 (Rotation: 145°) ###
-- >> Start of training phase << --
100%|██████████| 469/469 [04:08<00:00,  1.89it/s]
Epoch 0 ended.
	Loss_Epoch/train_phase/train_stream/Task000 = 0.7832
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.7656
-- >> End of training phase << --
Task vector 2 is NOT in span. Adding to active set.

### Training on Task 3 (Rotation: 175°) ###
-- >> Start of training phase << --
100%|██████████| 469/469 [04:08<00:00,  1.89it/s]
Epoch 0 ended.
	Loss_Epoch/train_phase/train_stream/Task000 = 0.5685
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.8268
-- >> End of training phase << --
Task vector 3 is NOT in span. Adding to active set.

### Evaluating Model ###
-- >> Start of eval phase << --
-- Starting eval on experience 0 (Task 0) from test stream --
100%|██████████| 10000/10000 [00:35<00:00, 283.30it/s]
> Eval on experience 0 (Task 0) from test stream ended.
-- Starting eval on experience 1 (Task 0) from test stream --
100%|██████████| 10