In [3]:
import torch
import torch.nn as nn
import numpy as np
import torch.optim as optim
from torchvision.models import resnet18
from avalanche.benchmarks import RotatedMNIST
from avalanche.training.supervised import Naive
from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics, forgetting_metrics
from avalanche.logging import InteractiveLogger, TextLogger
from avalanche.training.plugins import EvaluationPlugin
from torch.utils.data import DataLoader, ConcatDataset, Subset
from avalanche.training.templates import SupervisedTemplate
from avalanche.benchmarks.utils import AvalancheDataset
import random
import pickle
import warnings
import matplotlib.pyplot as plt
warnings.filterwarnings("ignore")
from torchvision.transforms import Compose, ToTensor, Normalize
EPOCHS = 4
INDEP_THRESHOLD = 2
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
def modify_resnet18_for_mnist():
    model = resnet18(pretrained=True)
    model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
    model.fc = nn.Linear(model.fc.in_features, 10)
    return model


def generate_active_task_vectors(n_experiences=7, rotations_list=[0, 15, 30, 45, 60, 75, 90]):
  mnist_transform = Compose([Normalize(mean=(0.5,), std=(0.5,))])

  rotated_benchmark = RotatedMNIST(
      n_experiences=n_experiences, seed=42, rotations_list=rotations_list,
      dataset_root="./data",
      train_transform=mnist_transform,
      eval_transform=mnist_transform)

  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  model_base = modify_resnet18_for_mnist()
  model_base = model_base.to(device)
  trainer = Naive(
      model_base,
      optimizer=torch.optim.SGD(model_base.parameters(), lr=0.01, momentum=0.9),
      criterion=torch.nn.CrossEntropyLoss(),
      train_mb_size=128,
      device=device,
      train_epochs=5,
      evaluator=EvaluationPlugin(
          accuracy_metrics(epoch=True, stream=True),
          loggers=[InteractiveLogger()]))

  task_vectors = []
  for experience in rotated_benchmark.train_stream:
      model_tuned = modify_resnet18_for_mnist()
      model_tuned = model_tuned.to(device)
      model_tuned.load_state_dict(model_base.state_dict())
      trainer = Naive(
          model_tuned,
          optimizer=torch.optim.SGD(model_tuned.parameters(), lr=0.01, momentum=0.9),
          criterion=torch.nn.CrossEntropyLoss(),
          train_mb_size=128,
          device=device,
          train_epochs=5,
          evaluator=EvaluationPlugin(
              accuracy_metrics(epoch=True, stream=True),
              loggers=[InteractiveLogger()]))
      trainer.train(experience)
      task_vector = []
      for p_base, p_tuned in zip(model_base.parameters(), model_tuned.parameters()):
          task_vector.append((p_tuned.data - p_base.data).detach().cpu().numpy())
      task_vector = np.concatenate([p.flatten() for p in task_vector])
      task_vectors.append(task_vector)

  return task_vectors

In [5]:
def train_joint_model_on_mnist(benchmark, device, epochs=5):
    print("\n### Training Joint Model on Shuffled MNIST Tasks ###")
    joint_model = modify_resnet18_for_mnist().to(device)
    optimizer = optim.SGD(joint_model.parameters(), lr=0.01, momentum=0.9)
    criterion = nn.CrossEntropyLoss()
    logger = InteractiveLogger()
    eval_plugin = EvaluationPlugin(accuracy_metrics(stream=True), loss_metrics(stream=True), loggers=[logger])

    trainer = Naive(joint_model, optimizer, criterion, train_mb_size=128, device=device, evaluator=eval_plugin)
    all_train_data = ConcatDataset([experience.dataset for experience in benchmark.train_stream])

    # Get all sample indices and shuffle them
    all_indices = np.arange(len(all_train_data))
    np.random.shuffle(all_indices)
    shuffled_train_data = Subset(all_train_data, all_indices)
    shuffled_train_loader = DataLoader(shuffled_train_data, batch_size=128, shuffle=False)

    # Manually train the joint model using the shuffled DataLoader
    for epoch in range(epochs):
        for batch in shuffled_train_loader:
            inputs, targets = batch[0], batch[1]
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = joint_model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

    return joint_model

In [None]:
def compute_task_vector(pretrained_model, finetuned_model):
    task_vector = []
    param_shapes = []  # Store parameter shapes

    for p_pre, p_fine in zip(pretrained_model.parameters(), finetuned_model.parameters()):
        param_shapes.append(p_pre.shape)  # Save the shape of each parameter
        task_vector.append((p_fine.data - p_pre.data).detach().cpu().numpy())

    flattened_task_vector = np.concatenate([p.flatten() for p in task_vector])
    return flattened_task_vector, param_shapes

def is_approximately_independent(vector, active_vectors, device='cuda'):
    vector = torch.tensor(vector, device=device, dtype=torch.float32)
    active_vectors = torch.tensor(active_vectors, device=device, dtype=torch.float32)
    dim = vector.shape[0]
    threshold = INDEP_THRESHOLD / torch.sqrt(torch.tensor(dim, dtype=torch.float32, device=device))
    total_norm = torch.norm(vector)
    U, _, _ = torch.linalg.svd(active_vectors.T, full_matrices=False)
    projection = U @ (U.T @ vector)
    normalized_projection_magnitude = torch.norm(projection) / total_norm
    is_independent = normalized_projection_magnitude < threshold
    return is_independent, normalized_projection_magnitude.item(), threshold.item(), projection.cpu().numpy()

def localize_and_stitch(model, pretrained_model, task_vector, param_shapes, all_masks, sparsity=0.01):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    abs_vector = torch.abs(torch.tensor(task_vector))   # Create a sparse mask
    k = int(sparsity * abs_vector.numel())  # Determine number of top-k elements
    topk_indices = abs_vector.topk(k).indices  # Get indices of top-k elements
    mask = torch.zeros_like(abs_vector)  # Initialize a zero mask
    mask[topk_indices] = 1  # Set top-k indices to 1 for the mask

    combined_masks = []
    for m in all_masks:
        if not isinstance(m, torch.Tensor):
          m = torch.tensor(m, device=device, dtype=torch.float32)
          combined_masks.append(m)
    combined_masks.append(mask)

    all_masks_tensor = torch.stack(combined_masks, dim=1)  # shape: (flattened_dim, num_masks)

    counts = torch.sum(all_masks_tensor, dim=1)  # shape: (flattened_dim,)


    stitched_mask = torch.where(
        mask > 0,
        torch.where(counts > 0, 1.0 / counts, torch.zeros_like(counts)),
        torch.zeros_like(counts)
    )

    offset = 0
    flat_task_vector = torch.tensor(task_vector, device=device)
    for param, shape in zip(model.parameters(), param_shapes):
        numel = np.prod(shape)
        task_slice = flat_task_vector[offset:offset + numel].view(shape)
        mask_slice = stitched_mask[offset:offset + numel].view(shape)
        param.data += (task_slice * mask_slice).to(param.device)
        offset += numel
    return model, mask.detach().cpu()

def continual_learning_with_localize_and_stitch(rotated_benchmark, random_seed=1234, use_localize_and_stitch=True, task_vectors_active=None, all_masks=[]):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_base = modify_resnet18_for_mnist()
    model_base = model_base.to(device)
    optimizer = torch.optim.SGD(model_base.parameters(), lr=0.01, momentum=0.9)
    criterion = nn.CrossEntropyLoss()
    model_pretrained = modify_resnet18_for_mnist()
    model_pretrained.fc = nn.Linear(model_pretrained.fc.in_features, 10)  # Adjust for 10 classes
    model_pretrained.load_state_dict(model_base.state_dict())
    model_pretrained = model_pretrained.to(device)
    eval_plugin = EvaluationPlugin(
        accuracy_metrics(experience=True),
        loggers=[InteractiveLogger()])

    current_experience_accuracies = []
    num_dependent_tasks = 0
    for task_id, experience in enumerate(rotated_benchmark.train_stream):
        print(f"\n### Training on Task {task_id+1} ###")

        model_finetuned = modify_resnet18_for_mnist()
        model_finetuned.load_state_dict(model_pretrained.state_dict())
        model_finetuned = model_finetuned.to(device)

        trainer = Naive(model_finetuned, optimizer, criterion, train_mb_size=128, device=device, evaluator=eval_plugin, train_epochs=EPOCHS)
        trainer.train(experience)
        task_vector, param_shapes = compute_task_vector(model_pretrained, model_finetuned)

        if use_localize_and_stitch:
          independent, projection_magnitude, _, projection = is_approximately_independent(task_vector, task_vectors_active)
          if independent:
              task_vectors_active.append(task_vector)
              model_base, new_mask = localize_and_stitch(
                  model=model_base,
                  pretrained_model=model_pretrained,
                  task_vector=task_vector,
                  param_shapes=param_shapes,
                  all_masks=all_masks,
                  sparsity=0.0)
              all_masks.append(new_mask)
          else:
              num_dependent_tasks += 1
              model_base, new_mask = localize_and_stitch(
                  model=model_base,
                  pretrained_model=model_pretrained,
                  task_vector=projection,
                  param_shapes=param_shapes,                  
                  all_masks=all_masks,                  
                  sparsity=0.01)
              all_masks.append(new_mask)

        trainer.eval(rotated_benchmark.test_stream)
        metrics = eval_plugin.get_last_metrics()
        acc = [v for k, v in metrics.items() if "Top1_Acc_Exp/eval_phase/test_stream" in k][task_id]
        current_experience_accuracies.append(acc)
        print(f"\n### Accuracies {acc} ###")
    return current_experience_accuracies, num_dependent_tasks

In [7]:
def generate_rotations_list(total_var, num_experiences, theta_0=None):
    if theta_0 is None:
        theta_0 = random.randint(0, 180)  # Randomly initialize starting angle
    rotations = [theta_0]
    for _ in range(num_experiences - 1):
        R = random.randint(1, total_var)  # Random step size within total variation
        theta_next = (rotations[-1] + R) % 180
        rotations.append(theta_next)
    return rotations

def perform_ablation_study():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    num_test_experiences_list = [5, 10]
    total_variations = [5, 10, 30, 60]
    expressivity_scenarios = [
        [10, 20, 30, 40]
    ]

    results = []
    results_txt_path = "ablation_study_results.txt"

    with open(results_txt_path, "w") as txt_file:
        txt_file.write("Ablation Study Results\n")
        txt_file.write("=" * 50 + "\n")
        txt_file.flush()

        for expressivity in expressivity_scenarios:
            # Generate active task vector set based on expressivity
            print(f"Generating active task vector set for expressivity: {expressivity}")
            txt_file.write(f"Generating active task vector set for expressivity: {expressivity}\n")
            txt_file.flush()
            active_rotation_angles = expressivity
            task_vectors_active = generate_active_task_vectors(
                n_experiences=len(active_rotation_angles),
                rotations_list=active_rotation_angles,
            )

            for num_test_experiences in num_test_experiences_list:
                for total_var in total_variations:
                    test_description = f"\n### Testing with T={num_test_experiences}, TV={total_var}, Expressivity={expressivity} ###"
                    print(test_description)
                    txt_file.write(test_description + "\n")
                    txt_file.flush()

                    # Generate rotation angles and test experiences
                    rotation_angles = generate_rotations_list(total_var, num_test_experiences)
                    rotated_benchmark = RotatedMNIST(
                        n_experiences=num_test_experiences,
                        seed=42,
                        rotations_list=rotation_angles,
                        return_task_id=True,
                    )

                    # Train a joint model on the test experiences
                    joint_model = train_joint_model_on_mnist(rotated_benchmark, device, epochs=5)
                    joint_model.eval()

                    eval_plugin = EvaluationPlugin(
                        accuracy_metrics(experience=True, epoch=True),
                        loggers=[InteractiveLogger()],
                    )
                    trainer = Naive(
                        joint_model,
                        optimizer=torch.optim.SGD(joint_model.parameters(), lr=0.01),
                        criterion=torch.nn.CrossEntropyLoss(),
                        train_mb_size=128,
                        device=device,
                        evaluator=eval_plugin,
                    )
                    trainer.eval(rotated_benchmark.train_stream)
                    joint_model_metrics = eval_plugin.get_last_metrics()
                    joint_accuracies = [
                        v for k, v in joint_model_metrics.items() if "Top1_Acc_Exp/eval_phase/train_stream" in k
                    ]

                    print("Joint Accuracies:", joint_accuracies)

                    # Run continual learning with Localize and Stitch
                    accs_localize, num_dependent_tasks_localize = continual_learning_with_localize_and_stitch(
                        task_vectors_active=task_vectors_active,
                        rotated_benchmark=rotated_benchmark,
                        random_seed=42,
                        use_localize_and_stitch=True,
                    )

                    # Run continual learning without Localize and Stitch
                    accs_finetune, num_dependent_tasks_finetune = continual_learning_with_localize_and_stitch(
                        rotated_benchmark=rotated_benchmark,
                        random_seed=42,
                        use_localize_and_stitch=False,
                    )

                    # Collect results
                    result_entry = {
                        "T": num_test_experiences,
                        "TV": total_var,
                        "Expressivity": expressivity,
                        "Dependent_Tasks_Localize": num_dependent_tasks_localize,
                        "Dependent_Tasks_Finetune": num_dependent_tasks_finetune,
                        "Accs_Localize": accs_localize,
                        "Accs_Finetune": accs_finetune,
                        "Accuracy_Joint": joint_accuracies
                    }
                    results.append(result_entry)

                    # Write results to the text file
                    txt_file.write(f"Dependent Tasks (Localize & Stitch): {num_dependent_tasks_localize}\n")
                    txt_file.write(f"Dependent Tasks (Fine-Tune): {num_dependent_tasks_finetune}\n")
                    txt_file.write(f"Accuracy (Localize & Stitch): {accs_localize}\n")
                    txt_file.write(f"Accuracy (Fine-Tune): {accs_finetune}\n")
                    txt_file.write(f"Accuracy (Joint): {joint_accuracies}\n")
                    txt_file.write("-" * 50 + "\n")
                    txt_file.flush()

    # Save results to a file for further analysis
    results_path = "ablation_study_results.pkl"
    with open(results_path, "wb") as f:
        pickle.dump(results, f)
    print(f"\n### Ablation study results saved to {results_path} ###")

    return results

In [None]:
def main():
    results = perform_ablation_study()
    # Analyze results (e.g., plot accuracies, regret, or dependency trends)
    for result in results:
        print(f"T={result['T']}, TV={result['TV']}, Expressivity={result['Expressivity']}")
        print(f"Accs Localize: {result['Accs_Localize']}")
        print(f"Accs Finetune: {result['Accs_Finetune']}")
        print(f"Dependent Tasks Localize: {result['Dependent_Tasks_Localize']}")
        print(f"Dependent Tasks Finetune: {result['Dependent_Tasks_Finetune']}\n")


if __name__ == "__main__":
    main()

Generating active task vector set for expressivity: [10, 20, 30, 40]
-- >> Start of training phase << --
  1%|          | 5/469 [07:53<12:16:56, 95.29s/it]