In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
import torch
from torch import nn
from torch.optim import SGD
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import torch.nn.functional as F
from torch.utils.data import TensorDataset, Dataset
import copy
import json
import os
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime
import time
from torch.utils.data import ConcatDataset

sys.path.append("..")

from models.mfvi import *
from tasks.permuted_mnist import *
from tasks.split_mnist import *
from utils.corset import *
from utils.misc import *

In [None]:
device = torch.device("cpu")

In [None]:
permute_tasks = permute_mnist(5)
split_tasks = split_mnist(5)

In [None]:
result_directory = "./experiment_results"

In [None]:
def train(
    model,
    epochs,
    train_set: Dataset,
    batch_size=256,
    lr=0.001,
    deterministic=False,
    head=0,
    train_size=None,
    ignore_kl=False,
    num_samples=10,
    experience=None,
):
    if deterministic:
        print("Training in deterministic mode")

    if not train_size:
        train_size = len(train_set)

    # optimizer = SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=0.0001)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4)
    if experience:
        print("Using experience replay")
        from itertools import cycle

        # Create a cyclic iterator for the experience loader
        exp_loader = cycle(DataLoader(experience, batch_size=40, shuffle=True))

    model.train()
    start_time = time.time()
    for epoch in range(epochs):
        epoch_start_time = time.time()
        total_loss = 0
        total_kl = 0
        total_nll = 0

        for inputs, targets in loader:
            optimizer.zero_grad()
            inputs, targets = inputs.to(device), targets.to(device)

            if experience:
                exp_inputs, exp_targets = next(exp_loader)

                inputs = torch.cat([inputs, exp_inputs], dim=0)
                targets = torch.cat([targets, exp_targets], dim=0)

            kl = model.kl() / train_size
            # kl = model._KL_term() / len(loader.dataset)

            # print(f"kl_old: {kl_old.item()} kl: {kl.item()}")

            nll = model.nll(
                inputs,
                targets,
                deterministic=deterministic,
                head=head,
                num_samples=num_samples,
            )

            if deterministic or ignore_kl:
                kl = torch.tensor(0.0)

            loss = kl + nll  # * MFVI.kl_alpha

            total_loss += loss.item()
            total_kl += kl.item()
            total_nll += nll.item()

            loss.backward()
            optimizer.step()

        total_loss /= len(loader)
        total_kl /= len(loader)
        total_nll /= len(loader)

        epoch_end_time = time.time()
        epoch_time = epoch_end_time - epoch_start_time
        print(
            f"Epoch {epoch} Loss: {total_loss} KL: {total_kl} NLL: {total_nll} Time: {epoch_time:.2f} sec"
        )

    total_train_time = time.time() - start_time
    print(f"Total training time: {total_train_time:.2f} sec")

In [None]:
def test(model, loader, deterministic=False, head=0):
    model.eval()
    correct = 0
    uncertainties = []
    with torch.no_grad():
        for x_batch, y_batch in loader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            y_pred, probabilities = model.predict(
                x_batch, deterministic=deterministic, head=head
            )
            correct += (y_pred == y_batch).sum().item()
            uncertainties.extend(probabilities.tolist())

    acc = correct / len(loader.dataset)
    uncertainty = np.mean(uncertainties)
    uncertainty_std = np.std(uncertainties)
    print(
        f"Accuracy: {acc}, Uncertainty: {uncertainty}, Uncertainty std: {uncertainty_std}"
    )
    return acc

## VCL


In [None]:
def get_scores(
    model,
    test_loaders,
    single_head=True,
    deterministic=False,
    coresets=None,
    epochs=10,
    batch_size=256,
    lr=0.001,
    train_size=None,
):
    score_model = model

    if single_head and coresets:
        merged_coreset = merge_coresets(*coresets)

        score_model = copy.deepcopy(model).to(device)
        # score_model = model
        train(
            score_model,
            epochs,
            merged_coreset,
            batch_size=batch_size,
            lr=lr,
            deterministic=deterministic,
            train_size=train_size,
            # ignore_kl=True,
        )

    scores = []
    for i, loader in enumerate(test_loaders):
        if not single_head and coresets:
            coreset = coresets[i]

            score_model = copy.deepcopy(model).to(device)
            train(
                score_model,
                epochs,
                coreset,
                batch_size=batch_size,
                lr=lr,
                deterministic=deterministic,
                head=i,
                train_size=train_size,
                # ignore_kl=True,
            )

        head = 0 if single_head else i

        used_model = score_model
        if i == len(test_loaders) - 1:
            # do not use the coreset fine-tuned model for the last task
            used_model = model

        acc = test(used_model, loader, head=head, deterministic=deterministic)

        scores.append(acc)
    return scores

In [None]:
def plot_scores_with_average(
    scores,
    title,
    filename=None,
):
    plt.figure(figsize=(10, 6))

    avg_accuracies = []
    for step in scores:
        avg = sum(score for score in step if score is not None) / len(step)
        avg_accuracies.append(avg)

    print("Average accuracies:", avg_accuracies)
    plt.plot(
        avg_accuracies,
        label="Average Accuracy",
        marker="o",
        linestyle="--",
        color="black",
    )

    for task_num in range(len(scores[-1])):
        task_accuracies = [
            step[task_num] if task_num < len(step) else None for step in scores
        ]
        plt.plot(task_accuracies, label=f"Task {task_num + 1}", marker="o")

    # Adding labels and title
    plt.title(title)
    plt.xlabel("Steps")
    plt.ylabel("Accuracy")
    plt.xticks(range(len(scores)), range(1, len(scores) + 1))
    plt.legend()
    plt.grid(True)

    # save the plot
    if filename:
        plt.savefig(filename)

    # Showing the plot
    plt.show()

In [None]:
def store_experiment_results(experiment_params, scores):
    current_datetime = datetime.now().strftime("%Y%m%d_%H%M%S")

    title = experiment_params["title"]

    subdirectory = os.path.join(result_directory, f"{current_datetime}_{title}")

    if not os.path.exists(subdirectory):
        os.makedirs(subdirectory)

    params_file = os.path.join(subdirectory, "parameters.json")
    with open(params_file, "w") as file:
        json.dump(experiment_params, file, indent=4)

    scores_file = os.path.join(subdirectory, "scores.csv")
    pd.DataFrame(scores).to_csv(scores_file, index=False)

    print(f"Experiment results saved in: {subdirectory}")

    return subdirectory

In [None]:
def experiment(
    title,
    tasks,
    batch_size=256,
    lr=0.01,
    epochs=100,
    hidden_layers=[100, 100],
    single_head=True,
    deterministic=False,
    coreset_size=0,
    coreset_method="random",
    initial_deterministic=True,
    update_prior=True,
    model=None,
    num_samples=10,
    output_features=10,
    experience_size=0,
):
    test_loaders = []
    scores = []

    num_tasks = len(tasks)
    num_heads = 1 if single_head else num_tasks

    coresets = []
    experience = None
    if experience_size > 0:
        experience = Coreset(experience_size, coreset_method)

    if not model:
        model = MFVI(28 * 28, hidden_layers, output_features, num_heads=num_heads).to(
            device
        )

    for i, task in enumerate(tasks):
        print(f"Training on task {i}")

        mnist_train, mnist_test = task

        train_set = to_tensor_dataset(mnist_train)

        train_size = len(train_set)

        test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False)
        test_loaders.append(test_loader)

        head = 0 if single_head else i

        # if not single_head:
        #     model.reset_head(i)

        # if i > 0:
        #     model.reset_posterior()

        if initial_deterministic and i == 0 and not deterministic:
            train(model, epochs, train_set, lr=lr, deterministic=True)
            acc = test(model, test_loader, deterministic=True)
            print("Acc deterministic", acc)
            acc = test(model, test_loader, deterministic=False)
            print("Acc stochastic", acc)

        if coreset_size > 0:
            coreset = Coreset(coreset_size, coreset_method)
            coreset.with_method(train_set)

            coresets.append(coreset)

        if experience_size > 0:
            experience.with_method(train_set)

        # merged_coreset = merge_coresets(*coresets)

        # merged_train_set = ConcatDataset([train_set] + 10 * [merged_coreset])
        # train_set = merged_train_set

        train(
            model,
            epochs,
            train_set,
            lr=lr,
            deterministic=deterministic,
            head=head,
            train_size=train_size,
            num_samples=num_samples,
            experience=experience,
        )
        # train(
        #     model,
        #     epochs,
        #     train_set,
        #     lr=lr / 10,
        #     deterministic=deterministic,
        #     head=head,
        #     train_size=train_size,
        #     ignore_kl=True,
        # )

        score = get_scores(
            model,
            test_loaders,
            single_head=single_head,
            deterministic=deterministic,
            coresets=coresets,
            epochs=epochs,
            batch_size=batch_size,
            lr=lr,
            train_size=train_size,
        )
        print(score)
        scores.append(score)

        # model.print_params()

        if update_prior:
            model.update_prior()
        MFVI.kl_alpha += 1

    subdirectory = store_experiment_results(
        {
            "title": title,
            "batch_size": batch_size,
            "lr": lr,
            "epochs": epochs,
            "hidden_layers": hidden_layers,
            "single_head": single_head,
            "deterministic": deterministic,
            "coreset_size": coreset_size,
            "coreset_method": coreset_method,
            "initial_deterministic": initial_deterministic,
            "update_prior": update_prior,
        },
        scores,
    )
    # plot_scores_with_average(
    #     scores, title, filename=os.path.join(subdirectory, "plot.png")
    # )
    return scores, model

## Experiments


In [None]:
raise Exception("Stop here")

In [None]:
experiment(
    "Permuted MNIST VCL",
    permute_tasks,
    batch_size=256,
    lr=0.005,
    epochs=80,
    hidden_layers=[100, 100],
    single_head=True,
    initial_deterministic=False,
    # experience_size=40,
)

In [None]:
experiment(
    "Permuted MNIST VCL ER",
    permute_tasks,
    batch_size=256,
    lr=0.005,
    epochs=80,
    hidden_layers=[100, 100],
    single_head=True,
    initial_deterministic=False,
    experience_size=200,
    coreset_method="random",
)

In [None]:
experiment(
    "Permuted MNIST VCL ER",
    permute_tasks,
    batch_size=256,
    lr=0.005,
    epochs=80,
    hidden_layers=[100, 100],
    single_head=True,
    initial_deterministic=False,
    coreset_size=200,
    coreset_method="random",
)

In [None]:
MFVI.kl_alpha = 1
experiment(
    "Permuted MNIST VCL ER",
    permute_tasks,
    batch_size=256,
    lr=0.005,
    epochs=80,
    hidden_layers=[100, 100],
    single_head=True,
    initial_deterministic=False,
    # coreset_size=200,
    # coreset_method="k_center",
)

In [None]:
MFVI.kl_alpha = 1
experiment(
    "Permuted MNIST VCL ER",
    permute_tasks,
    batch_size=256,
    lr=0.005,
    epochs=80,
    hidden_layers=[100, 100],
    single_head=True,
    initial_deterministic=False,
    # coreset_size=200,
    coreset_method="random",
    experience_size=200,
)

In [None]:
raise Exception("Stop here")

In [None]:
experiment(
    "Split MNIST VCL",
    split_tasks,
    256,
    0.005,
    80,
    [256, 256],
    single_head=True,
    initial_deterministic=False,
    output_features=2,
    # coreset_size=40,
    # coreset_method="k_center",
    update_prior=True,
    # experience_size=40,
)

In [None]:
experiment(
    "Split MNIST VCL ER",
    split_tasks,
    256,
    0.005,
    80,
    [256, 256],
    single_head=True,
    initial_deterministic=False,
    output_features=2,
    # coreset_size=40,
    coreset_method="random",
    update_prior=True,
    experience_size=40,
)

In [None]:
experiment(
    "Split MNIST VCL Coreset",
    split_tasks,
    256,
    0.005,
    80,
    [256, 256],
    single_head=True,
    initial_deterministic=False,
    output_features=2,
    coreset_size=40,
    coreset_method="random",
    update_prior=True,
    # experience_size=40,
)

In [None]:
MFVI.kl_alpha = 1
experiment(
    "Split MNIST VCL",
    split_tasks,
    256,
    0.005,
    80,
    [256, 256],
    single_head=True,
    initial_deterministic=False,
    output_features=2,
    # coreset_size=40,
    # coreset_method="k_center",
    update_prior=True,
    # experience_size=40,
)

In [None]:
MFVI.kl_alpha = 1
experiment(
    "Split MNIST VCL",
    split_tasks,
    256,
    0.005,
    80,
    [256, 256],
    single_head=True,
    initial_deterministic=False,
    output_features=2,
    # coreset_size=40,
    coreset_method="random",
    update_prior=True,
    experience_size=40,
)

In [None]:
experiment(
    "Split MNIST VCL",
    split_tasks,
    256,
    0.005,
    40,
    [256, 256],
    single_head=False,
    initial_deterministic=False,
    output_features=2,
    # coreset_size=40,
    # coreset_method="random",
    # update_prior=True,
    # experience_size=40,
)

In [None]:
experiment(
    "Split MNIST VCL",
    split_tasks,
    256,
    0.005,
    40,
    [256, 256],
    single_head=False,
    initial_deterministic=False,
    output_features=2,
    # coreset_size=40,
    coreset_method="random",
    # update_prior=True,
    experience_size=40,
)

In [None]:
experiment(
    "Split MNIST VCL",
    split_tasks,
    256,
    0.005,
    40,
    [256, 256],
    single_head=False,
    initial_deterministic=False,
    output_features=2,
    coreset_size=40,
    coreset_method="random",
    # update_prior=True,
    # experience_size=40,
)

In [None]:
MFVI.kl_alpha = 1
experiment(
    "Split MNIST VCL",
    split_tasks,
    256,
    0.005,
    40,
    [256, 256],
    single_head=False,
    initial_deterministic=False,
    output_features=2,
    # coreset_size=40,
    # coreset_method="random",
    # update_prior=True,
    # experience_size=40,
)

In [None]:
MFVI.kl_alpha = 1
experiment(
    "Split MNIST VCL",
    split_tasks,
    256,
    0.005,
    40,
    [256, 256],
    single_head=False,
    initial_deterministic=False,
    output_features=2,
    # coreset_size=40,
    # coreset_method="random",
    # update_prior=True,
    experience_size=40,
)

In [None]:
raise Exception("Stop here")

### Permuted MNIST


In [None]:
experiment(
    "Permuted MNIST VCL",
    permute_tasks,
    batch_size=256,
    lr=0.01,
    epochs=10,
    hidden_layers=[100, 100],
    single_head=True,
    initial_deterministic=False,
    experience_size=40,
)

In [None]:
experiment(
    "Permuted MNIST VCL",
    permute_tasks,
    batch_size=256,
    lr=0.001,
    epochs=100,
    hidden_layers=[100, 100],
    single_head=True,
)

In [None]:
experiment(
    "Permuted MNIST VCL",
    permute_tasks,
    batch_size=256,
    lr=0.001,
    epochs=100,
    hidden_layers=[100, 100],
    single_head=True,
)

In [None]:
experiment(
    "Permuted MNIST VCL Coreset Random",
    permute_tasks,
    batch_size=256,
    lr=0.001,
    epochs=100,
    hidden_layers=[100, 100],
    single_head=True,
    coreset_size=200,
    coreset_method="random",
)

In [None]:
experiment(
    "Permuted MNIST VCL Coreset k-center",
    permute_tasks,
    batch_size=256,
    lr=0.001,
    epochs=15,
    hidden_layers=[100, 100],
    single_head=True,
    coreset_size=200,
    coreset_method="k_center",
)

### Split MNIST


In [None]:
model_it1 = copy.deepcopy(model)

In [None]:
model_it2 = copy.deepcopy(model)

In [None]:
used_model = copy.deepcopy(model_it2)

scores, model = experiment(
    "Split MNIST VCL xxx!",
    split_tasks[2:3],
    256,
    0.001,
    15,
    [100, 100],
    single_head=True,
    initial_deterministic=False,
    # coreset_size=40,
    update_prior=True,
    model=used_model,
)
test(
    model,
    DataLoader(split_tasks[0][1], batch_size=256, shuffle=False),
    deterministic=False,
)
test(
    model,
    DataLoader(split_tasks[1][1], batch_size=256, shuffle=False),
    deterministic=False,
)

In [None]:
MFVI.kl_alpha = 1
scores, _ = experiment(
    "Split MNIST VCL xxx",
    split_tasks,
    256,
    0.001,
    10,
    [256, 256],
    single_head=True,
    initial_deterministic=False,
    # coreset_size=40,
    update_prior=True,
)

In [None]:
MFVI.kl_alpha = 1
scores, _ = experiment(
    "Split MNIST VCL xxx",
    split_tasks,
    256,
    0.001,
    40,
    [256, 256],
    single_head=True,
    initial_deterministic=False,
    # coreset_size=40,
    update_prior=True,
)

In [None]:
MFVI.kl_alpha = 1
scores, _ = experiment(
    "Split MNIST VCL xxx",
    split_tasks,
    256,
    0.01,
    10,
    [256, 256],
    single_head=True,
    initial_deterministic=False,
    # coreset_size=40,
    update_prior=True,
)

In [None]:
# MFVI.kl_alpha = 1
scores, _ = experiment(
    "Split MNIST VCL xxx",
    split_tasks,
    256,
    0.005,
    40,
    [256, 256],
    single_head=True,
    initial_deterministic=False,
    output_features=2,
    # coreset_size=40,
    # coreset_method="k_center",
    update_prior=True,
    experience_size=40,
)

In [None]:
# MFVI.kl_alpha = 1
scores, _ = experiment(
    "Split MNIST VCL xxx",
    split_tasks,
    256,
    0.005,
    80,
    [256, 256],
    single_head=True,
    initial_deterministic=False,
    output_features=2,
    # coreset_size=40,
    # coreset_method="k_center",
    update_prior=True,
    # experience_size=40,
)

In [None]:
MFVI.kl_alpha = 1
scores, _ = experiment(
    "Split MNIST VCL xxx",
    split_tasks,
    256,
    0.01,
    100,
    [256, 256],
    single_head=True,
    initial_deterministic=False,
    # coreset_size=40,
    # coreset_method="k_center",
    update_prior=True,
)

In [None]:
MFVI.kl_alpha = 1
scores, _ = experiment(
    "Split MNIST VCL xxx",
    split_tasks,
    256,
    0.001,
    10,
    [256, 256],
    single_head=True,
    initial_deterministic=False,
    coreset_size=40,
    coreset_method="k_center",
    update_prior=True,
)

In [None]:
MFVI.kl_alpha = 1
scores, _ = experiment(
    "Split MNIST VCL xxx",
    split_tasks,
    256,
    0.001,
    10,
    [256, 256],
    single_head=True,
    initial_deterministic=False,
    # coreset_size=40,
    update_prior=True,
)

In [None]:
MFVI.kl_alpha = 1
scores, _ = experiment(
    "Split MNIST VCL xxx",
    split_tasks,
    256,
    0.001,
    10,
    [256, 256],
    single_head=True,
    initial_deterministic=True,
    # coreset_size=40,
    update_prior=True,
    num_samples=10,
)

In [None]:
MFVI.kl_alpha = 1
scores, _ = experiment(
    "Split MNIST VCL xxx",
    split_tasks,
    256,
    0.001,
    70,
    [256, 256],
    single_head=True,
    initial_deterministic=False,
    # coreset_size=40,
    update_prior=True,
    num_samples=20,
)

In [None]:
MFVI.kl_alpha = 1
scores, _ = experiment(
    "Split MNIST VCL xxx",
    split_tasks,
    256,
    0.001,
    40,
    [256, 256],
    single_head=True,
    initial_deterministic=False,
    coreset_size=40,
    update_prior=True,
)

In [None]:
MFVI.kl_alpha = 1
scores, _ = experiment(
    "Split MNIST VCL xxx",
    split_tasks,
    256,
    0.001,
    40,
    [256, 256],
    single_head=True,
    initial_deterministic=False,
    # coreset_size=40,
    update_prior=True,
)

In [None]:
MFVI.kl_alpha = 1
scores, _ = experiment(
    "Split MNIST VCL xxx",
    split_tasks,
    256,
    0.001,
    10,
    [512, 512],
    single_head=True,
    initial_deterministic=False,
    coreset_size=40,
    update_prior=True,
)

In [None]:
MFVI.kl_alpha = 1
scores, _ = experiment(
    "Split MNIST VCL xxx",
    split_tasks,
    256,
    0.001,
    10,
    [256, 256],
    single_head=True,
    initial_deterministic=False,
    coreset_size=40,
    update_prior=True,
)

In [None]:
MFVI.kl_alpha = 1
scores, _ = experiment(
    "Split MNIST VCL xxx",
    split_tasks,
    256,
    0.001,
    10,
    [256, 256],
    single_head=True,
    initial_deterministic=False,
    # coreset_size=40,
    update_prior=True,
)

In [None]:
MFVI.kl_alpha = 1
scores, _ = experiment(
    "Split MNIST VCL xxx",
    split_tasks,
    256,
    0.001,
    30,
    [256, 256],
    single_head=True,
    initial_deterministic=False,
    # coreset_size=40,
    update_prior=True,
)

In [None]:
highscore = 0
alpha = 1
best_alpha = 1

for i in range(7):
    print(f"Experiment {i}")
    alpha *= 2
    MFVI.kl_alpha = alpha
    scores, _ = experiment(
        "Split MNIST VCL xxx",
        split_tasks,
        256,
        0.01,
        5,
        [256, 256],
        single_head=True,
        initial_deterministic=False,
        coreset_size=40,
        update_prior=True,
    )
    last_avg = sum(scores[-1]) / len(scores)
    print(f"Last average: {last_avg}")
    if last_avg > highscore:
        highscore = last_avg
        best_alpha = alpha
print(f"Best alpha: {best_alpha}")

In [None]:
experiment(
    "Split MNIST VCL xxx",
    split_tasks,
    256,
    0.01,
    5,
    [100, 100],
    single_head=True,
    initial_deterministic=False,
    # coreset_size=100,
    update_prior=True,
)

In [None]:
experiment(
    "Split MNIST VCL xxx",
    split_tasks,
    256,
    0.01,
    5,
    [100, 100],
    single_head=True,
    initial_deterministic=True,
    coreset_size=100,
    update_prior=True,
)

In [None]:
experiment(
    "Split MNIST VCL xxx",
    split_tasks,
    256,
    0.01,
    5,
    [100, 100],
    single_head=True,
    initial_deterministic=True,
    coreset_size=100,
)

In [None]:
experiment(
    "Split MNIST VCL xxx",
    split_tasks,
    256,
    0.01,
    5,
    [100, 100],
    single_head=True,
    initial_deterministic=True,
)

In [None]:
experiment(
    "Split MNIST VCL Multi-Head",
    split_tasks,
    256,
    0.001,
    40,
    [256, 256],
    single_head=False,
    initial_deterministic=False,
)

In [None]:
experiment(
    "Split MNIST VCL Multi-Head Coreset Random",
    split_tasks,
    256,
    0.001,
    40,
    [256, 256],
    single_head=False,
    coreset_size=40,
    coreset_method="random",
)

In [None]:
experiment(
    "Split MNIST VCL Multi-Head Coreset k-center",
    split_tasks,
    256,
    0.001,
    40,
    [256, 256],
    single_head=False,
    coreset_size=40,
    coreset_method="k_center",
)

In [None]:
experiment(
    "Split MNIST VCL Single-Head",
    split_tasks,
    256,
    0.001,
    40,
    [256, 256],
    single_head=True,
)

In [None]:
experiment(
    "Split MNIST VCL Single-Head Coreset Random",
    split_tasks,
    256,
    0.001,
    2,
    [256, 256],
    single_head=True,
    coreset_size=40,
    coreset_method="random",
)

In [None]:
experiment(
    "Split MNIST VCL Single-Head Coreset k-center",
    split_tasks,
    256,
    0.001,
    10,
    [256, 256],
    single_head=True,
    coreset_size=40,
    coreset_method="k_center",
)

In [None]:
raise Exception("Stop here")

### Novel shit


In [None]:
experiment(
    "Split MNIST VCL Single-Head Experiment",
    split_tasks,
    256,
    0.001,
    5,
    [256, 256],
    single_head=True,
)

In [None]:
experiment(
    "Split MNIST VCL Single-Head Experiment Deterministic",
    split_tasks,
    256,
    0.01,
    5,
    [256, 256],
    single_head=True,
    coreset_size=40,
    deterministic=True,
)

In [None]:
experiment(
    "Split MNIST VCL Single-Head Experiment",
    split_tasks,
    256,
    0.01,
    5,
    [256, 256],
    single_head=True,
    coreset_size=40,
)

## Split exp


In [None]:
raise Exception("Stop here")

In [None]:
experiment(
    "Single-Head Split MNIST VCL",
    split_tasks,
    256,
    0.005,
    80,
    [256, 256],
    single_head=True,
    initial_deterministic=False,
    output_features=2,
    # coreset_size=40,
    # coreset_method="k_center",
    update_prior=True,
    # experience_size=40,
)

In [None]:
experiment(
    "Single-Head Split MNIST VCL Coreset Random",
    split_tasks,
    256,
    0.005,
    80,
    [256, 256],
    single_head=True,
    initial_deterministic=False,
    output_features=2,
    coreset_size=40,
    coreset_method="random",
    update_prior=True,
    # experience_size=40,
)

In [None]:
experiment(
    "Single-Head Split MNIST VCL Coreset K-Center",
    split_tasks,
    256,
    0.005,
    80,
    [256, 256],
    single_head=True,
    initial_deterministic=False,
    output_features=2,
    coreset_size=40,
    coreset_method="k_center",
    update_prior=True,
    # experience_size=40,
)

In [None]:
experiment(
    "Single-Head Split MNIST VCL ER Random",
    split_tasks,
    256,
    0.005,
    80,
    [256, 256],
    single_head=True,
    initial_deterministic=False,
    output_features=2,
    # coreset_size=40,
    coreset_method="random",
    update_prior=True,
    experience_size=40,
)

In [None]:
experiment(
    "Single-Head Split MNIST VCL ER K-Center",
    split_tasks,
    256,
    0.005,
    80,
    [256, 256],
    single_head=True,
    initial_deterministic=False,
    output_features=2,
    # coreset_size=40,
    coreset_method="k_center",
    update_prior=True,
    experience_size=40,
)

In [None]:
MFVI.kl_alpha = 1
experiment(
    "Single-Head Split MNIST VCL KL exponential",
    split_tasks,
    256,
    0.005,
    80,
    [256, 256],
    single_head=True,
    initial_deterministic=False,
    output_features=2,
    # coreset_size=40,
    # coreset_method="k_center",
    update_prior=True,
    # experience_size=40,
)

### Other experiments


In [None]:
raise Exception("Stop here")

In [None]:
experiment(
    "Permuted MNIST VCL Coreset k-center",
    permute_tasks,
    batch_size=256,
    lr=0.001,
    epochs=100,
    hidden_layers=[100, 100],
    single_head=True,
    coreset_size=200,
    coreset_method="k_center",
)

In [None]:
experiment(
    "Permuted MNIST VCL",
    permute_tasks,
    batch_size=256,
    lr=0.001,
    epochs=100,
    hidden_layers=[100, 100],
    single_head=True,
)

In [None]:
experiment(
    "Split MNIST multi", split_tasks, 256, 0.001, 5, [256, 256], single_head=False
)

In [None]:
experiment(
    "Split MNIST single", split_tasks, 256, 0.001, 5, [256, 256], single_head=True
)

In [None]:
experiment(
    "Split MNIST multi", split_tasks, 256, 0.001, 5, [256, 256], single_head=False
)

In [None]:
experiment(
    "Split MNIST multi random",
    split_tasks,
    256,
    0.001,
    15,
    [256, 256],
    single_head=False,
    coreset_size=40,
    coreset_method="random",
)

In [None]:
scores = [
    [0.9995271867612293],
    [0.9995271867612293, 0.9799216454456415],
    [0.9995271867612293, 0.9696376101860921, 0.9935965848452508],
    [0.9995271867612293, 0.9735553379040157, 0.9882604055496265, 0.9939577039274925],
    [
        0.9981087470449173,
        0.9490695396669931,
        0.9919957310565635,
        0.9798590130916415,
        0.9798285426122038,
    ],
]
print(scores)
plot_scores(scores, "Split MNIST single")

In [None]:
experiment(
    "Permuted MNIST 2 epochs",
    permute_tasks,
    batch_size=256,
    lr=0.001,
    epochs=2,
    hidden_layers=[100, 100],
    single_head=True,
    coreset_size=200,
    coreset_method="random",
    initial_deterministic=True,
)

In [None]:
experiment(
    "Permuted MNIST VCL random",
    permute_tasks,
    batch_size=256,
    lr=0.001,
    epochs=15,
    hidden_layers=[100, 100],
    single_head=True,
    coreset_size=200,
    coreset_method="random",
    initial_deterministic=True,
)

In [None]:
raise Exception("Stop here")

In [None]:
import torch.utils.bottleneck as bottleneck

# experiment(
#     "Permuted MNIST 2 epochs",
#     permute_tasks,
#     batch_size=256,
#     lr=0.001,
#     epochs=2,
#     hidden_layers=[100, 100],
#     single_head=True,
# )

# bottleneck.run

In [None]:
x = [
    [0.9681],
    [0.9427, 0.9392],
    [0.9091, 0.9072, 0.9427],
    [0.8277, 0.882, 0.9244, 0.9393],
    [0.7394, 0.8362, 0.8892, 0.9189, 0.9433],
]

plot_scores_with_average(x, "Split MNIST")

In [None]:
batch_size = 128

test_loaders = []
scores = []

model = MFVI(28 * 28, [100, 100], 10).to(device)

for i, task in enumerate(tasks):
    print(f"Training on task {i}")

    mnist_train, mnist_test = task

    train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False)

    test_loaders.append(test_loader)

    train(model, 10, train_loader)

    score = get_scores(model, test_loaders)
    print(score)
    scores.append(score)

    model.update_prior()

## Plot results


In [None]:
# Calculate the average score for each task
average_scores = [sum(score) / len(score) for score in scores]

# Plot the average scores
fig, ax = plt.subplots()
ax.plot(range(1, len(average_scores) + 1), average_scores, marker="o", linestyle="-")

ax.set_xlabel("Number of tasks")
ax.set_ylabel("Average Accuracy")
ax.set_title("Average Model Accuracy Across Different Tasks")
ax.grid(True)

# Ensure x-axis only shows integers
ax.xaxis.set_major_locator(plt.MaxNLocator(integer=True))

plt.show()

In [None]:
# plot the scores
fig, ax = plt.subplots()
colors = [
    "b",
    "g",
    "r",
    "c",
    "m",
    "y",
    "k",
]  # Define a list of colors for different tasks
markers = [
    "o",
    "v",
    "^",
    "<",
    ">",
    "s",
    "p",
    "*",
    "+",
    "x",
]  # Define a list of markers for variety

for i, score in enumerate(scores):
    ax.plot(
        [i + 1] * len(score),
        score,
        marker=markers[i % len(markers)],
        linestyle="-",
        color=colors[i % len(colors)],
        label=f"Task {i+1}",
    )

ax.set_xlabel("Number of tasks")
ax.set_ylabel("Accuracy")
ax.set_title("Model Accuracy Across Different Tasks")
ax.grid(True)
ax.legend(title="Tasks", bbox_to_anchor=(1.05, 1), loc="upper left")
plt.tight_layout()
plt.show()

In [None]:
task0_train, task0_test = permute_tasks[0]

train_set = to_tensor_dataset(task0_train)
print(len(train_set))

coreset = Coreset(500, "random")

coreset.add_random(train_set)


print("Coreset size:", len(coreset))
print("Train set size:", len(train_set))

for t in coreset.tensors:
    print(t.shape)
for t in train_set.tensors:
    print(t.shape)

In [None]:
coreset.coreset_size = 5

coreset.add_k_center(train_set)

print("Coreset size:", len(coreset))
print("Train set size:", len(train_set))

In [None]:
coreset.coreset_size = 5

coreset.add_k_center(train_set)

print("Coreset size:", len(coreset))
print("Train set size:", len(train_set))

In [None]:
coreset.coreset_size = 5

coreset.add_k_center(train_set)

print("Coreset size:", len(coreset))
print("Train set size:", len(train_set))

In [None]:
coreset_2 = Coreset(5, "random")

coreset_2.with_method(train_set)

print("Coreset size:", len(coreset_2))
print("Train set size:", len(train_set))

In [None]:
coreset_3 = merge_coresets(coreset, coreset_2)

print("Coreset size:", len(coreset_3))

print(coreset_3.tensors[0].shape)
print(coreset_3.tensors[1].shape)