In [1]:
# Standard Library Imports
import json
from pathlib import Path


# Third-Party Libraries
import numpy as np
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
from torch.optim import Adam, SGD
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, models, transforms

# Avalanche: Continual Learning Framework
## Benchmarks
from avalanche.benchmarks.classic import SplitCIFAR10
from avalanche.benchmarks.datasets.torchvision_wrapper import CIFAR10
from avalanche.benchmarks.scenarios import CLExperience
from avalanche.benchmarks.utils.flat_data import ConstantSequence

## Models
from avalanche.models import (
    MultiHeadClassifier,
    MultiTaskModule,
    MTSimpleMLP,
    MTSimpleCNN,
    PNN,
)

## Training Strategies
from avalanche.training.supervised import Naive, EWC, LwF

## Plugins and Logging
from avalanche.logging import InteractiveLogger, TextLogger
from avalanche.training.plugins import EvaluationPlugin, LRSchedulerPlugin

## Evaluation Metrics
from avalanche.evaluation.metrics import (
    accuracy_metrics,
    forgetting_metrics,
    loss_metrics,
    timing_metrics,
    cpu_usage_metrics,
    confusion_matrix_metrics,
    disk_usage_metrics,
)

In [2]:
SAVE = False
import os

if SAVE:
    os.chdir('/home/uregina/DL_Project')
    print(os.getcwd())

# For saving the datasets/models/results/log files

if SAVE:
    DATASET_NAME = "SplitCIFAR10"
    ROOT = Path("/home/uregina/DL_Project")
    DATA_ROOT = ROOT / DATASET_NAME
    DATA_ROOT.mkdir(parents=True, exist_ok=True)

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
seed = 0

DATASET_NAME = "SplitCIFAR10"
NUM_CLASSES = {
    "SplitCIFAR10": 10
}

# Define hyperparameters/scheduler/augmentation
HPARAM = {
    "batch_size": 128,        
    "num_epoch": 3,           
    "start_lr": 0.01,
    "alpha": 0.9,
    "temperature": 5,
}

In [4]:
# print to stdout
interactive_logger = InteractiveLogger()

benchmark = SplitCIFAR10(
    n_experiences = 5,          
    return_task_id = True,
    seed=seed
)

eval_plugin = EvaluationPlugin(
    accuracy_metrics(minibatch=False, epoch=True, experience=True, stream=True),
    loss_metrics(minibatch=False, epoch=True, experience=True, stream=True),
    timing_metrics(epoch=True, epoch_running=True),
    forgetting_metrics(experience=True, stream=True),
    cpu_usage_metrics(experience=True),
    confusion_matrix_metrics(
        num_classes=NUM_CLASSES[DATASET_NAME], save_image=False, stream=True
    ),
    disk_usage_metrics(minibatch=True, epoch=True, experience=True, stream=True),
    loggers=interactive_logger,
)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
MODEL_NAME = 'MTSimpleCNN'
RUN = '0'                    #Multiple runs 0,1,2
model = MTSimpleCNN()

optimizer = Adam(model.parameters(), HPARAM["start_lr"])

cl_strategy = LwF(
    model=model,
    optimizer=optimizer,
    criterion=torch.nn.CrossEntropyLoss(),
    train_mb_size=HPARAM["batch_size"],
    train_epochs=HPARAM["num_epoch"],
    eval_mb_size=HPARAM["batch_size"],
    alpha=HPARAM["alpha"],              # LwF parameter
    temperature=HPARAM["temperature"],  # LwF parameter
    evaluator=eval_plugin,
    device=device,
)

if SAVE:
    DATA_ROOT = ROOT / DATASET_NAME / MODEL_NAME / RUN
    DATA_ROOT.mkdir(parents=True, exist_ok=True)

In [6]:
print("Starting experiment...")
results_dict = {}  # Use a dictionary instead of a list
for index, experience in enumerate(benchmark.train_stream):
    print("Start of experience: ", experience.current_experience)
    print("Current Classes: ", experience.classes_in_this_experience)
    res = cl_strategy.train(experience)
    print("Training completed")
    print("Computing accuracy on the whole test")
    results_dict[index] = cl_strategy.eval(benchmark.test_stream)  # Use the index as the key

print("Experiment completed")

Starting experiment...
Start of experience:  0
Current Classes:  [1, 4]
-- >> Start of training phase << --
 43%|████▎     | 34/79 [00:03<00:03, 13.62it/s]

KeyboardInterrupt: 

In [None]:
if SAVE:
    file_name = f"{MODEL_NAME}_{DATASET_NAME}_{RUN}_results.txt"
    file_path = ROOT / DATASET_NAME / MODEL_NAME / RUN / file_name
    with open(file_path, "w") as file:
        file.write(f"Model: {MODEL_NAME}\n")
        file.write(f"Dataset: {DATASET_NAME}\n")
        file.write(f"Run: {RUN}\n")   
        file.write("\nResults Dictionary:\n")
        file.write("--------------------------------------------------\n")
        for key, value in results_dict.items():
            file.write(f"Experience {key}:\n")
            for metric, metric_value in value.items():
                # Convert tensors to lists for saving
                if isinstance(metric_value, torch.Tensor):
                    metric_value = metric_value.tolist()
                file.write(f"  {metric}: {metric_value}\n")
            file.write("--------------------------------------------------\n")

 43%|████▎     | 34/79 [00:19<00:03, 13.62it/s]