In [10]:
from models import VCLModel, VanillaMLP
from models.kcoreset import KClusteringCoreset
from datasets import PermutedMnist

import matplotlib.pyplot as plt
import numpy as np

import os

import torch
import torch.nn.functional as F

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [14]:
NUM_EPOCHS = 50
LR = 1e-3
BATCH_SIZE = 256
TRAIN_NUM_SAMPLES = 10
TEST_NUM_SAMPLES = 100
CORESET_SIZE = 200
INIT_MLP = False

MLP_INIT_DESC = 'withoutmlpinit' if not INIT_MLP else 'withmlpinit'

random_seed = 1
torch.manual_seed(random_seed + 1)
np.random.seed(random_seed)

model_base_name = 'vcl_clusters_{}_lr_{}_{}_batch_{}_coresetsize_{}_epochs_{}_run_{}_no_tasks_{}'

In [15]:
no_tasks = 10

dataset = PermutedMnist(no_tasks)

In [16]:
NO_RUNS = 1

train_loaders, test_loaders = dataset.get_tasks(batch_size=BATCH_SIZE)
no_tasks = dataset.no_tasks

# Modify as need be.
epochs = [50]
coreset_sizes = [1000, 5000]
methods = ['kmeans']

for method in methods:
    METHOD = method
    
    clustering_coreset_selector = KClusteringCoreset(dataset)
    coreset_loaders, sans_coreset_loaders = clustering_coreset_selector.get_coreset_loaders(batch_size=BATCH_SIZE, coreset_size=CORESET_SIZE, method=METHOD) 
    
    for coreset_size in coreset_sizes:
        CORESET_SIZE = coreset_size
        
        for epoch in epochs:
            NUM_EPOCHS = epoch
            
            for run_number in range(NO_RUNS):
                model_name = model_base_name.format(METHOD, LR, MLP_INIT_DESC, BATCH_SIZE, CORESET_SIZE, NUM_EPOCHS, run_number, no_tasks) 
               
                LOG_DIR = 'logs/k_clusters/'
                os.makedirs(LOG_DIR, exist_ok=True)
                
                LOG_FILE_NAME = os.path.join(LOG_DIR, model_name) + '.txt'
                log_file_handler = open(LOG_FILE_NAME, "w")
                print("Run Number: {}. Log File: {}".format(run_number, LOG_FILE_NAME))
                
                train_loaders = sans_coreset_loaders
                
                model = VCLModel(784, 10, [100, 100])
                accuracies = np.zeros((no_tasks, no_tasks))
                
                for i in range(no_tasks):
                    train_loader = train_loaders[i]
                    
                    if INIT_MLP and i == 0:
                        print("Training MLP model to init first task")
                        mlp = VanillaMLP(784, 10, [100, 100])
                        mlp.train_model(NUM_EPOCHS, train_loader, LR)
                        model.init_mle(mlp)
                    
                    model.train_model(NUM_EPOCHS, train_loader, LR, TRAIN_NUM_SAMPLES)
                    
                    model.update_priors()
                    
                    # Train on coreset after training on non-coreset.
                    if CORESET_SIZE > 0:
                        coreset_loader = coreset_loaders[i]
                        model.train_model(NUM_EPOCHS, coreset_loader, LR, TRAIN_NUM_SAMPLES)
                        
                    task_accs = []
                    for j in range(i + 1):
                        test_loader = test_loaders[j]
                        accuracy = model.get_accuracy(test_loader, 100)
                        
                        msg = "[Number Tasks: {}]: Task {} Accuracy: {}".format(j + 1, j + 1, accuracy ) 
                        log_file_handler.write(msg + '\n')
                        log_file_handler.flush()
                        print(msg)
                        
                        task_accs.append(accuracy)
                        accuracies[i][j] = accuracy
                    
                    msg = "Task {} / {}. Mean Accuracy: {}".format(i + 1, no_tasks, np.mean(task_accs))
                    log_file_handler.write(msg + '\n')
                    log_file_handler.flush()
                    print(msg)
                log_file_handler.close()

  0%|          | 0/10 [00:00<?, ?it/s]

Processing class: 0
Processing class: 1
Processing class: 2
Processing class: 3
Processing class: 4
Processing class: 5
Processing class: 6
Processing class: 7
Processing class: 8
Processing class: 9


 10%|█         | 1/10 [00:42<06:26, 42.95s/it]

Processing class: 0
Processing class: 1
Processing class: 2
Processing class: 3
Processing class: 4
Processing class: 5
Processing class: 6
Processing class: 7
Processing class: 8
Processing class: 9


 20%|██        | 2/10 [01:16<04:57, 37.21s/it]

Processing class: 0
Processing class: 1
Processing class: 2
Processing class: 3
Processing class: 4
Processing class: 5
Processing class: 6
Processing class: 7
Processing class: 8
Processing class: 9


 30%|███       | 3/10 [01:49<04:07, 35.33s/it]

Processing class: 0
Processing class: 1
Processing class: 2
Processing class: 3
Processing class: 4
Processing class: 5
Processing class: 6
Processing class: 7
Processing class: 8
Processing class: 9


 40%|████      | 4/10 [02:23<03:28, 34.82s/it]

Processing class: 0
Processing class: 1
Processing class: 2
Processing class: 3
Processing class: 4
Processing class: 5
Processing class: 6
Processing class: 7
Processing class: 8
Processing class: 9


 50%|█████     | 5/10 [02:56<02:50, 34.09s/it]

Processing class: 0
Processing class: 1
Processing class: 2
Processing class: 3
Processing class: 4
Processing class: 5
Processing class: 6
Processing class: 7
Processing class: 8
Processing class: 9


 60%|██████    | 6/10 [03:31<02:17, 34.40s/it]

Processing class: 0
Processing class: 1
Processing class: 2
Processing class: 3
Processing class: 4
Processing class: 5
Processing class: 6
Processing class: 7
Processing class: 8
Processing class: 9


 70%|███████   | 7/10 [04:05<01:43, 34.44s/it]

Processing class: 0
Processing class: 1
Processing class: 2
Processing class: 3
Processing class: 4
Processing class: 5
Processing class: 6
Processing class: 7
Processing class: 8
Processing class: 9


 80%|████████  | 8/10 [04:39<01:08, 34.37s/it]

Processing class: 0
Processing class: 1
Processing class: 2
Processing class: 3
Processing class: 4
Processing class: 5
Processing class: 6
Processing class: 7
Processing class: 8
Processing class: 9


 90%|█████████ | 9/10 [05:14<00:34, 34.33s/it]

Processing class: 0
Processing class: 1
Processing class: 2
Processing class: 3
Processing class: 4
Processing class: 5
Processing class: 6
Processing class: 7
Processing class: 8
Processing class: 9


100%|██████████| 10/10 [05:47<00:00, 34.75s/it]


Run Number: 0. Log File: logs/k_clusters/vcl_clusters_kmeans_lr_0.001_withoutmlpinit_batch_256_coresetsize_1000_epochs_50_run_0_no_tasks_10.txt


Total Loss: 0.9106097733363127, KL: 0.8213848869005839, Lik Loss: 0.08922488543276604: 100%|██████████| 50/50 [04:59<00:00,  5.99s/it]
Total Loss: 0.0008123741135932505, KL: 0.0005741327768191695, Lik Loss: 0.000238241336774081: 100%|██████████| 50/50 [00:13<00:00,  3.70it/s]   


[Number Tasks: 1]: Task 1 Accuracy: 0.9804
Task 1 / 10. Mean Accuracy: 0.9804


Total Loss: 0.2886152234342363, KL: 0.12804170436838752, Lik Loss: 0.1605735179196056: 100%|██████████| 50/50 [05:00<00:00,  6.01s/it]  
Total Loss: 0.003928313031792641, KL: 0.0018627461977303028, Lik Loss: 0.002065566717647016: 100%|██████████| 50/50 [00:16<00:00,  2.96it/s]  


[Number Tasks: 1]: Task 1 Accuracy: 0.9695
[Number Tasks: 2]: Task 2 Accuracy: 0.9652
Task 2 / 10. Mean Accuracy: 0.9673499999999999


Total Loss: 0.2761793969533382, KL: 0.11980870264208215, Lik Loss: 0.15637069437493625: 100%|██████████| 50/50 [05:13<00:00,  6.26s/it] 
Total Loss: 0.005493680636088054, KL: 0.0025933715514838696, Lik Loss: 0.0029003091622143984: 100%|██████████| 50/50 [00:17<00:00,  2.82it/s]


[Number Tasks: 1]: Task 1 Accuracy: 0.9658
[Number Tasks: 2]: Task 2 Accuracy: 0.9611
[Number Tasks: 3]: Task 3 Accuracy: 0.9665
Task 3 / 10. Mean Accuracy: 0.9644666666666666


Total Loss: 0.2911843461358649, KL: 0.13084696252376604, Lik Loss: 0.16033738415338036: 100%|██████████| 50/50 [04:48<00:00,  5.77s/it] 
Total Loss: 0.009040696313604712, KL: 0.004285736242309213, Lik Loss: 0.004754960187710822: 100%|██████████| 50/50 [00:16<00:00,  3.02it/s] 


[Number Tasks: 1]: Task 1 Accuracy: 0.9645
[Number Tasks: 2]: Task 2 Accuracy: 0.9542
[Number Tasks: 3]: Task 3 Accuracy: 0.9621
[Number Tasks: 4]: Task 4 Accuracy: 0.9651
Task 4 / 10. Mean Accuracy: 0.961475


Total Loss: 0.2801150904379339, KL: 0.12402773597556302, Lik Loss: 0.1560873545578912: 100%|██████████| 50/50 [04:20<00:00,  5.20s/it]  
Total Loss: 0.010822712443768978, KL: 0.005284530343487859, Lik Loss: 0.0055381819838657975: 100%|██████████| 50/50 [00:16<00:00,  2.97it/s] 


[Number Tasks: 1]: Task 1 Accuracy: 0.9602
[Number Tasks: 2]: Task 2 Accuracy: 0.9503
[Number Tasks: 3]: Task 3 Accuracy: 0.9616
[Number Tasks: 4]: Task 4 Accuracy: 0.9625
[Number Tasks: 5]: Task 5 Accuracy: 0.9644
Task 5 / 10. Mean Accuracy: 0.9598000000000001


Total Loss: 0.28456659907968634, KL: 0.12703625112771988, Lik Loss: 0.15753034782460612: 100%|██████████| 50/50 [04:17<00:00,  5.15s/it]
Total Loss: 0.012661109678447247, KL: 0.005886852834373712, Lik Loss: 0.006774257030338049: 100%|██████████| 50/50 [00:18<00:00,  2.71it/s] 


[Number Tasks: 1]: Task 1 Accuracy: 0.9547
[Number Tasks: 2]: Task 2 Accuracy: 0.9442
[Number Tasks: 3]: Task 3 Accuracy: 0.9547
[Number Tasks: 4]: Task 4 Accuracy: 0.9573
[Number Tasks: 5]: Task 5 Accuracy: 0.9625
[Number Tasks: 6]: Task 6 Accuracy: 0.9622
Task 6 / 10. Mean Accuracy: 0.9559333333333334


Total Loss: 0.2851472428848601, KL: 0.12624977166071916, Lik Loss: 0.1588974715743819: 100%|██████████| 50/50 [04:17<00:00,  5.15s/it]  
Total Loss: 0.014337186235934496, KL: 0.007118773569042484, Lik Loss: 0.007218412666892012: 100%|██████████| 50/50 [00:19<00:00,  2.53it/s] 


[Number Tasks: 1]: Task 1 Accuracy: 0.955
[Number Tasks: 2]: Task 2 Accuracy: 0.9377
[Number Tasks: 3]: Task 3 Accuracy: 0.9482
[Number Tasks: 4]: Task 4 Accuracy: 0.9539
[Number Tasks: 5]: Task 5 Accuracy: 0.9581
[Number Tasks: 6]: Task 6 Accuracy: 0.9571
[Number Tasks: 7]: Task 7 Accuracy: 0.9616
Task 7 / 10. Mean Accuracy: 0.9530857142857142


Total Loss: 0.2891560975685079, KL: 0.1285900578666956, Lik Loss: 0.1605660396699722: 100%|██████████| 50/50 [04:15<00:00,  5.11s/it]   
Total Loss: 0.01570473052561283, KL: 0.008239557700497764, Lik Loss: 0.007465172625545945: 100%|██████████| 50/50 [00:20<00:00,  2.42it/s]  


[Number Tasks: 1]: Task 1 Accuracy: 0.9524
[Number Tasks: 2]: Task 2 Accuracy: 0.9382
[Number Tasks: 3]: Task 3 Accuracy: 0.9463
[Number Tasks: 4]: Task 4 Accuracy: 0.9514
[Number Tasks: 5]: Task 5 Accuracy: 0.9557
[Number Tasks: 6]: Task 6 Accuracy: 0.9532
[Number Tasks: 7]: Task 7 Accuracy: 0.9602
[Number Tasks: 8]: Task 8 Accuracy: 0.9615
Task 8 / 10. Mean Accuracy: 0.9523625


Total Loss: 0.3023987409268689, KL: 0.13490087755470195, Lik Loss: 0.16749786340400705: 100%|██████████| 50/50 [04:21<00:00,  5.24s/it] 
Total Loss: 0.023137051379308105, KL: 0.00988178770057857, Lik Loss: 0.01325526344589889: 100%|██████████| 50/50 [00:21<00:00,  2.31it/s]  


[Number Tasks: 1]: Task 1 Accuracy: 0.9476
[Number Tasks: 2]: Task 2 Accuracy: 0.9306
[Number Tasks: 3]: Task 3 Accuracy: 0.9415
[Number Tasks: 4]: Task 4 Accuracy: 0.9468
[Number Tasks: 5]: Task 5 Accuracy: 0.9501
[Number Tasks: 6]: Task 6 Accuracy: 0.9517
[Number Tasks: 7]: Task 7 Accuracy: 0.9547
[Number Tasks: 8]: Task 8 Accuracy: 0.96
[Number Tasks: 9]: Task 9 Accuracy: 0.9617
Task 9 / 10. Mean Accuracy: 0.9494111111111112


Total Loss: 0.31030034866088474, KL: 0.13952010456058714, Lik Loss: 0.17078024368637648: 100%|██████████| 50/50 [04:42<00:00,  5.66s/it]
Total Loss: 0.02163622062653303, KL: 0.011022720136679709, Lik Loss: 0.010613500606268644: 100%|██████████| 50/50 [00:23<00:00,  2.15it/s] 


[Number Tasks: 1]: Task 1 Accuracy: 0.9448
[Number Tasks: 2]: Task 2 Accuracy: 0.9276
[Number Tasks: 3]: Task 3 Accuracy: 0.9393
[Number Tasks: 4]: Task 4 Accuracy: 0.9428
[Number Tasks: 5]: Task 5 Accuracy: 0.9472
[Number Tasks: 6]: Task 6 Accuracy: 0.9461
[Number Tasks: 7]: Task 7 Accuracy: 0.9495
[Number Tasks: 8]: Task 8 Accuracy: 0.9552
[Number Tasks: 9]: Task 9 Accuracy: 0.9583
[Number Tasks: 10]: Task 10 Accuracy: 0.9599
Task 10 / 10. Mean Accuracy: 0.9470699999999999
Run Number: 0. Log File: logs/k_clusters/vcl_clusters_kmeans_lr_0.001_withoutmlpinit_batch_256_coresetsize_5000_epochs_50_run_0_no_tasks_10.txt


Total Loss: 0.9270972505084469, KL: 0.8374140573363019, Lik Loss: 0.08968319341094576: 100%|██████████| 50/50 [05:12<00:00,  6.25s/it]
Total Loss: 0.0008808227139525115, KL: 0.0006382733699865639, Lik Loss: 0.0002425493294140324: 100%|██████████| 50/50 [00:13<00:00,  3.67it/s]  


[Number Tasks: 1]: Task 1 Accuracy: 0.982
Task 1 / 10. Mean Accuracy: 0.982


Total Loss: 0.298520256120425, KL: 0.1291269316123082, Lik Loss: 0.16939332431707627: 100%|██████████| 50/50 [04:52<00:00,  5.86s/it]   
Total Loss: 0.004713722039014101, KL: 0.0021947257919237018, Lik Loss: 0.00251899630529806: 100%|██████████| 50/50 [00:15<00:00,  3.20it/s]   


[Number Tasks: 1]: Task 1 Accuracy: 0.9743
[Number Tasks: 2]: Task 2 Accuracy: 0.9628
Task 2 / 10. Mean Accuracy: 0.96855


Total Loss: 0.27830978291921127, KL: 0.1210331233839194, Lik Loss: 0.15727665966265222: 100%|██████████| 50/50 [05:01<00:00,  6.03s/it] 
Total Loss: 0.00671516094977657, KL: 0.0032301320073505244, Lik Loss: 0.00348502902003626: 100%|██████████| 50/50 [00:17<00:00,  2.88it/s]    


[Number Tasks: 1]: Task 1 Accuracy: 0.9705
[Number Tasks: 2]: Task 2 Accuracy: 0.9593
[Number Tasks: 3]: Task 3 Accuracy: 0.9656
Task 3 / 10. Mean Accuracy: 0.9651333333333335


Total Loss: 0.2804910606171331, KL: 0.1225872685193506, Lik Loss: 0.15790379168386134: 100%|██████████| 50/50 [05:04<00:00,  6.08s/it]  
Total Loss: 0.01044404273852706, KL: 0.004044418688863516, Lik Loss: 0.0063996238168329: 100%|██████████| 50/50 [00:18<00:00,  2.72it/s]    


[Number Tasks: 1]: Task 1 Accuracy: 0.967
[Number Tasks: 2]: Task 2 Accuracy: 0.9534
[Number Tasks: 3]: Task 3 Accuracy: 0.9619
[Number Tasks: 4]: Task 4 Accuracy: 0.9657
Task 4 / 10. Mean Accuracy: 0.962


Total Loss: 0.27365835851583725, KL: 0.11965774958077659, Lik Loss: 0.15400060982658312: 100%|██████████| 50/50 [05:09<00:00,  6.19s/it]
Total Loss: 0.01104302378371358, KL: 0.0054869180312380195, Lik Loss: 0.005556105636060238: 100%|██████████| 50/50 [00:19<00:00,  2.56it/s] 


[Number Tasks: 1]: Task 1 Accuracy: 0.9619
[Number Tasks: 2]: Task 2 Accuracy: 0.947
[Number Tasks: 3]: Task 3 Accuracy: 0.9567
[Number Tasks: 4]: Task 4 Accuracy: 0.9615
[Number Tasks: 5]: Task 5 Accuracy: 0.9636
Task 5 / 10. Mean Accuracy: 0.95814


Total Loss: 0.28475108195064414, KL: 0.12683621647520962, Lik Loss: 0.15791486531623408: 100%|██████████| 50/50 [05:16<00:00,  6.34s/it]
Total Loss: 0.014714601449668407, KL: 0.006738590635359287, Lik Loss: 0.00797601081430912: 100%|██████████| 50/50 [00:21<00:00,  2.38it/s]  


[Number Tasks: 1]: Task 1 Accuracy: 0.9573
[Number Tasks: 2]: Task 2 Accuracy: 0.9419
[Number Tasks: 3]: Task 3 Accuracy: 0.951
[Number Tasks: 4]: Task 4 Accuracy: 0.9558
[Number Tasks: 5]: Task 5 Accuracy: 0.9585
[Number Tasks: 6]: Task 6 Accuracy: 0.9636
Task 6 / 10. Mean Accuracy: 0.9546833333333332


Total Loss: 0.2839383624303035, KL: 0.12672890971104303, Lik Loss: 0.15720945249637988: 100%|██████████| 50/50 [04:55<00:00,  5.92s/it] 
Total Loss: 0.016679708225031693, KL: 0.007994025169561306, Lik Loss: 0.008685683133080602: 100%|██████████| 50/50 [00:20<00:00,  2.40it/s] 


[Number Tasks: 1]: Task 1 Accuracy: 0.9548
[Number Tasks: 2]: Task 2 Accuracy: 0.9345
[Number Tasks: 3]: Task 3 Accuracy: 0.9457
[Number Tasks: 4]: Task 4 Accuracy: 0.9526
[Number Tasks: 5]: Task 5 Accuracy: 0.9557
[Number Tasks: 6]: Task 6 Accuracy: 0.9599
[Number Tasks: 7]: Task 7 Accuracy: 0.9632
Task 7 / 10. Mean Accuracy: 0.952342857142857


Total Loss: 0.3006718086126523, KL: 0.1326124561138642, Lik Loss: 0.16805935176646608: 100%|██████████| 50/50 [04:59<00:00,  5.98s/it]  
Total Loss: 0.018881171941757202, KL: 0.009622908596481596, Lik Loss: 0.009258263278752565: 100%|██████████| 50/50 [00:22<00:00,  2.19it/s]


[Number Tasks: 1]: Task 1 Accuracy: 0.9515
[Number Tasks: 2]: Task 2 Accuracy: 0.9347
[Number Tasks: 3]: Task 3 Accuracy: 0.9404
[Number Tasks: 4]: Task 4 Accuracy: 0.9433
[Number Tasks: 5]: Task 5 Accuracy: 0.9518
[Number Tasks: 6]: Task 6 Accuracy: 0.9539
[Number Tasks: 7]: Task 7 Accuracy: 0.958
[Number Tasks: 8]: Task 8 Accuracy: 0.9586
Task 8 / 10. Mean Accuracy: 0.949025


Total Loss: 0.3108655105416591, KL: 0.13719172329984158, Lik Loss: 0.17367378717813736: 100%|██████████| 50/50 [05:05<00:00,  6.12s/it] 
Total Loss: 0.023140120087191463, KL: 0.011127818375825882, Lik Loss: 0.012012301594950259: 100%|██████████| 50/50 [00:23<00:00,  2.13it/s]


[Number Tasks: 1]: Task 1 Accuracy: 0.9488
[Number Tasks: 2]: Task 2 Accuracy: 0.9263
[Number Tasks: 3]: Task 3 Accuracy: 0.9353
[Number Tasks: 4]: Task 4 Accuracy: 0.9394
[Number Tasks: 5]: Task 5 Accuracy: 0.947
[Number Tasks: 6]: Task 6 Accuracy: 0.9502
[Number Tasks: 7]: Task 7 Accuracy: 0.9512
[Number Tasks: 8]: Task 8 Accuracy: 0.9556
[Number Tasks: 9]: Task 9 Accuracy: 0.9575
Task 9 / 10. Mean Accuracy: 0.9457


Total Loss: 0.30728894841467214, KL: 0.1350776156426495, Lik Loss: 0.1722113328675429: 100%|██████████| 50/50 [04:47<00:00,  5.75s/it]  
Total Loss: 0.026060479460284114, KL: 0.013166406308300793, Lik Loss: 0.01289407315198332: 100%|██████████| 50/50 [00:21<00:00,  2.36it/s] 


[Number Tasks: 1]: Task 1 Accuracy: 0.9475
[Number Tasks: 2]: Task 2 Accuracy: 0.9231
[Number Tasks: 3]: Task 3 Accuracy: 0.9332
[Number Tasks: 4]: Task 4 Accuracy: 0.9364
[Number Tasks: 5]: Task 5 Accuracy: 0.9465
[Number Tasks: 6]: Task 6 Accuracy: 0.9468
[Number Tasks: 7]: Task 7 Accuracy: 0.9487
[Number Tasks: 8]: Task 8 Accuracy: 0.9508
[Number Tasks: 9]: Task 9 Accuracy: 0.9547
[Number Tasks: 10]: Task 10 Accuracy: 0.9585
Task 10 / 10. Mean Accuracy: 0.9446200000000001


In [4]:
kcc =KClusteringCoreset(dataset)
kcc.get_coreset_datasets(200, method='kcenter_greedy')

Computing Cluster Coresets...:  20%|██        | 2/10 [00:49<03:16, 24.59s/it]


KeyboardInterrupt: 