In [10]:
from models.kcoreset import KClusteringCoreset
from models.vanilla_mlp import VanillaMLP
from models.vcl import VCLModel
from datasets import PermutedMnist
import matplotlib.pyplot as plt

import numpy as np

import os

import torch
import torch.nn.functional as F

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
num_epochs = 50
lr = 1e-3

num_samples = 10
NO_RUNS = 3

random_seed = 1
model_type = 'mlp'
method = 'kcenter_greedy'

torch.manual_seed(random_seed + 1)
np.random.seed(random_seed)

model_base_name = 'clustering_coreset_only_{}_model_type_{}_lr_{}_batch_{}_coresetsize_{}_epochs_{}_run_{}'

In [12]:
batch_size = 256
dataset = PermutedMnist(10)

In [13]:
clustering_coreset = KClusteringCoreset(dataset)
_, test_loaders = dataset.get_tasks(batch_size)

methods = ['kcenter_greedy', 'kmedians']
coreset_sizes = [5000]

for method in methods:
    for coreset_size in coreset_sizes:
        coreset_loaders, _ = clustering_coreset.get_coreset_loaders(method=method, batch_size=coreset_size, coreset_size=coreset_size)
        for run_number in range(NO_RUNS):
            model = VanillaMLP(784, 10, [100, 100])
            
            no_tasks = len(coreset_loaders)
            accuracies = np.zeros((no_tasks, no_tasks))
            
            model_name = model_base_name.format(method, model_type, lr, batch_size, coreset_size, num_epochs, run_number) 
            
            LOG_DIR = 'logs/cluster_coreset_only'
            os.makedirs(LOG_DIR, exist_ok=True)
            
            LOG_FILE_NAME = os.path.join(LOG_DIR, model_name) + '.txt'
            log_file_handler = open(LOG_FILE_NAME, "w")
            print("Run Number: {}. Log File: {}".format(run_number, LOG_FILE_NAME))
            
            for i in range(no_tasks):
                train_loader = coreset_loaders[i]
            
                model.train_model(num_epochs, train_loader, lr)
                task_accs = []
                for j in range(i + 1):
                    test_loader =  test_loaders[j]
                    accuracy = model.get_accuracy(test_loader)
                    
                    msg = "[Number Tasks: {}]: Task {} Accuracy: {}".format(j + 1, j + 1, accuracy ) 
                    log_file_handler.write(msg + '\n')
                    log_file_handler.flush()
                    print(msg)
                    
                    task_accs.append(accuracy)
                    accuracies[i][j] = accuracy
                msg = "Task {} / {}. Mean Accuracy: {}".format(i + 1, no_tasks, np.mean(task_accs))
                log_file_handler.write(msg + '\n')
                log_file_handler.flush()
                print(msg)

 10%|█         | 1/10 [22:59<3:26:59, 1379.92s/it]


KeyboardInterrupt: 