In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Subset, DataLoader
import sys
import os
import pickle

PROJECT_PATH = '/content/drive/MyDrive/tfc-sr'
sys.path.append(PROJECT_PATH)

In [None]:
CONFIG = {
    'seed': 42,
    'num_tasks': 5,
    'epochs_per_task': 10,
    'batch_size': 64,
    'lr': 0.001,
    'num_classes': 10,
    'results_path': os.path.join(PROJECT_PATH, 'results'),
    'checkpoints_path': os.path.join(PROJECT_PATH, 'checkpoints'),
    # EWC/SI specific
    'ewc_lambda': 1.0, # Regularization strength for EWC.
    'si_lambda': 1.0,   # Regularization strength for SI.
}

In [None]:
# Create directories if they don't exist
os.makedirs(CONFIG['results_path'], exist_ok=True)
os.makedirs(CONFIG['checkpoints_path'], exist_ok=True)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
!pip install avalanche-lib

from avalanche.benchmarks.classic import SplitMNIST
from avalanche.training import EWC, SynapticIntelligence
from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics, forgetting_metrics
from avalanche.logging import InteractiveLogger, TextLogger
from avalanche.training.plugins import EvaluationPlugin


# --- 3. UNIFIED AVALANCHE BENCHMARK SETUP ---
# This benchmark will be used for all experiments to ensure consistency.
split_mnist_benchmark = SplitMNIST(n_experiences=5, seed=CONFIG['seed'])

In [None]:
from model import CNN
from data_setup import get_split_mnist_dataloaders
from utils import set_seed, save_results, plot_results, load_results, evaluate_on_seen_tasks

In [None]:
set_seed(CONFIG['seed'])

In [None]:
# --- EXPERIMENT: Standard CL (Benchmark) ---

model = CNN(num_classes=CONFIG['num_classes']).to(device)
optimizer = optim.Adam(model.parameters(), lr=CONFIG['lr'])
criterion = nn.CrossEntropyLoss()
baseline_accuracies = []

# Main continual learning loop
for i, experience in enumerate(split_mnist_benchmark.train_stream):
    print(f"\n--- Training on Task {i+1}/{CONFIG['num_tasks']} ---")

    train_dataset = experience.dataset
    train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True)

    # Training loop
    model.train()
    for epoch in range(CONFIG['epochs_per_task']):
        for data, targets, task_labels in train_loader:
            data, targets = data.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
        print(f"Task {i+1}, Epoch {epoch+1}/{CONFIG['epochs_per_task']}, Loss: {loss.item():.4f}")

    # Evaluation loop
    accuracy = evaluate_on_seen_tasks(model, split_mnist_benchmark, i, device, CONFIG['batch_size'])
    baseline_accuracies.append(accuracy)
    print(f"----- Accuracy after Task {i+1}: {accuracy:.2f}% -----")

# Save the final model checkpoint
final_model_path = os.path.join(CONFIG['checkpoints_path'], 'baseline_final_model.pth')
torch.save(model.state_dict(), final_model_path)
print(f"\nFinal baseline model saved to {final_model_path}")

# Save the results list
baseline_results_path = os.path.join(CONFIG['results_path'], 'baseline_accuracies.pkl')
save_results(baseline_accuracies, baseline_results_path)

# Plot the results
results_to_plot = {
    'Baseline': baseline_accuracies
}
plot_results(results_to_plot, title="Baseline Performance on Split MNIST")

In [None]:
from utils import ReservoirReplayBuffer

CONFIG['buffer_capacity'] = 200 # Total samples to store across all tasks
CONFIG['replay_batch_size'] = CONFIG['batch_size'] // 2 # Should be half of the main batch_size

In [None]:
# --- EXPERIMENT: Standard ER ---
set_seed(CONFIG['seed'])

model_er = CNN(num_classes=CONFIG['num_classes']).to(device)
optimizer_er = optim.Adam(model_er.parameters(), lr=CONFIG['lr'])
criterion_er = nn.CrossEntropyLoss()

replay_buffer_er = ReservoirReplayBuffer(capacity=CONFIG['buffer_capacity'])

er_accuracies = []

# Main continual learning loop
for task_id, experience in enumerate(split_mnist_benchmark.train_stream):
    print(f"\n--- Training on Task {task_id+1}/{len(split_mnist_benchmark.train_stream)} ---")

    # --- Step A: Populate the replay buffer with some examples from the new task ---
    # We do this before training on the task itself
    print(f"Populating replay buffer from Task {task_id+1}...")

    for data_point, target, _ in experience.dataset:
      replay_buffer_er.add(data_point, target)
    print(f"Replay buffer size: {len(replay_buffer_er)}")

    # --- Step B: Training loop with mixed batches ---
    model_er.train()
    for epoch in range(CONFIG['epochs_per_task']):
        for new_data, new_targets, _ in train_loader:
            # Only proceed if we have something in the buffer to replay
            if len(replay_buffer_er) > CONFIG['replay_batch_size']:
                # 1. Sample a batch from the replay buffer
                old_data, old_targets = replay_buffer_er.sample(CONFIG['replay_batch_size'])

                # 2. Create the mixed batch
                # Ensure the new data batch is the same size as the replay batch
                # This makes a 50/50 mix
                new_data = new_data[:CONFIG['replay_batch_size']]
                new_targets = new_targets[:CONFIG['replay_batch_size']]

                combined_data = torch.cat((new_data, old_data), dim=0).to(device)
                combined_targets = torch.cat((new_targets, old_targets), dim=0).to(device)

                # 3. Standard training step on the mixed batch
                optimizer_er.zero_grad()
                outputs = model_er(combined_data)
                loss = criterion_er(outputs, combined_targets)
                loss.backward()
                optimizer_er.step()

        print(f"Task {task_id+1}, Epoch {epoch+1}, Last batch loss: {loss.item():.4f}")

    # --- Step C: Evaluation loop ---
    accuracy = evaluate_on_seen_tasks(model_er, split_mnist_benchmark, task_id, device, CONFIG['batch_size'])
    er_accuracies.append(accuracy)
    print(f"----- Accuracy after Task {task_id+1}: {accuracy:.2f}% -----")

# Save checkpoint
er_model_path = os.path.join(CONFIG['checkpoints_path'], 'er_final_model.pth')
torch.save(model_er.state_dict(), er_model_path)
print(f"\nFinal ER model saved to {er_model_path}")

# Save results
er_results_path = os.path.join(CONFIG['results_path'], 'er_accuracies.pkl')
save_results(er_accuracies, er_results_path)

# Plot comparison
baseline_accuracies = load_results(os.path.join(CONFIG['results_path'], 'baseline_accuracies.pkl'))
results_to_plot = {
    'Baseline': baseline_accuracies,
    'Standard ER': er_accuracies
}
plot_results(results_to_plot, title="Standard ER vs. Baseline on Split MNIST")

In [None]:
# --- EXPERIMENT: Task-Focused Consolidation with Spaced Repetition (TFC-SR) ---
set_seed(CONFIG['seed'])

from utils import create_buffer_validation_set, evaluate_replay_buffer

# --- 1. CONFIGURATION ---
# New parameters for TFC-SR
CONFIG['mastery_threshold'] = 95.0 # Accuracy threshold in %
CONFIG['initial_replay_gap'] = 1   # Start checking after epoch 1
CONFIG['replay_gap_multiplier'] = 1.5 # How much to increase the gap

# --- 2. RUN TFC-SR EXPERIMENT ---
print("\n===== Starting TFC-SR Experiment =====")

model_tfc = CNN(num_classes=CONFIG['num_classes']).to(device)
optimizer_tfc = optim.Adam(model_tfc.parameters(), lr=CONFIG['lr'])
criterion_tfc = nn.CrossEntropyLoss()

replay_buffer_tfc = ReservoirReplayBuffer(capacity=CONFIG['buffer_capacity'])
tfc_accuracies = []

# --- Main Continual Learning Loop ---
for task_id, experience in enumerate(split_mnist_benchmark.train_stream):
    print(f"\n--- Training on Task {task_id+1}/{len(split_mnist_benchmark.train_stream)} ---")

    # Populate replay buffer
    for data_point, target, _ in experience.dataset:
        replay_buffer_tfc.add(data_point, target)
    print(f"Replay buffer size: {len(replay_buffer_tfc)}")

    # Initialize the replay schedule for this new task
    current_replay_gap = float(CONFIG['initial_replay_gap'])
    replay_timer = int(current_replay_gap)

    train_loader = DataLoader(experience.dataset, batch_size=CONFIG['batch_size'], shuffle=True)

    # --- Training loop for the current task ---
    model_tfc.train()
    for epoch in range(CONFIG['epochs_per_task']):
        # --- Mixed-batch training ---
        for new_data, new_targets, _ in train_loader:
            if len(replay_buffer_tfc) > CONFIG['batch_size'] // 2:
                replay_batch_size = CONFIG['batch_size'] // 2
                old_data, old_targets = replay_buffer_tfc.sample(replay_batch_size)
                new_data = new_data[:replay_batch_size]
                new_targets = new_targets[:replay_batch_size]
                combined_data = torch.cat((new_data, old_data), dim=0).to(device)
                combined_targets = torch.cat((new_targets, old_targets), dim=0).to(device)

                optimizer_tfc.zero_grad()
                outputs = model_tfc(combined_data)
                loss = criterion_tfc(outputs, combined_targets)
                loss.backward()
                optimizer_tfc.step()

        print(f"Task {task_id+1}, Epoch {epoch+1}, Loss: {loss.item():.4f}", end="")

        # --- Adaptive Replay Scheduling Logic ---
        if (epoch + 1) == replay_timer and len(replay_buffer_tfc) > 1:
            print(" <-- Memory Check!", end="")
            model_tfc.eval()

            replay_perf = evaluate_replay_buffer(model_tfc, replay_buffer_tfc, device)

            print(f" | Replay Perf: {replay_perf:.2f}%", end="")

            if replay_perf >= CONFIG['mastery_threshold']:
                current_replay_gap *= CONFIG['replay_gap_multiplier']
                replay_timer += round(current_replay_gap)
                print(f" | Mastery OK. Next check @ epoch {replay_timer}.")
            else:
                replay_timer += 1
                print(f" | Mastery FAIL. Next check @ epoch {replay_timer+1}.")

            model_tfc.train()
        else:
            print()

    # --- Final Evaluation for this task (same as before) ---
    accuracy = evaluate_on_seen_tasks(model_tfc, split_mnist_benchmark, task_id, device, CONFIG['batch_size'])
    tfc_accuracies.append(accuracy)
    print(f"----- Accuracy after Task {task_id+1}: {accuracy:.2f}% -----")

# --- 3. SAVE AND PLOT ---
# Save checkpoint and results
tfc_model_path = os.path.join(CONFIG['checkpoints_path'], 'tfc_sr_final_model.pth')
torch.save(model_tfc.state_dict(), tfc_model_path)
tfc_results_path = os.path.join(CONFIG['results_path'], 'tfc_sr_accuracies.pkl')
save_results(tfc_accuracies, tfc_results_path)

# Plot comparison with previous results
baseline_accuracies = load_results(os.path.join(CONFIG['results_path'], 'baseline_accuracies.pkl'))
er_accuracies = load_results(os.path.join(CONFIG['results_path'], 'er_accuracies.pkl'))
results_to_plot = {
    'Baseline': baseline_accuracies,
    'Standard ER': er_accuracies,
    'TFC-SR': tfc_accuracies
}
plot_results(results_to_plot, title="TFC-SR vs. Baselines on Split MNIST")

In [None]:
# --- EXPERIMENT: ELASTIC WEIGHT CONSOLIDATION (EWC) ---
set_seed(CONFIG['seed'])

print("\n" + "="*20 + " Starting EWC Experiment " + "="*20)

# --- Hyperparameter Search for EWC ---
ewc_lambdas_to_try = [1.0, 100.0, 1000.0, 10000.0]
all_ewc_results = {}

for lmbda in ewc_lambdas_to_try:
    print(f"\n--- Running EWC with lambda = {lmbda} ---")

    set_seed(CONFIG['seed'])

    # --- Setup strategy for this trial ---
    model_ewc = CNN(num_classes=CONFIG['num_classes']).to(device)
    optimizer_ewc = optim.Adam(model_ewc.parameters(), lr=CONFIG['lr'])

    ewc_strategy = EWC(
        model_ewc, optimizer_ewc, nn.CrossEntropyLoss(),
        ewc_lambda=lmbda,
        train_mb_size=CONFIG['batch_size'],
        train_epochs=CONFIG['epochs_per_task'],
        device=device
    )

    # List to store accuracies for this specific lambda run
    current_lambda_accuracies = []

    # --- Training and Evaluation Loop ---
    for task_id, experience in enumerate(split_mnist_benchmark.train_stream):
        print(f"--> Training on experience {task_id+1}")

        ewc_strategy.train(experience)

        accuracy = evaluate_on_seen_tasks(
            ewc_strategy.model,
            split_mnist_benchmark,
            task_id,
            device,
            CONFIG['batch_size']
        )
        current_lambda_accuracies.append(accuracy)
        print(f"----- Avg Accuracy after Task {task_id+1}: {accuracy:.2f}% -----")

    # Store the results for this lambda
    all_ewc_results[lmbda] = current_lambda_accuracies

# --- Find the best EWC result and save it ---
best_lambda_ewc = max(all_ewc_results, key=lambda k: all_ewc_results[k][-1])
CONFIG['best_ewc_lambda'] = best_lambda_ewc
best_ewc_accuracies = all_ewc_results[best_lambda_ewc]

print(f"\nBest EWC lambda was {best_lambda_ewc} with final accuracy: {best_ewc_accuracies[-1]:.2f}%")

# --- SAVE AND PLOT ---
ewc_results_path = os.path.join(CONFIG['results_path'], 'ewc_accuracies.pkl')
save_results(best_ewc_accuracies, ewc_results_path)
print(f"\nEWC results saved to {ewc_results_path}")

baseline_accuracies = load_results(os.path.join(CONFIG['results_path'], 'baseline_accuracies.pkl'))
er_accuracies = load_results(os.path.join(CONFIG['results_path'], 'er_accuracies.pkl'))
tfc_accuracies = load_results(os.path.join(CONFIG['results_path'], 'tfc_sr_accuracies.pkl'))

results_to_plot = {
    'Baseline': baseline_accuracies,
    'Standard ER': er_accuracies,
    f'EWC (λ={best_lambda_ewc})': best_ewc_accuracies,
    'TFC-SR': tfc_accuracies
}
plot_results(results_to_plot, title="All Methods vs. Baselines on Split MNIST")

In [None]:
# --- EXPERIMENT: SYNAPTIC INTELLIGENCE (SI) ---

print("\n" + "="*20 + " Starting SI Experiment " + "="*20)

# --- Hyperparameter Search for SI ---
si_lambdas_to_try = [0.1, 1.0, 10.0, 100.0]
all_si_results = {}

for lmbda in si_lambdas_to_try:
    print(f"\n--- Running SI with lambda = {lmbda} ---")

    set_seed(CONFIG['seed'])

    # --- Setup strategy for this trial ---
    model_si = CNN(num_classes=CONFIG['num_classes']).to(device)
    optimizer_si = optim.Adam(model_si.parameters(), lr=CONFIG['lr'])

    # Instantiate the SynapticIntelligence strategy
    si_strategy = SynapticIntelligence(
        model_si, optimizer_si, nn.CrossEntropyLoss(),
        si_lambda=lmbda,
        train_mb_size=CONFIG['batch_size'],
        train_epochs=CONFIG['epochs_per_task'],
        device=device
    )

    current_lambda_accuracies = []

    # --- Training and Evaluation Loop ---
    for task_id, experience in enumerate(split_mnist_benchmark.train_stream):
        print(f"--> Training on experience {task_id+1}")

        si_strategy.train(experience)

        # --- EVALUATION STEP ---
        accuracy = evaluate_on_seen_tasks(
            si_strategy.model,
            split_mnist_benchmark,
            task_id,
            device,
            CONFIG['batch_size']
        )
        current_lambda_accuracies.append(accuracy)
        print(f"----- Avg Accuracy after Task {task_id+1}: {accuracy:.2f}% -----")

    # Store the results for this lambda
    all_si_results[lmbda] = current_lambda_accuracies

# --- Find the best SI result and save it ---
best_lambda_si = max(all_si_results, key=lambda k: all_si_results[k][-1])
best_si_accuracies = all_si_results[best_lambda_si]

print(f"\nBest SI lambda was {best_lambda_si} with final accuracy: {best_si_accuracies[-1]:.2f}%")

si_results_path = os.path.join(CONFIG['results_path'], 'si_accuracies.pkl')
save_results(best_si_accuracies, si_results_path)


# --- PLOT ALL RESULTS TOGETHER ---
# (Load all previous results and plot everything)
baseline_accuracies = load_results(os.path.join(CONFIG['results_path'], 'baseline_accuracies.pkl'))
er_accuracies = load_results(os.path.join(CONFIG['results_path'], 'er_accuracies.pkl'))
ewc_accuracies = load_results(os.path.join(CONFIG['results_path'], 'ewc_accuracies.pkl'))
tfc_accuracies = load_results(os.path.join(CONFIG['results_path'], 'tfc_sr_accuracies.pkl'))
best_ewc_lambda = CONFIG['best_ewc_lambda']

results_to_plot = {
    'Baseline': baseline_accuracies,
    'Standard ER': er_accuracies,
    f'EWC (Best λ={best_ewc_lambda})': ewc_accuracies,
    f'SI (Best λ={best_lambda_si})': best_si_accuracies,
    'TFC-SR': tfc_accuracies
}
plot_results(results_to_plot, title="All Methods vs. Baselines on Split MNIST")