In [1]:
import random
import torch
import torchvision
import torch.nn as nn
import torchvision.transforms as transforms
from torch import Tensor
from torch.utils.data import DataLoader
from model import (QuadraticDendriticNet, LinearNeuralNet, QuadraticNeuralNet)

D = 28 * 28
K = 10
features = (D, 2048, K)
num_tasks = 20
num_epochs = 10
batch_size = 6000
learning_rate = 0.01

def prepare_dataset():
    train_loader, test_loader, prototype = {}, {}, {}

    permute_idx = list(range(D))
    for task in range(num_tasks):
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        # dataset before permute
        train_set = torchvision.datasets.MNIST(
            root='', train=True, download=True, transform=transform
        )
        test_set = torchvision.datasets.MNIST(
            root='', train=False, download=True, transform=transform
        )

        # permute the dataset according to permute index
        train_set.data = train_set.data.reshape(-1, D)[:, permute_idx]
        test_set.data = test_set.data.reshape(-1, D)[:, permute_idx]

        # Compute the context vector of current task
        prototype[task] = torch.mean(Tensor.float(train_set.data), dim=0)

        # Data loader
        train_loader[task] = torch.utils.data.DataLoader(
            dataset=train_set, batch_size=batch_size, shuffle=True
        )
        test_loader[task] = torch.utils.data.DataLoader(
            dataset=test_set, batch_size=batch_size, shuffle=False
        )

        # shuffle the permutation index
        permute_idx = torch.randperm(D).tolist()

    return train_loader, test_loader, prototype

def test_all_task(model, test_loader):
    with torch.no_grad():
        for task in range(num_tasks):
            correct = 0
            total = 0
            for images, labels in test_loader[task]:
                images = images.reshape(-1, 28*28)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

            print('Task [{}/{}], Accuracy: {} %'
                  .format(task+1, num_tasks, 100 * correct / total))

train_loader, test_loader, context_vector = prepare_dataset()

In [12]:
def is_match(x, y, threshold=0.99):  # N_x * D, N_y * D
    x_n, d = x.shape  # N_x, D
    y_n, _ = y.shape  # N_y

    mean_x = torch.mean(x, dim=0, keepdim=True)  # 1 * D
    mean_y = torch.mean(y, dim=0, keepdim=True)  # 1 * D
    mean_diff = mean_x - mean_y  # 1 * D
    centered_x = x - mean_x  # N_x * D
    centered_y = y - mean_y  # N_y * D

    cov_x = torch.t(centered_x) @ centered_x  # D * D
    cov_y = torch.t(centered_y) @ centered_y  # D * D
    pooled_cov = ( cov_x + cov_y ) / ( x_n + y_n - 2 )  # D * D

    t2 = mean_diff @ torch.pinverse(pooled_cov) @ torch.t(mean_diff)
    t2 = (x_n * y_n * t2) / (x_n + y_n)
    f = (t2 * (x_n + y_n - d - 1)) / (d * (x_n + y_n - 2))

    return f <= threshold


def cluster(x, y_t, p) -> int:  # N_x * D, T * N_y * D, T
    for task in range(len(p)-1, -1, -1):
        if is_match(x, y_t[task]):
            y_t[task] = torch.cat((x, y_t[task]))  # (N_x + N_y) * D
            p[task] = torch.mean(y_t[task], dim=0)
            return task
    y_t[len(p)] = x
    p[len(p)] = torch.mean(y_t[len(p)], dim=0)
    return len(p)

In [13]:
y_t = {}  # key: task, value: batch of the task, [D]
p = {}  # key: task, value: prototype, [D]

for task in range(num_tasks):
    for step, (x, _) in enumerate(train_loader[task]):
        cluster(x.reshape(-1, D), y_t, p)
    print("task={}".format(task+1))
    print(len(p))

task=1
1
task=2
2
task=3
3
task=4
4
task=5
5
task=6
6
task=7
8
task=8
10
task=9
12
task=10
13
task=11
14
task=12
15
task=13
16
task=14
17
task=15
18
task=16
19
task=17
20
task=18
21
task=19
23
task=20
24


In [None]:
all_set = []
for task in range(num_tasks):
    for step, (x, labels) in enumerate(train_loader[task]):
        all_set.append((task, (x, labels)))
random.shuffle(all_set)

y_t = {}  # key: task, value: batch of the task, [D]
p = {}  # key: task, value: prototype, [D]

map = {}
for task in range(num_tasks): map[task] = []

already = set()
for task_idx, (x, _) in all_set:
    predict_task = cluster(x.reshape(-1, D), y_t, p)
    already.add(task_idx)
    print('{} : {}'.format(len(p), len(already)))

    map[task_idx].append(predict_task)
print(map)