In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from avalanche.benchmarks import SplitMNIST, PermutedMNIST, RotatedMNIST
from avalanche.models import SimpleMLP
from avalanche.training.supervised import Naive
from numpy.linalg import lstsq
from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics
from avalanche.logging import InteractiveLogger,  TextLogger
from avalanche.training.plugins import EvaluationPlugin
import numpy as np

# Helper function to compute task vector
def compute_task_vector(model_pre, model_tuned):
    task_vector = []
    for p_pre, p_tuned in zip(model_pre.parameters(), model_tuned.parameters()):
        task_vector.append((p_tuned.data - p_pre.data).detach().cpu().numpy())
    return np.concatenate([p.flatten() for p in task_vector])

# Helper function to check if a vector is in the span of other vectors
def is_in_span(vector, span_vectors):
    if not span_vectors:
        return False, None
    span_matrix = np.stack(span_vectors, axis=1)
    coeffs, residuals, _, _ = lstsq(span_matrix, vector, rcond=None)
    in_span = np.allclose(span_matrix @ coeffs, vector, atol=1e-5)
    return in_span, coeffs

# Function to train the joint baseline model on all tasks at once
def train_joint_baseline(benchmark, device, epochs=1):
    print("\n### Training Joint Baseline Model on All Tasks ###")
    joint_model = SimpleMLP(num_classes=10).to(device)
    optimizer = optim.SGD(joint_model.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()
    logger = InteractiveLogger()
    eval_plugin = EvaluationPlugin(loss_metrics(epoch=True), loggers=[logger])

    trainer = Naive(joint_model, optimizer, criterion,
                    train_mb_size=128, device=device, evaluator=eval_plugin)
    
    # Combine all tasks into a single dataset
    for experience in benchmark.train_stream:
        trainer.train(experience, epochs=epochs)
    
    print("Joint Baseline Model Training Completed.")
    return joint_model

# Main function implementing the algorithm with regret computation
def continual_learning_task_vectors(dataset_name="SplitMNIST", n_experiences=5):
    # 1. Benchmark
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if dataset_name.lower() == "splitmnist":
        benchmark = SplitMNIST(n_experiences=n_experiences, return_task_id=True)
    elif dataset_name.lower() == "permutedmnist":
        benchmark = PermutedMNIST(n_experiences=n_experiences)
    elif dataset_name.lower() == "rotatedmnist":
        benchmark = RotatedMNIST(n_experiences=n_experiences)
    else:
        raise ValueError("Invalid dataset name.")

    # 2. Train the joint baseline model
    joint_model = train_joint_baseline(benchmark, device, epochs=10)

    # 3. Model, Optimizer, and Criterion
    model_pre = SimpleMLP(num_classes=10).to(device)
    optimizer = optim.SGD(model_pre.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()

    # 4. Active Task Vector Set and Regret Tracking
    task_vectors_active = []
    total_regret = 0  # Tracks cumulative regret

    # 5. Loop Through Tasks
    for task_id, experience in enumerate(benchmark.train_stream):
        print(f"\n### Training Task {task_id+1} ###")

        # Fine-tune continual learning model
        model_tuned = SimpleMLP(num_classes=10).to(device)
        model_tuned.load_state_dict(model_pre.state_dict())  # Copy weights
        trainer = Naive(model_tuned, optimizer, criterion, train_mb_size=128, device=device)
        trainer.train(experience, epochs=10)

        # Compute task vector
        task_vector = compute_task_vector(model_pre, model_tuned)
        in_span, coefficients = is_in_span(task_vector, task_vectors_active)

        if in_span:
            print(f"Task vector {task_id+1} is in span of previous task vectors.")
        else:
            print(f"Task vector {task_id+1} is NOT in span. Adding to active set.")
            task_vectors_active.append(task_vector)

        # Evaluate losses for regret computation
        test_stream = benchmark.test_stream[task_id]

        # Evaluate losses for regret computation
        cl_loss = evaluate_loss(model_tuned, test_stream, criterion, device, task_id)
        joint_loss = evaluate_loss(joint_model, test_stream, criterion, device, task_id)

        # Compute and track regret
        regret = cl_loss - joint_loss
        avg_regret += regret/n_experiences
        print(f"Regret on Task {task_id+1}: {regret:.4f}")
        print(f"Average Regret: {avg_regret:.4f}")

        # Update pretrained model
        model_pre.load_state_dict(model_tuned.state_dict())

    print("\nContinual Learning and Regret Evaluation Completed.")

def evaluate_loss(model, test_stream, criterion, device, task_id):
    """
    Evaluate the loss of a model on a given test stream and dynamically retrieve the loss.
    """
    # Setup the evaluation plugin to track loss
    eval_plugin = EvaluationPlugin(
        loss_metrics(stream=True),  # Log loss across the entire test stream
        loggers=[TextLogger(open("/dev/null", "w"))]  # Suppress logger output
    )

    # Create a Naive trainer only for evaluation
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    evaluator = Naive(model, optimizer, criterion, device=device, evaluator=eval_plugin)

    # Run evaluation
    evaluator.eval(test_stream)

    # Dynamically retrieve the loss based on the task number
    task_key = f"Loss_Stream/eval_phase/test_stream/Task{task_id:03d}"
    eval_results = eval_plugin.get_last_metrics()
    
    if task_key in eval_results:
        average_loss = eval_results[task_key]
    else:
        raise ValueError(f"Loss key {task_key} not found in evaluation results.")

    return average_loss

if __name__ == "__main__":
    print("Choose an MNIST benchmark to run:")
    print("1. SplitMNIST")
    print("2. PermutedMNIST")
    print("3. RotatedMNIST")
    
    choice = input("Enter the number of your choice: ").strip()
    if choice == "1":
        dataset_name = "SplitMNIST"
    elif choice == "2":
        dataset_name = "PermutedMNIST"
    elif choice == "3":
        dataset_name = "RotatedMNIST"
    else:
        print("Invalid choice. Defaulting to SplitMNIST.")
        dataset_name = "SplitMNIST"
    
    continual_learning_task_vectors(dataset_name=dataset_name, n_experiences=5)

Choose an MNIST benchmark to run:
1. SplitMNIST
2. PermutedMNIST
3. RotatedMNIST





### Training Joint Baseline Model on All Tasks ###
-- >> Start of training phase << --
100%|██████████| 469/469 [03:20<00:00,  2.34it/s]
Epoch 0 ended.
	Loss_Epoch/train_phase/train_stream/Task000 = 0.8100
-- >> End of training phase << --
-- >> Start of training phase << --
 11%|█         | 52/469 [00:24<05:31,  1.26it/s]

KeyboardInterrupt: 