In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from avalanche.benchmarks import SplitMNIST, PermutedMNIST, RotatedMNIST
from avalanche.models import SimpleMLP
from avalanche.training.supervised import Naive
from numpy.linalg import lstsq
from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics
from avalanche.logging import InteractiveLogger,  TextLogger
from avalanche.training.plugins import EvaluationPlugin
import numpy as np

# Generate RotatedMNIST pseudo-tasks (e.g., 7 tasks with different rotations)
rotated_benchmark = RotatedMNIST(n_experiences=7, seed=1234, rotations_list=[0, 15, 30, 45, 60, 75, 90])

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7f8479448340>>
Traceback (most recent call last):
  File "/home/samuel_schapiro/anaconda3/envs/py3.9.13/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 
  from .autonotebook import tqdm as notebook_tqdm


In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Standard MNIST training
split_mnist = SplitMNIST(n_experiences=1, seed=1234)
model_base = SimpleMLP(num_classes=10).to(device)

trainer = Naive(
    model_base,
    optimizer=torch.optim.SGD(model_base.parameters(), lr=0.01),
    criterion=torch.nn.CrossEntropyLoss(),
    train_mb_size=128,
    device=device
)

# Pretrain on MNIST
trainer.train(split_mnist.train_stream[0], epochs=5)



-- >> Start of training phase << --
100%|██████████| 469/469 [03:06<00:00,  2.51it/s]
Epoch 0 ended.
	Loss_Epoch/train_phase/train_stream/Task000 = 0.8239
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.7810
-- >> End of training phase << --


{'Top1_Acc_Epoch/train_phase/train_stream/Task000': 0.7809666666666667,
 'Loss_Epoch/train_phase/train_stream/Task000': 0.8238722893238067}

In [11]:
task_vectors = []  # To store pre-populated task vectors

for experience in rotated_benchmark.train_stream:
    model_tuned = SimpleMLP(num_classes=10).to(device)
    model_tuned.load_state_dict(model_base.state_dict())  # Start from base model

    # Fine-tune on rotated task
    trainer = Naive(
        model_tuned,
        optimizer=torch.optim.SGD(model_tuned.parameters(), lr=0.01),
        criterion=torch.nn.CrossEntropyLoss(),
        train_mb_size=128,
        device=device
    )
    trainer.train(experience, epochs=5)

    # Compute task vector
    task_vector = []
    for p_base, p_tuned in zip(model_base.parameters(), model_tuned.parameters()):
        task_vector.append((p_tuned.data - p_base.data).detach().cpu().numpy())
    task_vector = np.concatenate([p.flatten() for p in task_vector])

    # Store the task vector
    task_vectors.append(task_vector)

-- >> Start of training phase << --
100%|██████████| 469/469 [03:35<00:00,  2.18it/s]
Epoch 0 ended.
	Loss_Epoch/train_phase/train_stream/Task000 = 0.4171
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.8795
-- >> End of training phase << --
-- >> Start of training phase << --
100%|██████████| 469/469 [03:33<00:00,  2.19it/s]
Epoch 0 ended.
	Loss_Epoch/train_phase/train_stream/Task000 = 0.4825
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.8587
-- >> End of training phase << --
-- >> Start of training phase << --
100%|██████████| 469/469 [03:42<00:00,  2.11it/s]
Epoch 0 ended.
	Loss_Epoch/train_phase/train_stream/Task000 = 0.5858
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.8214
-- >> End of training phase << --
-- >> Start of training phase << --
100%|██████████| 469/469 [03:53<00:00,  2.01it/s]
Epoch 0 ended.
	Loss_Epoch/train_phase/train_stream/Task000 = 0.6689
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.7912
-- >> End of training phase << --
-- >> Start 

In [12]:
import pickle

# Save pre-populated task vectors to a file
with open("rotated_task_vectors.pkl", "wb") as f:
    pickle.dump(task_vectors, f)

In [None]:
# Load pre-populated task vectors
with open("rotated_task_vectors.pkl", "rb") as f:
    task_vectors_active = pickle.load(f)