In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Subset, DataLoader
import sys
import os
import pickle

PROJECT_PATH = '/content/drive/MyDrive/tfc-sr'
sys.path.append(PROJECT_PATH)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
CONFIG = {
    'seed': 42,
    'num_tasks': 10,
    'epochs_per_task': 20,
    'batch_size': 64,
    'lr': 0.001,
    'num_classes': 100,
    'results_path': os.path.join(PROJECT_PATH, 'results'),
    'checkpoints_path': os.path.join(PROJECT_PATH, 'checkpoints'),
}

In [None]:
!pip install avalanche-lib

from avalanche.benchmarks.classic import SplitMNIST
from avalanche.training import EWC, SynapticIntelligence
from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics, forgetting_metrics
from avalanche.logging import InteractiveLogger, TextLogger
from avalanche.training.plugins import EvaluationPlugin


In [None]:
!pip freeze > requirements.txt

In [None]:
from data_setup import get_split_mnist_dataloaders
from utils import set_seed, save_results, plot_results, load_results, evaluate_on_seen_tasks, ReservoirReplayBuffer

set_seed(CONFIG['seed'])

In [None]:
from avalanche.benchmarks.classic import SplitCIFAR100

# Create a benchmark with 10 tasks, each containing 10 new classes.
split_cifar100_benchmark = SplitCIFAR100(n_experiences=10, seed=CONFIG['seed'])

In [None]:
from torchvision.models import resnet18

def get_resnet18_for_cifar(num_classes=100):
    """
    Returns a ResNet-18 model adapted for the CIFAR dataset.
    """
    model = resnet18(weights=None) # weights=None means training from scratch

    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity() # Remove the initial max pooling

    # The final layer needs to be replaced for the correct number of classes
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes)

    return model

In [None]:
# --- EXPERIMENT: BASELINE ON SPLIT CIFAR-100 (WITH LR TUNING) ---
import torchvision.transforms as transforms

print("\n" + "="*20 + " Starting Baseline Experiment on Split CIFAR-100 " + "="*20)

# --- Hyperparameter Search Setup ---
learning_rates_to_try = [0.01, 0.001, 0.0001]
all_baseline_results = {}

# --- Define the benchmark ONCE ---
# This ensures all LR trials use the exact same data splits and order.
split_cifar100_benchmark = SplitCIFAR100(
    n_experiences=10,
    seed=CONFIG['seed'],
    # Standard normalization for CIFAR-100
    train_transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
    ]),
    eval_transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
    ])
)

for lr in learning_rates_to_try:
    print(f"\n--- Running Baseline with learning rate = {lr} ---")

    set_seed(CONFIG['seed']) # Reset seed for each trial for fairness

    model_baseline = get_resnet18_for_cifar(num_classes=CONFIG['num_classes']).to(device)
    optimizer = optim.Adam(model_baseline.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    current_lr_accuracies = []

    # --- Main Continual Learning Loop ---
    for task_id, experience in enumerate(split_cifar100_benchmark.train_stream):
        print(f"--> Training on Task {task_id+1}/{len(split_cifar100_benchmark.train_stream)}")

        train_loader = DataLoader(experience.dataset, batch_size=CONFIG['batch_size'], shuffle=True)

        # --- Training loop ---
        model_baseline.train()
        for epoch in range(CONFIG['epochs_per_task']):
            running_loss = 0.0
            for data, targets, _ in train_loader:
                data, targets = data.to(device), targets.to(device)

                optimizer.zero_grad()
                outputs = model_baseline(data)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()

            print(f"Task {task_id+1}, Epoch {epoch+1}, Avg Loss: {running_loss / len(train_loader):.4f}")

        # --- Evaluation Step ---
        accuracy = evaluate_on_seen_tasks(model_baseline, split_cifar100_benchmark, task_id, device, CONFIG['batch_size'])
        current_lr_accuracies.append(accuracy)
        print(f"----- Avg Accuracy after Task {task_id+1}: {accuracy:.2f}% -----")

    all_baseline_results[lr] = current_lr_accuracies

# --- Find the best learning rate and save its results ---
# We choose the LR that gave the best accuracy after the final task
best_lr = max(all_baseline_results, key=lambda k: all_baseline_results[k][-1])
best_baseline_accuracies = all_baseline_results[best_lr]

print(f"\nBest Baseline learning rate was {best_lr} with final accuracy: {best_baseline_accuracies[-1]:.2f}%")

# Save only the results from the BEST run
baseline_cifar_results_path = os.path.join(CONFIG['results_path'], 'baseline_cifar_accuracies.pkl')
save_results(best_baseline_accuracies, baseline_cifar_results_path)

# --- Plot all the trial runs to visualize the tuning process ---
plot_results({f'LR={lr}': acc for lr, acc in all_baseline_results.items()},
             title="Baseline LR Tuning on Split CIFAR-100")

# --- Plot just the best result ---
plot_results({'Baseline (Best LR)': best_baseline_accuracies},
             title="Best Baseline Performance on Split CIFAR-100")

In [None]:
# Best Baseline learning rate was 0.001 with final accuracy: 7.27%

In [None]:
# doing the experiment again to save better results data structure which is now a dictionary. Above experiment already did fine tuning for learning rate.

# --- EXPERIMENT 1: BASELINE (SEQUENTIAL FINE-TUNING) ON SPLIT CIFAR-100 ---

print("\n" + "="*20 + " Starting Final Baseline Experiment on Split CIFAR-100 " + "="*20)

# --- Use the best learning rate found from tuning ---
CONFIG['lr'] = 0.001

set_seed(CONFIG['seed'])

model_baseline = get_resnet18_for_cifar(num_classes=CONFIG['num_classes']).to(device)
optimizer = optim.Adam(model_baseline.parameters(), lr=CONFIG['lr'])
criterion = nn.CrossEntropyLoss()

# --- Use a dictionary to store results  ---
baseline_results = {
    'accuracies': [],
    'total_replay_batches': 0  # Allows us to better compare the replay methods. 0 for the baseline.
}

# --- Main Continual Learning Loop ---
for task_id, experience in enumerate(split_cifar100_benchmark.train_stream):
    print(f"\n--- Training on Task {task_id+1}/{len(split_cifar100_benchmark.train_stream)} ---")

    train_loader = DataLoader(experience.dataset, batch_size=CONFIG['batch_size'], shuffle=True)

    # --- Training loop ---
    model_baseline.train()
    for epoch in range(CONFIG['epochs_per_task']):
        running_loss = 0.0
        for data, targets, _ in train_loader:
            data, targets = data.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = model_baseline(data)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        print(f"Task {task_id+1}, Epoch {epoch+1}, Avg Loss: {running_loss / len(train_loader):.4f}")

    # --- Evaluation Step ---
    accuracy = evaluate_on_seen_tasks(model_baseline, split_cifar100_benchmark, task_id, device, CONFIG['batch_size'])
    baseline_results['accuracies'].append(accuracy)
    print(f"----- Accuracy after Task {task_id+1}: {accuracy:.2f}% -----")

# --- Save Results and Checkpoint ---
final_model_path = os.path.join(CONFIG['checkpoints_path'], 'baseline_cifar_final_model.pth')
torch.save(model_baseline.state_dict(), final_model_path)
print(f"\nFinal baseline model (CIFAR-100) saved to {final_model_path}")

# Save the entire results dictionary
baseline_cifar_results_path = os.path.join(CONFIG['results_path'], 'baseline_cifar_results.pkl')
save_results(baseline_results, baseline_cifar_results_path)

# --- Plot the Single Result ---
plot_results(
    {'Baseline (CIFAR-100)': baseline_results['accuracies']},
    title="Baseline Performance on Split CIFAR-100"
)

In [None]:
# --- EXPERIMENT: STANDARD EXPERIENCE REPLAY (ER) ON SPLIT CIFAR-100 ---

print("\n" + "="*20 + " Starting Standard ER Experiment on Split CIFAR-100 " + "="*20)

CONFIG['buffer_capacity'] = 1000
CONFIG['replay_batch_size'] = CONFIG['batch_size'] // 2

set_seed(CONFIG['seed'])

model_er = get_resnet18_for_cifar(num_classes=CONFIG['num_classes']).to(device)
optimizer_er = optim.Adam(model_er.parameters(), lr=CONFIG['lr'])
criterion_er = nn.CrossEntropyLoss()

replay_buffer_er = ReservoirReplayBuffer(capacity=CONFIG['buffer_capacity'])

er_results = {
    'accuracies': [],
    'total_replay_batches': 0
}

# --- Main Continual Learning Loop ---
for task_id, experience in enumerate(split_cifar100_benchmark.train_stream):
    print(f"\n--- Training on Task {task_id+1}/{len(split_cifar100_benchmark.train_stream)} ---")

    # Populate replay buffer
    print(f"Populating replay buffer from Task {task_id+1}...")
    for data_point, target, _ in experience.dataset:
        replay_buffer_er.add(data_point, target)
    print(f"Replay buffer size: {len(replay_buffer_er)}")

    # Create the dataloader for the current task
    train_loader = DataLoader(experience.dataset, batch_size=CONFIG['batch_size'], shuffle=True)

    # --- Training loop with mixed batches ---
    model_er.train()
    for epoch in range(CONFIG['epochs_per_task']):
        running_loss = 0.0
        for new_data, new_targets, _ in train_loader:

            if len(replay_buffer_er) >= CONFIG['replay_batch_size']:
                old_data, old_targets = replay_buffer_er.sample(CONFIG['replay_batch_size'])
                new_data = new_data[:CONFIG['replay_batch_size']]
                new_targets = new_targets[:CONFIG['replay_batch_size']]

                combined_data = torch.cat((new_data, old_data), dim=0).to(device)
                combined_targets = torch.cat((new_targets, old_targets), dim=0).to(device)

                optimizer_er.zero_grad()
                outputs = model_er(combined_data)
                loss = criterion_er(outputs, combined_targets)
                loss.backward()
                optimizer_er.step()
                running_loss += loss.item()

                er_results['total_replay_batches'] += 1

        print(f"Task {task_id+1}, Epoch {epoch+1}, Avg Loss: {running_loss / len(train_loader):.4f}")

    # --- Evaluation ---
    accuracy = evaluate_on_seen_tasks(model_er, split_cifar100_benchmark, task_id, device, CONFIG['batch_size'])
    er_results['accuracies'].append(accuracy)
    print(f"----- Accuracy after Task {task_id+1}: {accuracy:.2f}% -----")

# --- Save and Plot ---
er_cifar_results_path = os.path.join(CONFIG['results_path'], 'er_cifar_results.pkl')
save_results(er_results, er_cifar_results_path)

# Load the dictionaries for plotting
baseline_results = load_results(os.path.join(CONFIG['results_path'], 'baseline_cifar_results.pkl'))

baseline_accuracy = baseline_results['accuracies']
er_results_loaded = load_results(er_cifar_results_path)

# Extract the accuracy lists for the plot function
results_to_plot = {
    'Baseline (CIFAR-100)': baseline_accuracy,
    'Standard ER (CIFAR-100)': er_results_loaded['accuracies']
}
plot_results(results_to_plot, title="Standard ER vs. Baseline on Split CIFAR-100")

# print the efficiency metric
print(f"\nStandard ER performed {er_results_loaded['total_replay_batches']} replay batches.")

In [None]:
# --- EXPERIMENT: TFC-SR SENSITIVITY ANALYSIS ON SPLIT CIFAR-100 ---
from utils import evaluate_replay_buffer

print("\n" + "="*20 + " Starting TFC-SR Hyperparameter Tuning on Split CIFAR-100 " + "="*20)

# --- Hyperparameter Search Setup ---
thresholds_to_try = [10.0, 20.0, 30.0, 50.0, 70.0, 90.0]
all_tfc_results = {}

CONFIG['buffer_capacity'] = 1000
CONFIG['replay_batch_size'] = CONFIG['batch_size'] // 2
CONFIG['initial_replay_gap'] = 1   # Start checking after epoch 1
CONFIG['replay_gap_multiplier'] = 1.5 # How much to increase the gap

# --- Outer loop for tuning the mastery_threshold ---
for threshold in thresholds_to_try:
    print(f"\n--->>> STARTING TRIAL: THRESHOLD = {threshold}% <<<---")
    CONFIG['mastery_threshold'] = threshold

    set_seed(CONFIG['seed'])
    model_tfc = get_resnet18_for_cifar(num_classes=CONFIG['num_classes']).to(device)
    optimizer_tfc = optim.Adam(model_tfc.parameters(), lr=CONFIG['lr'])
    criterion_tfc = nn.CrossEntropyLoss()
    replay_buffer_tfc = ReservoirReplayBuffer(capacity=CONFIG['buffer_capacity'])

    current_run_results = { 'accuracies': [], 'total_replay_batches': 0, 'memory_checks': 0, 'schedule_history': [] }

    for task_id, experience in enumerate(split_cifar100_benchmark.train_stream):
        print(f"--> Training on Task {task_id+1}")

        for data_point, target, _ in experience.dataset: replay_buffer_tfc.add(data_point, target)

        current_replay_gap = float(CONFIG['initial_replay_gap'])
        replay_timer = int(current_replay_gap)
        train_loader = DataLoader(experience.dataset, batch_size=CONFIG['batch_size'], shuffle=True)

        model_tfc.train()
        for epoch in range(CONFIG['epochs_per_task']):
            for new_data, new_targets, _ in train_loader:
                 if len(replay_buffer_tfc) >= CONFIG['replay_batch_size']:
                    old_data, old_targets = replay_buffer_tfc.sample(CONFIG['replay_batch_size'])
                    new_data = new_data[:CONFIG['replay_batch_size']]

                    combined_data = torch.cat((new_data, old_data), dim=0).to(device)
                    combined_targets = torch.cat((new_targets[:len(new_data)], old_targets), dim=0).to(device)

                    optimizer_tfc.zero_grad()
                    outputs = model_tfc(combined_data)
                    loss = criterion_tfc(outputs, combined_targets)
                    loss.backward()
                    optimizer_tfc.step()

                    current_run_results['total_replay_batches'] += 1

            # --- Adaptive Replay Scheduling Logic with DIAGNOSTICS ---
            if (epoch + 1) == replay_timer and len(replay_buffer_tfc) > 1:
                current_run_results['memory_checks'] += 1
                model_tfc.eval()

                print(f"\n  [Epoch {epoch+1}] Memory Check Triggered. Current Timer: {replay_timer}")

                replay_perf = evaluate_replay_buffer(model_tfc, replay_buffer_tfc, device)
                print(f"    Replay Buffer Perf: {replay_perf:.2f}%. Comparing against Threshold: {CONFIG['mastery_threshold']}%")

                if replay_perf >= CONFIG['mastery_threshold']:
                    current_replay_gap *= CONFIG['replay_gap_multiplier']
                    replay_timer += round(current_replay_gap)
                    print(f"    RESULT: Mastery MET. New timer set to epoch {replay_timer}.")
                else:
                    replay_timer += 1
                    print(f"    RESULT: Mastery FAILED. New timer set to epoch {replay_timer}.")

                model_tfc.train()

        accuracy = evaluate_on_seen_tasks(model_tfc, split_cifar100_benchmark, task_id, device, CONFIG['batch_size'])
        current_run_results['accuracies'].append(accuracy)
        print(f"----- Accuracy after Task {task_id+1}: {accuracy:.2f}% -----")

    all_tfc_results[threshold] = current_run_results

# --- ANALYSIS, SAVING, AND PLOTTING ---

# 1. Save the FULL dictionary of all trial runs for later analysis
all_tfc_results_path = os.path.join(CONFIG['results_path'], 'tfc_sr_cifar_all_trials.pkl')
save_results(all_tfc_results, all_tfc_results_path)
print(f"\nFull TFC-SR tuning results saved to {all_tfc_results_path}")

# 2. Programmatically find the best result and save it separately for convenience
best_threshold = max(all_tfc_results, key=lambda k: all_tfc_results[k]['accuracies'][-1])
best_tfc_results = all_tfc_results[best_threshold]
print(f"\nBest TFC-SR mastery threshold was {best_threshold}% with final accuracy {best_tfc_results['accuracies'][-1]:.2f}%")
best_tfc_results_path = os.path.join(CONFIG['results_path'], 'tfc_sr_cifar_best.pkl')
save_results(best_tfc_results, best_tfc_results_path)

# 3. Plot the Sensitivity Analysis
print("\n--- Generating Sensitivity Analysis Plot ---")
# Load ER results to use as a reference line
er_cifar_results = load_results(os.path.join(CONFIG['results_path'], 'er_cifar_results.pkl'))
sensitivity_plot_data = {f'TFC-SR (Thresh={t}%)': res['accuracies'] for t, res in all_tfc_results.items()}
sensitivity_plot_data['Standard ER'] = er_cifar_results['accuracies']
plot_results(sensitivity_plot_data, title="TFC-SR Sensitivity to Mastery Threshold on Split CIFAR-100")

# 4. Plot the Main Comparison
print("\n--- Generating Main Comparison Plot ---")
# Load all the "best" results files
baseline_cifar_results = load_results(os.path.join(CONFIG['results_path'], 'baseline_cifar_results.pkl'))
tfc_cifar_best_results = load_results(best_tfc_results_path)

main_plot_data = {
    'Baseline': baseline_cifar_results['accuracies'],
    'Standard ER': er_cifar_results['accuracies'],
    f'TFC-SR (Ours, Thresh={best_threshold}%)': tfc_cifar_best_results['accuracies']
}
plot_results(main_plot_data, title="Main Performance Comparison on Split CIFAR-100")

# 5. Report the efficiency metrics
print(f"\n--- Efficiency Comparison ---")
print(f"Standard ER performed {er_cifar_results['total_replay_batches']} replay batches.")
print(f"Best TFC-SR (Thresh={best_threshold}%) performed {best_tfc_results['total_replay_batches']} replay batches with {best_tfc_results['memory_checks']} memory checks.")

In [None]:
# --- EXPERIMENT: STANDARD ER SENSITIVITY TO BUFFER SIZE (OPTIMIZED) ---

print("\n" + "="*20 + " Starting ER Buffer Size Tuning " + "="*20)

# --- Hyperparameter Search Setup ---
new_buffer_sizes_to_try = [100, 500, 2000]
er_tuning_results = {}
CONFIG['replay_batch_size'] = CONFIG['batch_size'] // 2

# --- Outer loop for tuning the buffer_capacity ---
for capacity in new_buffer_sizes_to_try:
    print(f"\n--->>> STARTING TRIAL: BUFFER CAPACITY = {capacity} <<<---")

    set_seed(CONFIG['seed'])
    model_er = get_resnet18_for_cifar(num_classes=CONFIG['num_classes']).to(device)
    optimizer_er = optim.Adam(model_er.parameters(), lr=CONFIG['lr'])
    criterion_er = nn.CrossEntropyLoss()

    replay_buffer_er = ReservoirReplayBuffer(capacity=capacity)

    current_run_results = { 'accuracies': [], 'total_replay_batches': 0 }

    # --- Main Continual Learning Loop ---
    for task_id, experience in enumerate(split_cifar100_benchmark.train_stream):
        print(f"--> Training on Task {task_id+1}")

        # Populate replay buffer
        for data_point, target, _ in experience.dataset:
            replay_buffer_er.add(data_point, target)

        train_loader = DataLoader(experience.dataset, batch_size=CONFIG['batch_size'], shuffle=True)

        # --- Training loop with mixed batches ---
        model_er.train()
        for epoch in range(CONFIG['epochs_per_task']):
            for new_data, new_targets, _ in train_loader:
                if len(replay_buffer_er) >= CONFIG['replay_batch_size']:
                    old_data, old_targets = replay_buffer_er.sample(CONFIG['replay_batch_size'])
                    new_data = new_data[:CONFIG['replay_batch_size']]

                    combined_data = torch.cat((new_data, old_data), dim=0).to(device)
                    combined_targets = torch.cat((new_targets[:len(new_data)], old_targets), dim=0).to(device)

                    optimizer_er.zero_grad()
                    outputs = model_er(combined_data)
                    loss = criterion_er(outputs, combined_targets)
                    loss.backward()
                    optimizer_er.step()
                    current_run_results['total_replay_batches'] += 1

        # --- Unified Evaluation ---
        accuracy = evaluate_on_seen_tasks(model_er, split_cifar100_benchmark, task_id, device, CONFIG['batch_size'])
        current_run_results['accuracies'].append(accuracy)
        print(f"----- Accuracy after Task {task_id+1}: {accuracy:.2f}% -----")

    # Store the final results for this buffer size
    er_tuning_results[capacity] = current_run_results

# --- ANALYSIS, SAVING, AND PLOTTING ---

# 1. Load the result for the run we already completed
path_to_er_1000_results = os.path.join(CONFIG['results_path'], 'er_cifar_results.pkl')
try:
    er_1000_results = load_results(path_to_er_1000_results)
    # Add the loaded result to our tuning dictionary
    er_tuning_results[1000] = er_1000_results
    print("\nSuccessfully loaded existing results for buffer size 1000.")
except FileNotFoundError:
    print("\nWarning: Could not find existing results for buffer size 1000. It will be missing from the plot.")


# 2. Save the FULL dictionary of all trial runs (new and old) for later analysis
all_er_results_path = os.path.join(CONFIG['results_path'], 'er_cifar_buffer_tuning_ALL.pkl')
save_results(er_tuning_results, all_er_results_path)
print(f"Full ER buffer tuning results saved to {all_er_results_path}")

# 3. Plot the sensitivity analysis curve
print("\n--- Generating Buffer Size Sensitivity Plot for Standard ER ---")
# Sort the dictionary by key (buffer size) for a clean plot
sorted_capacities = sorted(er_tuning_results.keys())
final_accuracies = [er_tuning_results[cap]['accuracies'][-1] for cap in sorted_capacities]

In [None]:
import matplotlib.pyplot as plt


plt.figure(figsize=(8, 5))
plt.plot(sorted_capacities, final_accuracies, marker='o')
plt.title("Standard ER Performance vs. Buffer Capacity on Split CIFAR-100")
plt.xlabel("Replay Buffer Capacity")
plt.ylabel("Final Average Accuracy (%) after 10 Tasks")
plt.xscale('log')
plt.grid(True, which='both', linestyle='--')
plt.show()

# 4. Plot all the learning curves
print("\n--- Generating Learning Curves for Each Buffer Size ---")
plot_data_er = {f'ER (Buffer={cap})': er_tuning_results[cap]['accuracies'] for cap in sorted_capacities}
plot_results(plot_data_er, title="Standard ER Learning Curves by Buffer Size")

In [None]:
# --- EXPERIMENT: TFC-SR SENSITIVITY TO BUFFER SIZE ON SPLIT CIFAR-100 ---

print("\n" + "="*20 + " Starting TFC-SR Buffer Size Tuning on Split CIFAR-100 " + "="*20)

# --- Hyperparameter Search Setup ---
buffer_sizes_to_try = [100, 500, 2000]
all_tfc_buffer_results = {}

# --- Fixed Hyperparameters for this experiment ---
CONFIG['mastery_threshold'] = 10.0 # best threshold we found previously
CONFIG['replay_batch_size'] = CONFIG['batch_size'] // 2
CONFIG['initial_replay_gap'] = 1
CONFIG['replay_gap_multiplier'] = 1.5

# --- Outer loop for tuning the buffer_capacity ---
for buffer_size in buffer_sizes_to_try:
    print(f"\n--->>> STARTING TRIAL: BUFFER SIZE = {buffer_size} <<<---")
    # Set the buffer size for this run
    CONFIG['buffer_capacity'] = buffer_size

    # --- Setup for this specific trial ---
    set_seed(CONFIG['seed'])
    model_tfc = get_resnet18_for_cifar(num_classes=CONFIG['num_classes']).to(device)
    optimizer_tfc = optim.Adam(model_tfc.parameters(), lr=CONFIG['lr'])
    criterion_tfc = nn.CrossEntropyLoss()
    replay_buffer_tfc = ReservoirReplayBuffer(capacity=CONFIG['buffer_capacity'])

    current_run_results = { 'accuracies': [], 'total_replay_batches': 0, 'memory_checks': 0, 'schedule_history': [] }

    # --- Main Continual Learning Loop ---
    for task_id, experience in enumerate(split_cifar100_benchmark.train_stream):
        print(f"--> Training on Task {task_id+1}")
        for data_point, target, _ in experience.dataset: replay_buffer_tfc.add(data_point, target)

        current_replay_gap = float(CONFIG['initial_replay_gap'])
        replay_timer = int(current_replay_gap)
        train_loader = DataLoader(experience.dataset, batch_size=CONFIG['batch_size'], shuffle=True)

        model_tfc.train()
        for epoch in range(CONFIG['epochs_per_task']):
          for new_data, new_targets, _ in train_loader:
                if len(replay_buffer_tfc) >= CONFIG['replay_batch_size']:
                  old_data, old_targets = replay_buffer_tfc.sample(CONFIG['replay_batch_size'])
                  new_data = new_data[:CONFIG['replay_batch_size']]

                  combined_data = torch.cat((new_data, old_data), dim=0).to(device)
                  combined_targets = torch.cat((new_targets[:len(new_data)], old_targets), dim=0).to(device)

                  optimizer_tfc.zero_grad()
                  outputs = model_tfc(combined_data)
                  loss = criterion_tfc(outputs, combined_targets)
                  loss.backward()
                  optimizer_tfc.step()

                  current_run_results['total_replay_batches'] += 1

          # --- Adaptive Replay Scheduling Logic with DIAGNOSTICS ---
          if (epoch + 1) == replay_timer and len(replay_buffer_tfc) > 1:
                current_run_results['memory_checks'] += 1
                model_tfc.eval()

                print(f"\n  [Epoch {epoch+1}] Memory Check Triggered. Current Timer: {replay_timer}")

                replay_perf = evaluate_replay_buffer(model_tfc, replay_buffer_tfc, device)
                print(f"    Replay Buffer Perf: {replay_perf:.2f}%. Comparing against Threshold: {CONFIG['mastery_threshold']}%")

                if replay_perf >= CONFIG['mastery_threshold']:
                    current_replay_gap *= CONFIG['replay_gap_multiplier']
                    replay_timer += round(current_replay_gap)
                    print(f"    RESULT: Mastery MET. New timer set to epoch {replay_timer}.")
                else:
                    replay_timer += 1
                    print(f"    RESULT: Mastery FAILED. New timer set to epoch {replay_timer}.")

                model_tfc.train()

        accuracy = evaluate_on_seen_tasks(model_tfc, split_cifar100_benchmark, task_id, device, CONFIG['batch_size'])
        current_run_results['accuracies'].append(accuracy)
        print(f"----- Accuracy after Task {task_id+1}: {accuracy:.2f}% -----")

    all_tfc_buffer_results[buffer_size] = current_run_results

# --- ANALYSIS, SAVING, AND PLOTTING (This is the corrected part) ---

# 1. Load the result for the run with buffer size 1000
path_to_tfc_1000_results = os.path.join(CONFIG['results_path'], 'tfc_sr_cifar_best.pkl')
try:
    tfc_1000_results = load_results(path_to_tfc_1000_results)
    all_tfc_buffer_results[1000] = tfc_1000_results
    print("\nSuccessfully loaded existing results for TFC-SR with buffer size 1000.")
except FileNotFoundError:
    print(f"\nWarning: Could not find results file at {path_to_tfc_1000_results}. It will be missing.")

# 2. Save the FULL dictionary of all trial runs
all_tfc_results_path = os.path.join(CONFIG['results_path'], 'tfc_sr_cifar_buffer_tuning_ALL.pkl')
save_results(all_tfc_buffer_results, all_tfc_results_path)
print(f"Full TFC-SR buffer tuning results saved to {all_tfc_results_path}")

# 3. Plot the TFC-SR Sensitivity to Buffer Size
print("\n--- Generating Buffer Size Sensitivity Plot for TFC-SR ---")
# Load the corresponding Standard ER results for comparison
er_buffer_tuning_results = load_results(os.path.join(CONFIG['results_path'], 'er_cifar_buffer_tuning_ALL.pkl'))

# Get the final accuracy for each buffer size for both methods
sorted_capacities = sorted(all_tfc_buffer_results.keys())
tfc_final_accuracies = [all_tfc_buffer_results[cap]['accuracies'][-1] for cap in sorted_capacities]
er_final_accuracies = [er_buffer_tuning_results[cap]['accuracies'][-1] for cap in sorted_capacities]

plt.figure(figsize=(8, 5))
plt.plot(sorted_capacities, tfc_final_accuracies, marker='o', label='TFC-SR (Ours)')
plt.plot(sorted_capacities, er_final_accuracies, marker='o', linestyle='--', label='Standard ER')
plt.title("Performance vs. Buffer Capacity on Split CIFAR-100")
plt.xlabel("Replay Buffer Capacity")
plt.ylabel("Final Average Accuracy (%) after 10 Tasks")
plt.xscale('log')
plt.grid(True, which='both', linestyle='--')
plt.legend()
plt.show()

# 4. Plot the Main Comparison using the best buffer size for TFC-SR
# Find which buffer size gave TFC-SR the best final accuracy
best_buffer_size = max(all_tfc_buffer_results, key=lambda k: all_tfc_buffer_results[k]['accuracies'][-1])
best_tfc_results = all_tfc_buffer_results[best_buffer_size]

# Load other baselines
baseline_cifar_results = load_results(os.path.join(CONFIG['results_path'], 'baseline_cifar_results.pkl'))
er_best_buffer_results = er_buffer_tuning_results[best_buffer_size] # Compare ER at the same buffer size

main_plot_data = {
    'Baseline': baseline_cifar_results['accuracies'],
    f'Standard ER (Buffer={best_buffer_size})': er_best_buffer_results['accuracies'],
    f'TFC-SR (Ours, Buffer={best_buffer_size})': best_tfc_results['accuracies']
}
plot_results(main_plot_data, title="Main Performance Comparison on Split CIFAR-100")

In [None]:
print("\n" + "="*20 + " Starting EWC Tuning on Split CIFAR-100 " + "="*20)

# --- Hyperparameter Search Setup ---
ewc_lambdas_to_try = [1000.0, 10000.0, 100000.0]
all_ewc_cifar_results = {}

# Use the best learning rate we found for the baseline
current_lr = CONFIG.get('lr', 0.001)

# --- Outer loop for tuning lambda ---
for lmbda in ewc_lambdas_to_try:
    print(f"\n--->>> STARTING EWC TRIAL: LAMBDA = {lmbda} <<<---")

    set_seed(CONFIG['seed'])

    model_ewc = get_resnet18_for_cifar(num_classes=CONFIG['num_classes']).to(device)
    optimizer_ewc = optim.Adam(model_ewc.parameters(), lr=current_lr)

    # Instantiate the EWC strategy with the current lambda
    ewc_strategy = EWC(
        model_ewc, optimizer_ewc, nn.CrossEntropyLoss(),
        ewc_lambda=lmbda,
        train_mb_size=CONFIG['batch_size'],
        train_epochs=CONFIG['epochs_per_task'],
        device=device
    )

    current_lambda_accuracies = []

    # --- Training and Evaluation Loop ---
    for task_id, experience in enumerate(split_cifar100_benchmark.train_stream):
        print(f"--> Training on Task {task_id+1}")

        ewc_strategy.train(experience)

        accuracy = evaluate_on_seen_tasks(
            ewc_strategy.model,
            split_cifar100_benchmark,
            task_id,
            device,
            CONFIG['batch_size']
        )
        current_lambda_accuracies.append(accuracy)
        print(f"----- Avg Accuracy after Task {task_id+1}: {accuracy:.2f}% -----")

    all_ewc_cifar_results[lmbda] = current_lambda_accuracies

# --- Find the best EWC result and save it ---
best_lambda_ewc = max(all_ewc_cifar_results, key=lambda k: all_ewc_cifar_results[k][-1])
best_ewc_accuracies = all_ewc_cifar_results[best_lambda_ewc]

print(f"\nBest EWC lambda for CIFAR-100 was {best_lambda_ewc} with final accuracy: {best_ewc_accuracies[-1]:.2f}%")

# --- SAVE AND PLOT ---
# Save the results for the best EWC run
ewc_cifar_results_path = os.path.join(CONFIG['results_path'], 'ewc_cifar_best.pkl')
save_results({'accuracies': best_ewc_accuracies}, ewc_cifar_results_path)
print(f"Best EWC (CIFAR-100) results saved to {ewc_cifar_results_path}")

# Load all other "best" results to create the final comparison plot
baseline_cifar_results = load_results(os.path.join(CONFIG['results_path'], 'baseline_cifar_results.pkl'))
er_cifar_best_results = load_results(os.path.join(CONFIG['results_path'], 'er_cifar_buffer_tuning_ALL.pkl'))[1000] # Assuming 1000 was best
tfc_cifar_best_results = load_results(os.path.join(CONFIG['results_path'], 'tfc_sr_cifar_buffer_tuning_ALL.pkl'))[1000] # Assuming 1000 was best

results_to_plot = {
    'Baseline': baseline_cifar_results['accuracies'],
    'Standard ER': er_cifar_best_results['accuracies'],
    f'EWC (Best λ)': best_ewc_accuracies,
    'TFC-SR (Ours)': tfc_cifar_best_results['accuracies']
}
plot_results(results_to_plot, title="Main Performance Comparison on Split CIFAR-100")

In [None]:
# --- EXPERIMENT: SI HYPERPARAMETER TUNING ON SPLIT CIFAR-100 ---

print("\n" + "="*20 + " Starting SI Tuning on Split CIFAR-100 " + "="*20)

# --- Hyperparameter Search Setup ---
si_lambdas_to_try = [1.0, 10.0, 100.0, 1000.0]
all_si_cifar_results = {}

# Use the best learning rate we found for the baseline
current_lr = CONFIG.get('lr', 0.001)

# --- Outer loop for tuning lambda ---
for lmbda in si_lambdas_to_try:
    print(f"\n--->>> STARTING SI TRIAL: LAMBDA = {lmbda} <<<---")

    set_seed(CONFIG['seed'])

    model_si = get_resnet18_for_cifar(num_classes=CONFIG['num_classes']).to(device)
    optimizer_si = optim.Adam(model_si.parameters(), lr=current_lr)

    # Instantiate the SynapticIntelligence strategy
    si_strategy = SynapticIntelligence(
        model_si, optimizer_si, nn.CrossEntropyLoss(),
        si_lambda=lmbda,
        train_mb_size=CONFIG['batch_size'],
        train_epochs=CONFIG['epochs_per_task'],
        device=device
    )

    current_lambda_accuracies = []

    # --- Training and Evaluation Loop ---
    for task_id, experience in enumerate(split_cifar100_benchmark.train_stream):
        print(f"--> Training on Task {task_id+1}")

        si_strategy.train(experience)

        accuracy = evaluate_on_seen_tasks(
            si_strategy.model,
            split_cifar100_benchmark,
            task_id,
            device,
            CONFIG['batch_size']
        )
        current_lambda_accuracies.append(accuracy)
        print(f"----- Avg Accuracy after Task {task_id+1}: {accuracy:.2f}% -----")

    all_si_cifar_results[lmbda] = current_lambda_accuracies

# --- Find the best SI result and save it ---
best_lambda_si = max(all_si_cifar_results, key=lambda k: all_si_cifar_results[k][-1])
best_si_accuracies = all_si_cifar_results[best_lambda_si]

print(f"\nBest SI lambda for CIFAR-100 was {best_lambda_si} with final accuracy: {best_si_accuracies[-1]:.2f}%")

# --- SAVE AND PLOT ---
# Save the results for the best SI run
si_cifar_results_path = os.path.join(CONFIG['results_path'], 'si_cifar_best.pkl')
save_results({'accuracies': best_si_accuracies}, si_cifar_results_path)
print(f"Best SI (CIFAR-100) results saved to {si_cifar_results_path}")

# Load all other "best" results to create the final comparison plot
baseline_cifar_results = load_results(os.path.join(CONFIG['results_path'], 'baseline_cifar_results.pkl'))
er_cifar_best_results = load_results(os.path.join(CONFIG['results_path'], 'er_cifar_buffer_tuning_ALL.pkl'))[1000] # Use your best ER run
ewc_cifar_best_results = load_results(os.path.join(CONFIG['results_path'], 'ewc_cifar_best.pkl'))
tfc_cifar_best_results = load_results(os.path.join(CONFIG['results_path'], 'tfc_sr_cifar_best.pkl'))

results_to_plot = {
    'Baseline': baseline_cifar_results['accuracies'],
    'Standard ER': er_cifar_best_results['accuracies'],
    f'EWC (Best λ={best_lambda_ewc})': ewc_cifar_best_results['accuracies'],
    f'SI (Best λ={best_lambda_si})': best_si_accuracies,
    'TFC-SR': tfc_cifar_best_results['accuracies']
}
plot_results(results_to_plot, title="Main Performance Comparison on Split CIFAR-100")

In [None]:
# --- EXPERIMENT: TFC-SR STESS TEST ON SPLIT CIFAR-100 ---
from utils import evaluate_replay_buffer

print("\n" + "="*20 + " Starting TFC-SR Stress Test on Split CIFAR-100 " + "="*20)


CONFIG['buffer_capacity'] = 1000
CONFIG['replay_batch_size'] = CONFIG['batch_size'] // 2
CONFIG['initial_replay_gap'] = 1   # Start checking after epoch 1
CONFIG['replay_gap_multiplier'] = 1.5 # How much to increase the gap
threshold = 99.0

print(f"\n--->>> STARTING TRIAL: THRESHOLD = {threshold}% <<<---")
CONFIG['mastery_threshold'] = threshold

set_seed(CONFIG['seed'])
model_tfc = get_resnet18_for_cifar(num_classes=CONFIG['num_classes']).to(device)
optimizer_tfc = optim.Adam(model_tfc.parameters(), lr=CONFIG['lr'])
criterion_tfc = nn.CrossEntropyLoss()
replay_buffer_tfc = ReservoirReplayBuffer(capacity=CONFIG['buffer_capacity'])

current_run_results = { 'accuracies': [], 'total_replay_batches': 0, 'memory_checks': 0, 'schedule_history': [] }

for task_id, experience in enumerate(split_cifar100_benchmark.train_stream):
    print(f"--> Training on Task {task_id+1}")

    for data_point, target, _ in experience.dataset: replay_buffer_tfc.add(data_point, target)

    current_replay_gap = float(CONFIG['initial_replay_gap'])
    replay_timer = int(current_replay_gap)
    train_loader = DataLoader(experience.dataset, batch_size=CONFIG['batch_size'], shuffle=True)

    model_tfc.train()
    for epoch in range(CONFIG['epochs_per_task']):
        for new_data, new_targets, _ in train_loader:
              if len(replay_buffer_tfc) >= CONFIG['replay_batch_size']:
                old_data, old_targets = replay_buffer_tfc.sample(CONFIG['replay_batch_size'])
                new_data = new_data[:CONFIG['replay_batch_size']]

                combined_data = torch.cat((new_data, old_data), dim=0).to(device)
                combined_targets = torch.cat((new_targets[:len(new_data)], old_targets), dim=0).to(device)

                optimizer_tfc.zero_grad()
                outputs = model_tfc(combined_data)
                loss = criterion_tfc(outputs, combined_targets)
                loss.backward()
                optimizer_tfc.step()

                current_run_results['total_replay_batches'] += 1

        # --- Adaptive Replay Scheduling Logic with DIAGNOSTICS ---
        if (epoch + 1) == replay_timer and len(replay_buffer_tfc) > 1:
            current_run_results['memory_checks'] += 1
            model_tfc.eval()

            print(f"\n  [Epoch {epoch+1}] Memory Check Triggered. Current Timer: {replay_timer}")

            replay_perf = evaluate_replay_buffer(model_tfc, replay_buffer_tfc, device)
            print(f"    Replay Buffer Perf: {replay_perf:.2f}%. Comparing against Threshold: {CONFIG['mastery_threshold']}%")

            if replay_perf >= CONFIG['mastery_threshold']:
                current_replay_gap *= CONFIG['replay_gap_multiplier']
                replay_timer += round(current_replay_gap)
                print(f"    RESULT: Mastery MET. New timer set to epoch {replay_timer}.")
            else:
                replay_timer += 1
                print(f"    RESULT: Mastery FAILED. New timer set to epoch {replay_timer}.")

            model_tfc.train()

    accuracy = evaluate_on_seen_tasks(model_tfc, split_cifar100_benchmark, task_id, device, CONFIG['batch_size'])
    current_run_results['accuracies'].append(accuracy)
    print(f"----- Accuracy after Task {task_id+1}: {accuracy:.2f}% -----")

# --- ANALYSIS, SAVING, AND PLOTTING ---
tfc_stress_cifar_results_path = os.path.join(CONFIG['results_path'], 'tfc_sr_99_thresh_cifar.pkl')
save_results(current_run_results, tfc_stress_cifar_results_path)
print(f"TFC_SR Stress Test (CIFAR-100) results saved to {tfc_stress_cifar_results_path}")

baseline_cifar_results = load_results(os.path.join(CONFIG['results_path'], 'baseline_cifar_results.pkl'))
er_cifar_best_results = load_results(os.path.join(CONFIG['results_path'], 'er_cifar_buffer_tuning_ALL.pkl'))[1000] # Use your best ER run
tfc_cifar_best_results = load_results(os.path.join(CONFIG['results_path'], 'tfc_sr_cifar_best.pkl'))

results_to_plot = {
    'Baseline': baseline_cifar_results['accuracies'],
    'Standard ER': er_cifar_best_results['accuracies'],
    'TFC-SR (threshold = 10.0)': tfc_cifar_best_results['accuracies'],
    f'TFC-SR (threshold = {threshold})': current_run_results['accuracies'],
}
plot_results(results_to_plot, title="TFC-SR Stress Test on Split CIFAR-100")

In [None]:
# --- EXPERIMENT: TFC-SR with Spaced Replay ON SPLIT CIFAR-100 ---
from utils import evaluate_replay_buffer

print("\n" + "="*20 + " Starting TFC-SR Spaced Replay on Split CIFAR-100 " + "="*20)


CONFIG['buffer_capacity'] = 1000
CONFIG['replay_batch_size'] = CONFIG['batch_size'] // 2
CONFIG['initial_replay_gap'] = 1   # Start checking after epoch 1
CONFIG['replay_gap_multiplier'] = 1.5 # How much to increase the gap
threshold = 10.0

print(f"\n--->>> STARTING TRIAL: THRESHOLD = {threshold}% <<<---")
CONFIG['mastery_threshold'] = threshold

set_seed(CONFIG['seed'])
model_tfc = get_resnet18_for_cifar(num_classes=CONFIG['num_classes']).to(device)
optimizer_tfc = optim.Adam(model_tfc.parameters(), lr=CONFIG['lr'])
criterion_tfc = nn.CrossEntropyLoss()
replay_buffer_tfc = ReservoirReplayBuffer(capacity=CONFIG['buffer_capacity'])

current_run_results = { 'accuracies': [], 'total_replay_batches': 0, 'memory_checks': 0, 'schedule_history': [] }

for task_id, experience in enumerate(split_cifar100_benchmark.train_stream):
    print(f"--> Training on Task {task_id+1}")

    for data_point, target, _ in experience.dataset: replay_buffer_tfc.add(data_point, target)

    current_replay_gap = float(CONFIG['initial_replay_gap'])
    replay_timer = int(current_replay_gap)
    train_loader = DataLoader(experience.dataset, batch_size=CONFIG['batch_size'], shuffle=True)

    model_tfc.train()
    for epoch in range(CONFIG['epochs_per_task']):
        # --- Check if this is a Replay Epoch ---
        if (epoch + 1) == replay_timer and task_id > 0: # Only replay after the first task
            model_tfc.train()
            print(f"\n--- Epoch {epoch+1}: Performing Spaced Replay & Memory Check ---")

            for new_data, new_targets, _ in train_loader:
                if len(replay_buffer_tfc) >= CONFIG['replay_batch_size']:
                    old_data, old_targets = replay_buffer_tfc.sample(CONFIG['replay_batch_size'])
                    new_data = new_data[:CONFIG['replay_batch_size']]

                    combined_data = torch.cat((new_data, old_data), dim=0).to(device)
                    combined_targets = torch.cat((new_targets[:len(new_data)], old_targets), dim=0).to(device)

                    optimizer_tfc.zero_grad()
                    outputs = model_tfc(combined_data)
                    loss = criterion_tfc(outputs, combined_targets)
                    loss.backward()
                    optimizer_tfc.step()

                    current_run_results['total_replay_batches'] += 1

            # After the replay epoch, we perform the memory check to schedule the NEXT replay
            current_run_results['memory_checks'] += 1

            model_tfc.eval()
            print(f"\n  [Epoch {epoch+1}] Memory Check Triggered. Current Timer: {replay_timer}")
            replay_perf = evaluate_replay_buffer(model_tfc, replay_buffer_tfc, device)
            # Update replay_timer based on replay_perf...
            if replay_perf >= CONFIG['mastery_threshold']:
                current_replay_gap *= CONFIG['replay_gap_multiplier']
                replay_timer += round(current_replay_gap)
                print(f"    RESULT: Mastery MET. New timer set to epoch {replay_timer}.")
            else:
                replay_timer += 1
                print(f"    RESULT: Mastery FAILED. New timer set to epoch {replay_timer}.")
        else:
            model_tfc.train()
            print(f"\n--- Epoch {epoch+1}: Training on New Task Data Only ---")

            for new_data, new_targets, _ in train_loader:
                new_data, new_targets = new_data.to(device), new_targets.to(device)

                optimizer_tfc.zero_grad()
                outputs = model_tfc(new_data)
                loss = criterion_tfc(outputs, new_targets)
                loss.backward()
                optimizer_tfc.step()

    accuracy = evaluate_on_seen_tasks(model_tfc, split_cifar100_benchmark, task_id, device, CONFIG['batch_size'])
    current_run_results['accuracies'].append(accuracy)
    print(f"----- Accuracy after Task {task_id+1}: {accuracy:.2f}% -----")

# --- ANALYSIS, SAVING, AND PLOTTING ---
tfc_spaced_replay_cifar_results_path = os.path.join(CONFIG['results_path'], 'tfc_spaced_replay_cifar.pkl')
save_results(current_run_results, tfc_spaced_replay_cifar_results_path)
print(f"TFC_SR with spaced replay (CIFAR-100) results saved to {tfc_spaced_replay_cifar_results_path}")

baseline_cifar_results = load_results(os.path.join(CONFIG['results_path'], 'baseline_cifar_results.pkl'))
er_cifar_best_results = load_results(os.path.join(CONFIG['results_path'], 'er_cifar_buffer_tuning_ALL.pkl'))[1000] # Use your best ER run
tfc_cifar_best_results = load_results(os.path.join(CONFIG['results_path'], 'tfc_sr_cifar_best.pkl'))

results_to_plot = {
    'Baseline': baseline_cifar_results['accuracies'],
    'Standard ER': er_cifar_best_results['accuracies'],
    'TFC-SR': tfc_cifar_best_results['accuracies'],
    f'TFC-SR (with Spaced Replay)': current_run_results['accuracies'],
}
plot_results(results_to_plot, title="TFC-SR with Spaced Replay vs Other methods on Split CIFAR-100")

print("Results for Spaced Replay:")
print(f"Total Replay Batches: {current_run_results['total_replay_batches']}")
print(f"Memory Checks: {current_run_results['memory_checks']}")

In [None]:
# --- EXPERIMENT: Mastery-Gated Progression (MGP) ---

print("\n" + "="*20 + " Starting MGP Experiment " + "="*20)

# --- MGP-Specific Hyperparameters in CONFIG ---
CONFIG['new_task_mastery_thresh'] = 85.0 # e.g., Must get >90% on current task's test set
CONFIG['retention_thresh'] = 15.0      # e.g., Must keep >15% avg accuracy on replay buffer
CONFIG['max_epochs_per_task'] = 50    # A safety break to prevent infinite loops

set_seed(CONFIG['seed'])

# --- Setup for the experiment ---
model_mgp = get_resnet18_for_cifar(num_classes=CONFIG['num_classes']).to(device)
optimizer_mgp = optim.Adam(model_mgp.parameters(), lr=CONFIG['lr'])
criterion_mgp = nn.CrossEntropyLoss()
replay_buffer_mgp = ReservoirReplayBuffer(capacity=CONFIG['buffer_capacity'])

# --- Results Dictionary ---
mgp_results = {
    'accuracies': [],
    'epochs_per_task': [], # Track how long each task took
    'final_accuracy': 0.0
}

# --- Main Continual Learning Loop ---
for task_id, experience in enumerate(split_cifar100_benchmark.train_stream):
    print(f"\n--- Starting to learn Task {task_id+1} ---")

    # Populate replay buffer with the new task's data
    for data, target, _ in experience.dataset:
        replay_buffer_mgp.add(data, target)

    train_loader = DataLoader(experience.dataset, batch_size=CONFIG['batch_size'], shuffle=True)

    # --- The "Practice Until Mastery" While Loop ---
    epoch_count = 0
    mastery_achieved = False
    while not mastery_achieved and epoch_count < CONFIG['max_epochs_per_task']:
        epoch_count += 1
        model_mgp.train()

        # --- Training is always on mixed batches (except for task 1) ---
        for new_data, new_targets, _ in train_loader:
            if task_id > 0 and len(replay_buffer_mgp) >= CONFIG['replay_batch_size']:
                # Mixed batch training
                old_data, old_targets = replay_buffer_mgp.sample(CONFIG['replay_batch_size'])
                new_data = new_data[:CONFIG['replay_batch_size']]
                combined_data = torch.cat((new_data, old_data), dim=0).to(device)
                combined_targets = torch.cat((new_targets[:len(new_data)], old_targets), dim=0).to(device)

                optimizer_mgp.zero_grad()
                outputs = model_mgp(combined_data)
                loss = criterion_mgp(outputs, combined_targets)
                loss.backward()
                optimizer_mgp.step()
            else:
                # For Task 1, train on new data only
                new_data, new_targets = new_data.to(device), new_targets.to(device)
                optimizer_mgp.zero_grad()
                outputs = model_mgp(new_data)
                loss = criterion_mgp(outputs, new_targets)
                loss.backward()
                optimizer_mgp.step()

        # --- Mastery Check at the end of each epoch ---
        # 1. Check performance on the CURRENT task's test set
        current_task_test_loader = DataLoader(split_cifar100_benchmark.test_stream[task_id].dataset, batch_size=CONFIG['batch_size'])
        model_mgp.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for data, targets, _ in current_task_test_loader:
                data, targets = data.to(device), targets.to(device)
                outputs = model_mgp(data)
                _, predicted = torch.max(outputs.data, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()
        new_task_perf = 100.0 * correct / total

        # 2. Check performance on the replay buffer (if not the first task)
        retention_perf = 100.0
        if task_id > 0:
            retention_perf = evaluate_replay_buffer(model_mgp, replay_buffer_mgp, device)

        print(f"Epoch {epoch_count}: New Task Perf: {new_task_perf:.2f}%, Retention Perf: {retention_perf:.2f}%")

        # 3. Check if both conditions are met
        if new_task_perf >= CONFIG['new_task_mastery_thresh'] and retention_perf >= CONFIG['retention_thresh']:
            mastery_achieved = True
            print(f"*** Mastery achieved for Task {task_id+1} in {epoch_count} epochs! ***")

    if not mastery_achieved:
        print(f"!!! Max epochs reached for Task {task_id+1}. Moving on without mastery. !!!")

    # --- Record metrics for this task ---
    mgp_results['epochs_per_task'].append(epoch_count)
    final_task_accuracy = evaluate_on_seen_tasks(model_mgp, split_cifar100_benchmark, task_id, device, CONFIG['batch_size'])
    mgp_results['accuracies'].append(final_task_accuracy)
    print(f"----- Overall Accuracy after Task {task_id+1}: {final_task_accuracy:.2f}% -----")

# --- ANALYSIS, SAVING, AND PLOTTING ---
mgp_cifar_results_path = os.path.join(CONFIG['results_path'], 'mgp_cifar.pkl')
save_results(mgp_results, mgp_cifar_results_path)
print(f"MGP (CIFAR-100) results saved to {mgp_cifar_results_path}")

baseline_cifar_results = load_results(os.path.join(CONFIG['results_path'], 'baseline_cifar_results.pkl'))
er_cifar_best_results = load_results(os.path.join(CONFIG['results_path'], 'er_cifar_buffer_tuning_ALL.pkl'))[1000] # Use your best ER run
tfc_cifar_best_results = load_results(os.path.join(CONFIG['results_path'], 'tfc_sr_cifar_best.pkl'))

results_to_plot = {
    'Baseline': baseline_cifar_results['accuracies'],
    'Standard ER': er_cifar_best_results['accuracies'],
    'TFC-SR': tfc_cifar_best_results['accuracies'],
    f'MGP': mgp_results['accuracies'],
}
plot_results(results_to_plot, title="MGP vs Other methods on Split CIFAR-100")

# --- Final Results ---
mgp_results['final_accuracy'] = mgp_results['accuracies'][-1]
print("\n--- MGP Experiment Finished ---")
print(f"Final Overall Accuracy: {mgp_results['final_accuracy']:.2f}%")
print(f"Epochs taken per task: {mgp_results['epochs_per_task']}")

In [None]:
# 70: 11.21