In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys

import torch
from torch import nn
from torch.optim import SGD
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import torch.nn.functional as F

sys.path.append("..")

from models.mlp import MLP, train_mlp, test_mlp

from tasks.permuted_mnist import *

In [None]:
device = torch.device("mps")

In [None]:
tasks = permute_mnist(5)

In [None]:
model = MLP(784, [100], 10).to(device)

optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.000)
criterion = nn.CrossEntropyLoss()

In [None]:
def get_scores(model, test_loaders):
    scores = []
    for loader in test_loaders:
        scores.append(test_mlp(model, loader, device))
    return scores

In [None]:
model1 = MLP(28 * 28, [100], 10).to(device)
optimizer = SGD(model1.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

sample_tmp = 1

task0_train, task0_test = tasks[0]

train_loader = DataLoader(task0_train, batch_size=128, shuffle=True)
test_loader = DataLoader(task0_test, batch_size=128, shuffle=False)

train_mlp(model1, train_loader, device, optimizer, criterion, 10)

acc = get_scores(model1, [test_loader])

print(acc)

In [None]:
from models.mlp import MLP2


model1 = MLP2(28 * 28, [100], 10).to(device)
optimizer = SGD(model1.parameters(), lr=0.01, momentum=0.9, weight_decay=0.000)
criterion = nn.CrossEntropyLoss()

sample_tmp = 1

task0_train, task0_test = tasks[0]

train_loader = DataLoader(task0_train, batch_size=128, shuffle=True)
test_loader = DataLoader(task0_test, batch_size=128, shuffle=False)

train_mlp(model1, train_loader, device, optimizer, criterion, 5, False)

acc = get_scores(model1, [test_loader])

print(acc)

In [None]:
from models.mlp import MLP2


model1 = MLP2(28 * 28, [100], 10).to(device)
optimizer = SGD(model1.parameters(), lr=0.01, momentum=0.9, weight_decay=0.000)
criterion = nn.CrossEntropyLoss()

sample_tmp = 1

task0_train, task0_test = tasks[0]

train_loader = DataLoader(task0_train, batch_size=128, shuffle=True)
test_loader = DataLoader(task0_test, batch_size=128, shuffle=False)

train_mlp(model1, train_loader, device, optimizer, criterion, 5, True)

acc = get_scores(model1, [test_loader])

print(acc)

In [None]:
test_loaders = []
scores = []

for i, task in enumerate(tasks):
    print(f"Training on task {i}")

    mnist_train, mnist_test = task

    train_loader = DataLoader(
        mnist_train,
        batch_size=64,
        shuffle=True,
        # num_workers=8,
        # prefetch_factor=8,
        # pin_memory=True,
        # persistent_workers=True,
    )
    test_loader = DataLoader(
        mnist_test,
        batch_size=64,
        shuffle=False,
        # num_workers=4,
        # prefetch_factor=4,
        # pin_memory=True,
        # persistent_workers=True,
    )

    test_loaders.append(test_loader)

    sample = i > 0

    train_mlp(model, train_loader, device, optimizer, criterion, 5, sample)

    score = get_scores(model, test_loaders)
    print(score)
    scores.append(score)

## Plot results


In [None]:
# Calculate the average score for each task
average_scores = [sum(score) / len(score) for score in scores]

# Plot the average scores
fig, ax = plt.subplots()
ax.plot(range(1, len(average_scores) + 1), average_scores, marker="o", linestyle="-")

ax.set_xlabel("Number of tasks")
ax.set_ylabel("Average Accuracy")
ax.set_title("Average Model Accuracy Across Different Tasks")
ax.grid(True)

# Ensure x-axis only shows integers
ax.xaxis.set_major_locator(plt.MaxNLocator(integer=True))

plt.show()

In [None]:
# plot the scores
fig, ax = plt.subplots()
colors = [
    "b",
    "g",
    "r",
    "c",
    "m",
    "y",
    "k",
]  # Define a list of colors for different tasks
markers = [
    "o",
    "v",
    "^",
    "<",
    ">",
    "s",
    "p",
    "*",
    "+",
    "x",
]  # Define a list of markers for variety

for i, score in enumerate(scores):
    ax.plot(
        [i + 1] * len(score),
        score,
        marker=markers[i % len(markers)],
        linestyle="-",
        color=colors[i % len(colors)],
        label=f"Task {i+1}",
    )

ax.set_xlabel("Number of tasks")
ax.set_ylabel("Accuracy")
ax.set_title("Model Accuracy Across Different Tasks")
ax.grid(True)
ax.legend(title="Tasks", bbox_to_anchor=(1.05, 1), loc="upper left")
plt.tight_layout()
plt.show()

In [None]:
def calculate_accuracy(outputs, targets):
    return np.mean(outputs.argmax(dim=-1).cpu().numpy() == targets.cpu().numpy())

In [None]:
class ELBO(nn.Module):

    def __init__(self, model, train_size, beta):
        super().__init__()
        self.num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        self.beta = beta
        self.train_size = train_size

    def forward(self, outputs, targets, kl):
        assert not targets.requires_grad
        # print(F.nll_loss(outputs, targets, reduction='mean'), self.beta * kl / self.num_params)
        return (
            F.nll_loss(outputs, targets, reduction="mean")
            + self.beta * kl / self.num_params
        )

In [None]:
def predict(model, dataloader, single_head, task_id, T=10):
    if single_head:
        offset = 0
        output_nodes = 10
    else:
        output_nodes = model.classifiers[0].out_features
        offset = task_id * output_nodes

    model.train()
    accs = []
    for inputs, targets in dataloader:
        inputs, targets = inputs.to(device), targets.to(device)
        targets -= offset
        outputs = torch.zeros(inputs.shape[0], output_nodes, T, device=device)

        for i in range(T):
            with torch.no_grad():
                net_out = model(inputs, task_id)
            outputs[:, :, i] = F.log_softmax(net_out, dim=-1)

        log_output = torch.logsumexp(outputs, dim=-1) - np.log(T)
        accs.append(calculate_accuracy(log_output, targets))

    return np.mean(accs)

In [None]:
def train(
    model, num_epochs, dataloader, single_head, task_id, beta, T=10, replay=False
):
    beta = 0 if replay else beta
    lr_start = 1e-3

    if single_head:
        offset = 0
        output_nodes = 10
    else:
        output_nodes = model.classifiers[0].out_features
        offset = task_id * output_nodes

    train_size = (
        len(dataloader.dataset) if single_head else dataloader.sampler.indices.shape[0]
    )
    elbo = ELBO(model, train_size, beta)
    optimizer = SGD(model.parameters(), lr=lr_start, momentum=0.9)

    model.train()
    for epoch in range(num_epochs):  # tqdm(range(num_epochs)):
        print(f"Epoch {epoch}")
        for inputs, targets in dataloader:
            optimizer.zero_grad()
            inputs, targets = inputs.to(device), targets.to(device)
            targets -= offset
            outputs = torch.zeros(inputs.shape[0], output_nodes, T, device=device)

            for i in range(T):
                net_out = model(inputs, task_id)
                outputs[:, :, i] = F.log_softmax(net_out, dim=-1)

            log_output = torch.logsumexp(outputs, dim=-1) - np.log(T)
            kl = model.get_kl(task_id)
            loss = elbo(log_output, targets, kl)
            loss.backward()
            optimizer.step()

In [None]:
from models.bbb import *

model_perm = PermutedModel().to(device)

In [None]:
task0_train, task0_test = tasks[0]

train_loader = DataLoader(task0_train, batch_size=128, shuffle=True)
test_loader = DataLoader(task0_test, batch_size=128, shuffle=False)


train(model_perm, 10, train_loader, True, 0, 1)

accuracy = predict(model_perm, test_loader, True, 0)
print(accuracy)

In [None]:
def train_final(model, epochs, loader, lr=0.01):
    optimizer = SGD(model.parameters(), lr=lr, momentum=0.9)

    model.train()
    for epoch in range(epochs):
        total_loss = 0
        total_kl = 0
        total_ll = 0

        for inputs, targets in loader:
            optimizer.zero_grad()
            inputs, targets = inputs.to(device), targets.to(device)

            kl = model.kl() / len(loader.dataset)
            ll = model.logpred(inputs, targets)

            loss = kl - ll

            total_loss += loss.item()
            total_kl += kl.item()
            total_ll += ll.item()

            loss.backward()
            optimizer.step()

        total_loss /= len(loader)
        total_kl /= len(loader)
        total_ll /= len(loader)

        print(f"Epoch {epoch} Loss: {total_loss} KL: {total_kl} LL: {total_ll}")

In [None]:
def test_final(model, loader):
    model.eval()
    accs = []
    for inputs, targets in loader:
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        accs.append(calculate_accuracy(outputs, targets))

    return np.mean(accs)

In [None]:
from models.mfvi import *

model_final = MFVI(28 * 28, [100, 100], 10).to(device)

task0_train, task0_test = tasks[0]

train_loader = DataLoader(task0_train, batch_size=128, shuffle=True)
test_loader = DataLoader(task0_test, batch_size=128, shuffle=False)

train_final(model_final, 10, train_loader)

accuracy = test_final(model_final, test_loader)
print(accuracy)