In [1]:
from models import VCLModel, VanillaMLP
from models.coreset import RandomCoreset
from datasets import PermutedMnist

import matplotlib.pyplot as plt
import numpy as np

import os

import torch
import torch.nn.functional as F

%load_ext autoreload
%autoreload 2

In [14]:
NUM_EPOCHS = 100
LR = 1e-3
TRAIN_NUM_SAMPLES = 10
TEST_NUM_SAMPLES = 100
CORESET_SIZE = 200
BATCH_SIZE = 256
INIT_MLP = False

MLP_INIT_DESC = 'withoutmlpinit' if not INIT_MLP else 'withmlpinit'

random_seed = 1
torch.manual_seed(random_seed + 1)
np.random.seed(random_seed)

model_base_name = 'vcl_lr_{}_{}_batch_{}_coresetsize_{}_epochs_{}_run_{}_no_tasks_{}_no_train_samples_{}'

In [15]:
no_tasks = 10

dataset = PermutedMnist(no_tasks)

In [None]:
NO_RUNS = 3


no_tasks = dataset.no_tasks

# Modify as need be.
lrs = [1e-2, 1e-4]
train_samples = [50, 100]

epochs = [50]
coreset_sizes = [0]

for train_sample in train_samples:
    TRAIN_NUM_SAMPLES = train_sample
    train_loaders, test_loaders = dataset.get_tasks(batch_size=BATCH_SIZE)
    for epoch in epochs:
        NUM_EPOCHS = epoch
        for coreset_size in coreset_sizes:
            CORESET_SIZE = coreset_size
            for run_number in range(NO_RUNS):
                model_name = model_base_name.format(LR, MLP_INIT_DESC, BATCH_SIZE, CORESET_SIZE, NUM_EPOCHS, run_number, no_tasks, TRAIN_NUM_SAMPLES) 
               
                LOG_DIR = 'logs/'
                os.makedirs(LOG_DIR, exist_ok=True)
                
                LOG_FILE_NAME = os.path.join(LOG_DIR, model_name) + '.txt'
                log_file_handler = open(LOG_FILE_NAME, "w")
                print("Run Number: {}. Log File: {}".format(run_number, LOG_FILE_NAME))
                
                if CORESET_SIZE > 0:
                    random_coreset_selector = RandomCoreset(dataset)
                    coreset_loaders, sans_coreset_loaders = random_coreset_selector.get_coreset_loaders(batch_size=BATCH_SIZE, coreset_size=CORESET_SIZE)    
                    
                    # On each task, first train on non-coreset data. 
                    # Else, don't modify train_loaders (train on all the training data).
                    train_loaders = sans_coreset_loaders
                
                model = VCLModel(784, 10, [100, 100])
                accuracies = np.zeros((no_tasks, no_tasks))
                
                for i in range(no_tasks):
                    train_loader = train_loaders[i]
                    
                    if INIT_MLP and i == 0:
                        print("Training MLP model to init first task")
                        mlp = VanillaMLP(784, 10, [100, 100])
                        mlp.train_model(NUM_EPOCHS, train_loader, LR)
                        model.init_mle(mlp)
                    
                    model.train_model(NUM_EPOCHS, train_loader, LR, TRAIN_NUM_SAMPLES)
                    
                    model.update_priors()
                    
                    # Train on coreset after training on non-coreset.
                    if CORESET_SIZE > 0:
                        coreset_loader = coreset_loaders[i]
                        model.train_model(NUM_EPOCHS, coreset_loader, LR, TRAIN_NUM_SAMPLES)
                        
                    task_accs = []
                    for j in range(i + 1):
                        test_loader = test_loaders[j]
                        accuracy = model.get_accuracy(test_loader, 100)
                        
                        msg = "[Number Tasks: {}]: Task {} Accuracy: {}".format(j + 1, j + 1, accuracy ) 
                        log_file_handler.write(msg + '\n')
                        log_file_handler.flush()
                        print(msg)
                        
                        task_accs.append(accuracy)
                        accuracies[i][j] = accuracy
                    
                    msg = "Task {} / {}. Mean Accuracy: {}".format(i + 1, no_tasks, np.mean(task_accs))
                    log_file_handler.write(msg + '\n')
                    log_file_handler.flush()
                    print(msg)
                log_file_handler.close()

Run Number: 0. Log File: logs/vcl_lr_0.001_withoutmlpinit_batch_256_coresetsize_0_epochs_50_run_0_no_tasks_10_no_train_samples_50.txt


Total Loss: 0.48908327691098474, KL: 0.3679376728991245, Lik Loss: 0.1211456041703833: 100%|██████████| 50/50 [03:34<00:00,  4.30s/it] 


[Number Tasks: 1]: Task 1 Accuracy: 0.9788
Task 1 / 10. Mean Accuracy: 0.9788


Total Loss: 0.2880435473741369, KL: 0.12966347279700827, Lik Loss: 0.1583800754331528: 100%|██████████| 50/50 [03:42<00:00,  4.46s/it]  


[Number Tasks: 1]: Task 1 Accuracy: 0.9642
[Number Tasks: 2]: Task 2 Accuracy: 0.9709
Task 2 / 10. Mean Accuracy: 0.9675499999999999


Total Loss: 0.27553309614353994, KL: 0.1213170366718414, Lik Loss: 0.15421606070817784: 100%|██████████| 50/50 [03:40<00:00,  4.42s/it] 


[Number Tasks: 1]: Task 1 Accuracy: 0.9548
[Number Tasks: 2]: Task 2 Accuracy: 0.9596
[Number Tasks: 3]: Task 3 Accuracy: 0.9673
Task 3 / 10. Mean Accuracy: 0.9605666666666668


Total Loss: 0.27878701572722575, KL: 0.12303725259101138, Lik Loss: 0.15574976275575922: 100%|██████████| 50/50 [03:43<00:00,  4.47s/it]


[Number Tasks: 1]: Task 1 Accuracy: 0.9437
[Number Tasks: 2]: Task 2 Accuracy: 0.9528
[Number Tasks: 3]: Task 3 Accuracy: 0.9647
[Number Tasks: 4]: Task 4 Accuracy: 0.9665
Task 4 / 10. Mean Accuracy: 0.956925


Total Loss: 0.2787152958043078, KL: 0.12152422381208298, Lik Loss: 0.15719107154836046: 100%|██████████| 50/50 [03:47<00:00,  4.56s/it] 


[Number Tasks: 1]: Task 1 Accuracy: 0.9314
[Number Tasks: 2]: Task 2 Accuracy: 0.9484
[Number Tasks: 3]: Task 3 Accuracy: 0.9597
[Number Tasks: 4]: Task 4 Accuracy: 0.9616
[Number Tasks: 5]: Task 5 Accuracy: 0.9662
Task 5 / 10. Mean Accuracy: 0.95346


Total Loss: 0.28452814690610195, KL: 0.12584469587878977, Lik Loss: 0.15868345137606277: 100%|██████████| 50/50 [03:53<00:00,  4.67s/it]


[Number Tasks: 1]: Task 1 Accuracy: 0.9106
[Number Tasks: 2]: Task 2 Accuracy: 0.932
[Number Tasks: 3]: Task 3 Accuracy: 0.9451
[Number Tasks: 4]: Task 4 Accuracy: 0.9536
[Number Tasks: 5]: Task 5 Accuracy: 0.9554
[Number Tasks: 6]: Task 6 Accuracy: 0.9652
Task 6 / 10. Mean Accuracy: 0.94365


Total Loss: 0.2843763279153946, KL: 0.12399228551286332, Lik Loss: 0.16038404084900593: 100%|██████████| 50/50 [03:47<00:00,  4.55s/it] 


[Number Tasks: 1]: Task 1 Accuracy: 0.863
[Number Tasks: 2]: Task 2 Accuracy: 0.9027
[Number Tasks: 3]: Task 3 Accuracy: 0.9206
[Number Tasks: 4]: Task 4 Accuracy: 0.9284
[Number Tasks: 5]: Task 5 Accuracy: 0.9461
[Number Tasks: 6]: Task 6 Accuracy: 0.9564
[Number Tasks: 7]: Task 7 Accuracy: 0.9646
Task 7 / 10. Mean Accuracy: 0.9259714285714286


Total Loss: 0.2935102131138457, KL: 0.1295607733599683, Lik Loss: 0.16394943994410494: 100%|██████████| 50/50 [03:46<00:00,  4.54s/it]  


[Number Tasks: 1]: Task 1 Accuracy: 0.8427
[Number Tasks: 2]: Task 2 Accuracy: 0.9099
[Number Tasks: 3]: Task 3 Accuracy: 0.9081
[Number Tasks: 4]: Task 4 Accuracy: 0.9208
[Number Tasks: 5]: Task 5 Accuracy: 0.9321
[Number Tasks: 6]: Task 6 Accuracy: 0.9469
[Number Tasks: 7]: Task 7 Accuracy: 0.9567
[Number Tasks: 8]: Task 8 Accuracy: 0.9606
Task 8 / 10. Mean Accuracy: 0.9222250000000001


Total Loss: 0.320349402693992, KL: 0.14434116267143413, Lik Loss: 0.1760082409736958: 100%|██████████| 50/50 [03:44<00:00,  4.50s/it]   


[Number Tasks: 1]: Task 1 Accuracy: 0.8428
[Number Tasks: 2]: Task 2 Accuracy: 0.8806
[Number Tasks: 3]: Task 3 Accuracy: 0.8741
[Number Tasks: 4]: Task 4 Accuracy: 0.8869
[Number Tasks: 5]: Task 5 Accuracy: 0.9052
[Number Tasks: 6]: Task 6 Accuracy: 0.9211
[Number Tasks: 7]: Task 7 Accuracy: 0.9424
[Number Tasks: 8]: Task 8 Accuracy: 0.9512
[Number Tasks: 9]: Task 9 Accuracy: 0.9603
Task 9 / 10. Mean Accuracy: 0.9071777777777777


Total Loss: 0.3073765841570306, KL: 0.13659645908690513, Lik Loss: 0.1707801250067163: 100%|██████████| 50/50 [03:49<00:00,  4.59s/it]  


[Number Tasks: 1]: Task 1 Accuracy: 0.8501
[Number Tasks: 2]: Task 2 Accuracy: 0.8488
[Number Tasks: 3]: Task 3 Accuracy: 0.8622
[Number Tasks: 4]: Task 4 Accuracy: 0.8695
[Number Tasks: 5]: Task 5 Accuracy: 0.8952
[Number Tasks: 6]: Task 6 Accuracy: 0.8886
[Number Tasks: 7]: Task 7 Accuracy: 0.9363
[Number Tasks: 8]: Task 8 Accuracy: 0.9405
[Number Tasks: 9]: Task 9 Accuracy: 0.9526
[Number Tasks: 10]: Task 10 Accuracy: 0.9606
Task 10 / 10. Mean Accuracy: 0.9004399999999999
Run Number: 1. Log File: logs/vcl_lr_0.001_withoutmlpinit_batch_256_coresetsize_0_epochs_50_run_1_no_tasks_10_no_train_samples_50.txt


Total Loss: 0.5003423432086377, KL: 0.37765789513892317, Lik Loss: 0.1226844476892593: 100%|██████████| 50/50 [03:49<00:00,  4.59s/it] 


[Number Tasks: 1]: Task 1 Accuracy: 0.9799
Task 1 / 10. Mean Accuracy: 0.9799


Total Loss: 0.299081681700463, KL: 0.13249990604025252, Lik Loss: 0.16658177597725646: 100%|██████████| 50/50 [03:49<00:00,  4.58s/it]  


[Number Tasks: 1]: Task 1 Accuracy: 0.9351
[Number Tasks: 2]: Task 2 Accuracy: 0.9668
Task 2 / 10. Mean Accuracy: 0.95095


Total Loss: 0.28711239429230384, KL: 0.12413130665713168, Lik Loss: 0.1629810889033561: 100%|██████████| 50/50 [03:49<00:00,  4.60s/it] 


[Number Tasks: 1]: Task 1 Accuracy: 0.9313
[Number Tasks: 2]: Task 2 Accuracy: 0.9617
[Number Tasks: 3]: Task 3 Accuracy: 0.9679
Task 3 / 10. Mean Accuracy: 0.9536333333333333


Total Loss: 0.28044854212314524, KL: 0.11976170508151358, Lik Loss: 0.16068683732697303: 100%|██████████| 50/50 [04:27<00:00,  5.35s/it]


[Number Tasks: 1]: Task 1 Accuracy: 0.8956
[Number Tasks: 2]: Task 2 Accuracy: 0.942
[Number Tasks: 3]: Task 3 Accuracy: 0.9598
[Number Tasks: 4]: Task 4 Accuracy: 0.9627
Task 4 / 10. Mean Accuracy: 0.9400249999999999


Total Loss: 0.28329336510059683, KL: 0.12046383930013535, Lik Loss: 0.1628295252931879: 100%|██████████| 50/50 [04:41<00:00,  5.63s/it] 


[Number Tasks: 1]: Task 1 Accuracy: 0.8758
[Number Tasks: 2]: Task 2 Accuracy: 0.9385
[Number Tasks: 3]: Task 3 Accuracy: 0.9498
[Number Tasks: 4]: Task 4 Accuracy: 0.9595
[Number Tasks: 5]: Task 5 Accuracy: 0.9648
Task 5 / 10. Mean Accuracy: 0.9376800000000001


Total Loss: 0.29670211178191164, KL: 0.13098837073813094, Lik Loss: 0.1657137410754853: 100%|██████████| 50/50 [05:13<00:00,  6.27s/it] 


[Number Tasks: 1]: Task 1 Accuracy: 0.782
[Number Tasks: 2]: Task 2 Accuracy: 0.9169
[Number Tasks: 3]: Task 3 Accuracy: 0.9386
[Number Tasks: 4]: Task 4 Accuracy: 0.9498
[Number Tasks: 5]: Task 5 Accuracy: 0.9585
[Number Tasks: 6]: Task 6 Accuracy: 0.9634
Task 6 / 10. Mean Accuracy: 0.9182


Total Loss: 0.2958703151408662, KL: 0.12801106242423363, Lik Loss: 0.16785925363606594: 100%|██████████| 50/50 [04:56<00:00,  5.93s/it] 


[Number Tasks: 1]: Task 1 Accuracy: 0.7676
[Number Tasks: 2]: Task 2 Accuracy: 0.9049
[Number Tasks: 3]: Task 3 Accuracy: 0.9301
[Number Tasks: 4]: Task 4 Accuracy: 0.9368
[Number Tasks: 5]: Task 5 Accuracy: 0.9506
[Number Tasks: 6]: Task 6 Accuracy: 0.9567
[Number Tasks: 7]: Task 7 Accuracy: 0.9619
Task 7 / 10. Mean Accuracy: 0.9155142857142856


Total Loss: 0.29875755639786417, KL: 0.1274625466859087, Lik Loss: 0.17129501018752444: 100%|██████████| 50/50 [08:24<00:00, 10.10s/it] 


[Number Tasks: 1]: Task 1 Accuracy: 0.7706
[Number Tasks: 2]: Task 2 Accuracy: 0.8967
[Number Tasks: 3]: Task 3 Accuracy: 0.9266
[Number Tasks: 4]: Task 4 Accuracy: 0.9202
[Number Tasks: 5]: Task 5 Accuracy: 0.939
[Number Tasks: 6]: Task 6 Accuracy: 0.9495
[Number Tasks: 7]: Task 7 Accuracy: 0.9529
[Number Tasks: 8]: Task 8 Accuracy: 0.9626
Task 8 / 10. Mean Accuracy: 0.9147625


Total Loss: 0.3042948330970521, KL: 0.12754030608116312, Lik Loss: 0.17675452739634412: 100%|██████████| 50/50 [08:38<00:00, 10.37s/it] 


[Number Tasks: 1]: Task 1 Accuracy: 0.768
[Number Tasks: 2]: Task 2 Accuracy: 0.8802
[Number Tasks: 3]: Task 3 Accuracy: 0.9062
[Number Tasks: 4]: Task 4 Accuracy: 0.9031
[Number Tasks: 5]: Task 5 Accuracy: 0.9204
[Number Tasks: 6]: Task 6 Accuracy: 0.9452
[Number Tasks: 7]: Task 7 Accuracy: 0.9304
[Number Tasks: 8]: Task 8 Accuracy: 0.9564
[Number Tasks: 9]: Task 9 Accuracy: 0.9572
Task 9 / 10. Mean Accuracy: 0.9074555555555555


Total Loss: 0.31158628146699135, KL: 0.13571406114608683, Lik Loss: 0.17587222003556313: 100%|██████████| 50/50 [05:40<00:00,  6.81s/it]


[Number Tasks: 1]: Task 1 Accuracy: 0.7448
[Number Tasks: 2]: Task 2 Accuracy: 0.8459
[Number Tasks: 3]: Task 3 Accuracy: 0.8923
[Number Tasks: 4]: Task 4 Accuracy: 0.8933
[Number Tasks: 5]: Task 5 Accuracy: 0.9115
[Number Tasks: 6]: Task 6 Accuracy: 0.9291
[Number Tasks: 7]: Task 7 Accuracy: 0.9201
[Number Tasks: 8]: Task 8 Accuracy: 0.9522
[Number Tasks: 9]: Task 9 Accuracy: 0.9512
[Number Tasks: 10]: Task 10 Accuracy: 0.9598
Task 10 / 10. Mean Accuracy: 0.9000199999999999
Run Number: 2. Log File: logs/vcl_lr_0.001_withoutmlpinit_batch_256_coresetsize_0_epochs_50_run_2_no_tasks_10_no_train_samples_50.txt


Total Loss: 0.48145367985076093, KL: 0.3616487482760815, Lik Loss: 0.11980492938706215: 100%|██████████| 50/50 [04:17<00:00,  5.14s/it]


[Number Tasks: 1]: Task 1 Accuracy: 0.9803
Task 1 / 10. Mean Accuracy: 0.9803


Total Loss: 0.28471452631848926, KL: 0.12990550506622234, Lik Loss: 0.15480902226681406: 100%|██████████| 50/50 [04:16<00:00,  5.13s/it]


[Number Tasks: 1]: Task 1 Accuracy: 0.9676
[Number Tasks: 2]: Task 2 Accuracy: 0.9715
Task 2 / 10. Mean Accuracy: 0.96955


Total Loss: 0.2884335774690547, KL: 0.13220717913292823, Lik Loss: 0.1562263992238552: 100%|██████████| 50/50 [04:43<00:00,  5.67s/it]  


[Number Tasks: 1]: Task 1 Accuracy: 0.957
[Number Tasks: 2]: Task 2 Accuracy: 0.9564
[Number Tasks: 3]: Task 3 Accuracy: 0.9685
Task 3 / 10. Mean Accuracy: 0.9606333333333333


Total Loss: 0.2812632905041918, KL: 0.12585925069895196, Lik Loss: 0.15540403894921567: 100%|██████████| 50/50 [04:36<00:00,  5.53s/it] 


[Number Tasks: 1]: Task 1 Accuracy: 0.9421
[Number Tasks: 2]: Task 2 Accuracy: 0.9536
[Number Tasks: 3]: Task 3 Accuracy: 0.9642
[Number Tasks: 4]: Task 4 Accuracy: 0.9672
Task 4 / 10. Mean Accuracy: 0.956775


Total Loss: 0.28330349034451424, KL: 0.12668268388890205, Lik Loss: 0.15662080677265816: 100%|██████████| 50/50 [04:21<00:00,  5.23s/it]


[Number Tasks: 1]: Task 1 Accuracy: 0.9157
[Number Tasks: 2]: Task 2 Accuracy: 0.9354
[Number Tasks: 3]: Task 3 Accuracy: 0.9494
[Number Tasks: 4]: Task 4 Accuracy: 0.9612
[Number Tasks: 5]: Task 5 Accuracy: 0.9649
Task 5 / 10. Mean Accuracy: 0.94532


Total Loss: 0.2899506405947056, KL: 0.127391993936072, Lik Loss: 0.16255864738783937: 100%|██████████| 50/50 [04:29<00:00,  5.39s/it]   


[Number Tasks: 1]: Task 1 Accuracy: 0.8812
[Number Tasks: 2]: Task 2 Accuracy: 0.9171
[Number Tasks: 3]: Task 3 Accuracy: 0.9257
[Number Tasks: 4]: Task 4 Accuracy: 0.9521
[Number Tasks: 5]: Task 5 Accuracy: 0.9588
[Number Tasks: 6]: Task 6 Accuracy: 0.9636
Task 6 / 10. Mean Accuracy: 0.9330833333333333


Total Loss: 0.29930424753655777, KL: 0.13703879168693056, Lik Loss: 0.16226545548502436: 100%|██████████| 50/50 [04:27<00:00,  5.36s/it]


[Number Tasks: 1]: Task 1 Accuracy: 0.8665
[Number Tasks: 2]: Task 2 Accuracy: 0.9088
[Number Tasks: 3]: Task 3 Accuracy: 0.9036
[Number Tasks: 4]: Task 4 Accuracy: 0.9314
[Number Tasks: 5]: Task 5 Accuracy: 0.9496
[Number Tasks: 6]: Task 6 Accuracy: 0.9565
[Number Tasks: 7]: Task 7 Accuracy: 0.9642
Task 7 / 10. Mean Accuracy: 0.9258


Total Loss: 0.30261132304972793, KL: 0.1350346732646861, Lik Loss: 0.16757665019720158: 100%|██████████| 50/50 [04:30<00:00,  5.41s/it] 


[Number Tasks: 1]: Task 1 Accuracy: 0.8309
[Number Tasks: 2]: Task 2 Accuracy: 0.8917
[Number Tasks: 3]: Task 3 Accuracy: 0.8882
[Number Tasks: 4]: Task 4 Accuracy: 0.92
[Number Tasks: 5]: Task 5 Accuracy: 0.9321
[Number Tasks: 6]: Task 6 Accuracy: 0.9479
[Number Tasks: 7]: Task 7 Accuracy: 0.9547
[Number Tasks: 8]: Task 8 Accuracy: 0.9598
Task 8 / 10. Mean Accuracy: 0.9156625


Total Loss: 0.3340493710751229, KL: 0.15148221298735193, Lik Loss: 0.18256715916572733: 100%|██████████| 50/50 [04:16<00:00,  5.12s/it] 


[Number Tasks: 1]: Task 1 Accuracy: 0.7795
[Number Tasks: 2]: Task 2 Accuracy: 0.8532
[Number Tasks: 3]: Task 3 Accuracy: 0.8789
[Number Tasks: 4]: Task 4 Accuracy: 0.9033
[Number Tasks: 5]: Task 5 Accuracy: 0.8987
[Number Tasks: 6]: Task 6 Accuracy: 0.9375
[Number Tasks: 7]: Task 7 Accuracy: 0.9447
[Number Tasks: 8]: Task 8 Accuracy: 0.9574
[Number Tasks: 9]: Task 9 Accuracy: 0.958
Task 9 / 10. Mean Accuracy: 0.9012444444444445


Total Loss: 0.325737911209147, KL: 0.14571773460570803, Lik Loss: 0.1800201766668482: 100%|██████████| 50/50 [04:42<00:00,  5.64s/it]   


[Number Tasks: 1]: Task 1 Accuracy: 0.7954
[Number Tasks: 2]: Task 2 Accuracy: 0.8357
[Number Tasks: 3]: Task 3 Accuracy: 0.8698
[Number Tasks: 4]: Task 4 Accuracy: 0.9119
[Number Tasks: 5]: Task 5 Accuracy: 0.8921
[Number Tasks: 6]: Task 6 Accuracy: 0.8921
[Number Tasks: 7]: Task 7 Accuracy: 0.9428
[Number Tasks: 8]: Task 8 Accuracy: 0.9506
[Number Tasks: 9]: Task 9 Accuracy: 0.9497
[Number Tasks: 10]: Task 10 Accuracy: 0.9574
Task 10 / 10. Mean Accuracy: 0.8997499999999998
Run Number: 0. Log File: logs/vcl_lr_0.001_withoutmlpinit_batch_256_coresetsize_0_epochs_50_run_0_no_tasks_10_no_train_samples_100.txt


Total Loss: 0.42458437731925475, KL: 0.2945742239343359, Lik Loss: 0.13001015359099874: 100%|██████████| 50/50 [04:30<00:00,  5.41s/it] 


[Number Tasks: 1]: Task 1 Accuracy: 0.9781
Task 1 / 10. Mean Accuracy: 0.9781


Total Loss: 0.3176800160331929, KL: 0.14319396418459873, Lik Loss: 0.17448605178518498: 100%|██████████| 50/50 [05:35<00:00,  6.71s/it] 


[Number Tasks: 1]: Task 1 Accuracy: 0.8535
[Number Tasks: 2]: Task 2 Accuracy: 0.9638
Task 2 / 10. Mean Accuracy: 0.90865


Total Loss: 0.28554316499131793, KL: 0.11760341258759194, Lik Loss: 0.16793975256224897: 100%|██████████| 50/50 [05:39<00:00,  6.79s/it]


[Number Tasks: 1]: Task 1 Accuracy: 0.8041
[Number Tasks: 2]: Task 2 Accuracy: 0.9512
[Number Tasks: 3]: Task 3 Accuracy: 0.9614
Task 3 / 10. Mean Accuracy: 0.9055666666666667


Total Loss: 0.29198614954948426, KL: 0.12247356002001052, Lik Loss: 0.16951258927583696: 100%|██████████| 50/50 [05:01<00:00,  6.03s/it]


[Number Tasks: 1]: Task 1 Accuracy: 0.7451
[Number Tasks: 2]: Task 2 Accuracy: 0.9423
[Number Tasks: 3]: Task 3 Accuracy: 0.9489
[Number Tasks: 4]: Task 4 Accuracy: 0.9629
Task 4 / 10. Mean Accuracy: 0.8997999999999999


Total Loss: 0.32107245478224244, KL: 0.13858917279446378, Lik Loss: 0.1824832818292557: 100%|██████████| 50/50 [05:32<00:00,  6.65s/it] 


[Number Tasks: 1]: Task 1 Accuracy: 0.6694
[Number Tasks: 2]: Task 2 Accuracy: 0.9393
[Number Tasks: 3]: Task 3 Accuracy: 0.9423
[Number Tasks: 4]: Task 4 Accuracy: 0.9446
[Number Tasks: 5]: Task 5 Accuracy: 0.958
Task 5 / 10. Mean Accuracy: 0.89072


Total Loss: 0.3102401822171313, KL: 0.13146038049079004, Lik Loss: 0.1787798008386125: 100%|██████████| 50/50 [05:49<00:00,  6.99s/it]  


[Number Tasks: 1]: Task 1 Accuracy: 0.6526
[Number Tasks: 2]: Task 2 Accuracy: 0.9322
[Number Tasks: 3]: Task 3 Accuracy: 0.9271
[Number Tasks: 4]: Task 4 Accuracy: 0.9297
[Number Tasks: 5]: Task 5 Accuracy: 0.9475
[Number Tasks: 6]: Task 6 Accuracy: 0.9586
Task 6 / 10. Mean Accuracy: 0.8912833333333333


Total Loss: 0.30992640821223566, KL: 0.12706133628145178, Lik Loss: 0.1828650729770356: 100%|██████████| 50/50 [05:18<00:00,  6.38s/it] 


[Number Tasks: 1]: Task 1 Accuracy: 0.6734
[Number Tasks: 2]: Task 2 Accuracy: 0.925
[Number Tasks: 3]: Task 3 Accuracy: 0.9055
[Number Tasks: 4]: Task 4 Accuracy: 0.906
[Number Tasks: 5]: Task 5 Accuracy: 0.9322
[Number Tasks: 6]: Task 6 Accuracy: 0.9514
[Number Tasks: 7]: Task 7 Accuracy: 0.9585
Task 7 / 10. Mean Accuracy: 0.8931428571428571


Total Loss: 0.3159539947484402, KL: 0.13078827008287958, Lik Loss: 0.1851657247289698: 100%|██████████| 50/50 [05:22<00:00,  6.45s/it]  


[Number Tasks: 1]: Task 1 Accuracy: 0.6256
[Number Tasks: 2]: Task 2 Accuracy: 0.9089
[Number Tasks: 3]: Task 3 Accuracy: 0.8777
[Number Tasks: 4]: Task 4 Accuracy: 0.8729
[Number Tasks: 5]: Task 5 Accuracy: 0.9045
[Number Tasks: 6]: Task 6 Accuracy: 0.9377
[Number Tasks: 7]: Task 7 Accuracy: 0.943
[Number Tasks: 8]: Task 8 Accuracy: 0.9549
Task 8 / 10. Mean Accuracy: 0.87815


Total Loss: 0.3259190507391666, KL: 0.13923009139426212, Lik Loss: 0.1866889586791079: 100%|██████████| 50/50 [05:25<00:00,  6.52s/it]  


[Number Tasks: 1]: Task 1 Accuracy: 0.625
[Number Tasks: 2]: Task 2 Accuracy: 0.9032
[Number Tasks: 3]: Task 3 Accuracy: 0.8818
[Number Tasks: 4]: Task 4 Accuracy: 0.854
[Number Tasks: 5]: Task 5 Accuracy: 0.8655
[Number Tasks: 6]: Task 6 Accuracy: 0.923
[Number Tasks: 7]: Task 7 Accuracy: 0.9308
[Number Tasks: 8]: Task 8 Accuracy: 0.9464
[Number Tasks: 9]: Task 9 Accuracy: 0.955
Task 9 / 10. Mean Accuracy: 0.8760777777777778


Total Loss: 0.3377294881546751, KL: 0.1474585697371909, Lik Loss: 0.1902709184491888: 100%|██████████| 50/50 [04:55<00:00,  5.91s/it]   


[Number Tasks: 1]: Task 1 Accuracy: 0.5933
[Number Tasks: 2]: Task 2 Accuracy: 0.8652
[Number Tasks: 3]: Task 3 Accuracy: 0.8527
[Number Tasks: 4]: Task 4 Accuracy: 0.844
[Number Tasks: 5]: Task 5 Accuracy: 0.8364
[Number Tasks: 6]: Task 6 Accuracy: 0.8887
[Number Tasks: 7]: Task 7 Accuracy: 0.9236
[Number Tasks: 8]: Task 8 Accuracy: 0.9305
[Number Tasks: 9]: Task 9 Accuracy: 0.9457
[Number Tasks: 10]: Task 10 Accuracy: 0.9565
Task 10 / 10. Mean Accuracy: 0.86366
Run Number: 1. Log File: logs/vcl_lr_0.001_withoutmlpinit_batch_256_coresetsize_0_epochs_50_run_1_no_tasks_10_no_train_samples_100.txt


Total Loss: 0.41761752367019656, KL: 0.28907260083137676, Lik Loss: 0.1285449224583646: 100%|██████████| 50/50 [05:25<00:00,  6.51s/it] 


[Number Tasks: 1]: Task 1 Accuracy: 0.9779
Task 1 / 10. Mean Accuracy: 0.9779


Total Loss: 0.3434606742351613, KL: 0.14476000026185462, Lik Loss: 0.19870067356114693: 100%|██████████| 50/50 [04:46<00:00,  5.73s/it] 


[Number Tasks: 1]: Task 1 Accuracy: 0.7948
[Number Tasks: 2]: Task 2 Accuracy: 0.9593
Task 2 / 10. Mean Accuracy: 0.87705


Total Loss: 0.3081944280482353, KL: 0.12506235720629388, Lik Loss: 0.1831320714443288: 100%|██████████| 50/50 [05:28<00:00,  6.57s/it]  


[Number Tasks: 1]: Task 1 Accuracy: 0.7596
[Number Tasks: 2]: Task 2 Accuracy: 0.943
[Number Tasks: 3]: Task 3 Accuracy: 0.9592
Task 3 / 10. Mean Accuracy: 0.8872666666666666


Total Loss: 0.3173211880186771, KL: 0.13230139695583507, Lik Loss: 0.18501979017511327: 100%|██████████| 50/50 [05:21<00:00,  6.43s/it] 


[Number Tasks: 1]: Task 1 Accuracy: 0.7363
[Number Tasks: 2]: Task 2 Accuracy: 0.9288
[Number Tasks: 3]: Task 3 Accuracy: 0.9451
[Number Tasks: 4]: Task 4 Accuracy: 0.9607
Task 4 / 10. Mean Accuracy: 0.892725


Total Loss: 0.30978207271149816, KL: 0.1291526660006097, Lik Loss: 0.18062940686941148: 100%|██████████| 50/50 [05:10<00:00,  6.20s/it] 


[Number Tasks: 1]: Task 1 Accuracy: 0.6999
[Number Tasks: 2]: Task 2 Accuracy: 0.9188
[Number Tasks: 3]: Task 3 Accuracy: 0.9369
[Number Tasks: 4]: Task 4 Accuracy: 0.9546
[Number Tasks: 5]: Task 5 Accuracy: 0.9581
Task 5 / 10. Mean Accuracy: 0.89366


Total Loss: 0.31891128192556667, KL: 0.1292149195011626, Lik Loss: 0.1896963619171305: 100%|██████████| 50/50 [04:56<00:00,  5.93s/it]  


[Number Tasks: 1]: Task 1 Accuracy: 0.6642
[Number Tasks: 2]: Task 2 Accuracy: 0.9096
[Number Tasks: 3]: Task 3 Accuracy: 0.9169
[Number Tasks: 4]: Task 4 Accuracy: 0.9434
[Number Tasks: 5]: Task 5 Accuracy: 0.9514
[Number Tasks: 6]: Task 6 Accuracy: 0.9566
Task 6 / 10. Mean Accuracy: 0.8903500000000001


Total Loss: 0.31919813435128397, KL: 0.12940902545097027, Lik Loss: 0.18978910902713209: 100%|██████████| 50/50 [05:35<00:00,  6.70s/it]


[Number Tasks: 1]: Task 1 Accuracy: 0.6664
[Number Tasks: 2]: Task 2 Accuracy: 0.8895
[Number Tasks: 3]: Task 3 Accuracy: 0.9004
[Number Tasks: 4]: Task 4 Accuracy: 0.9268
[Number Tasks: 5]: Task 5 Accuracy: 0.9411
[Number Tasks: 6]: Task 6 Accuracy: 0.9408
[Number Tasks: 7]: Task 7 Accuracy: 0.9526
Task 7 / 10. Mean Accuracy: 0.8882285714285715


Total Loss: 0.31826364245820554, KL: 0.1320645813612228, Lik Loss: 0.18619905973368503: 100%|██████████| 50/50 [05:21<00:00,  6.44s/it] 


[Number Tasks: 1]: Task 1 Accuracy: 0.6437
[Number Tasks: 2]: Task 2 Accuracy: 0.874
[Number Tasks: 3]: Task 3 Accuracy: 0.8989
[Number Tasks: 4]: Task 4 Accuracy: 0.922
[Number Tasks: 5]: Task 5 Accuracy: 0.932
[Number Tasks: 6]: Task 6 Accuracy: 0.9204
[Number Tasks: 7]: Task 7 Accuracy: 0.947
[Number Tasks: 8]: Task 8 Accuracy: 0.9577
Task 8 / 10. Mean Accuracy: 0.8869625000000001


Total Loss: 0.3300857106421856, KL: 0.13801354968801458, Lik Loss: 0.19207216104928485: 100%|██████████| 50/50 [04:28<00:00,  5.36s/it] 


[Number Tasks: 1]: Task 1 Accuracy: 0.6414
[Number Tasks: 2]: Task 2 Accuracy: 0.8553
[Number Tasks: 3]: Task 3 Accuracy: 0.8665
[Number Tasks: 4]: Task 4 Accuracy: 0.8978
[Number Tasks: 5]: Task 5 Accuracy: 0.9149
[Number Tasks: 6]: Task 6 Accuracy: 0.9241
[Number Tasks: 7]: Task 7 Accuracy: 0.9351
[Number Tasks: 8]: Task 8 Accuracy: 0.9479
[Number Tasks: 9]: Task 9 Accuracy: 0.9532
Task 9 / 10. Mean Accuracy: 0.8818


Total Loss: 0.33427082154345006, KL: 0.14224208587027612, Lik Loss: 0.1920287345001038: 100%|██████████| 50/50 [05:09<00:00,  6.19s/it] 


[Number Tasks: 1]: Task 1 Accuracy: 0.6454
[Number Tasks: 2]: Task 2 Accuracy: 0.8192
[Number Tasks: 3]: Task 3 Accuracy: 0.8637
[Number Tasks: 4]: Task 4 Accuracy: 0.8681
[Number Tasks: 5]: Task 5 Accuracy: 0.9021
[Number Tasks: 6]: Task 6 Accuracy: 0.8777
[Number Tasks: 7]: Task 7 Accuracy: 0.92
[Number Tasks: 8]: Task 8 Accuracy: 0.9416
[Number Tasks: 9]: Task 9 Accuracy: 0.9452
[Number Tasks: 10]: Task 10 Accuracy: 0.9546
Task 10 / 10. Mean Accuracy: 0.8737599999999999
Run Number: 2. Log File: logs/vcl_lr_0.001_withoutmlpinit_batch_256_coresetsize_0_epochs_50_run_2_no_tasks_10_no_train_samples_100.txt


Total Loss: 0.41031713295490185, KL: 0.2812678060633071, Lik Loss: 0.12904932693915164: 100%|██████████| 50/50 [05:37<00:00,  6.76s/it]


[Number Tasks: 1]: Task 1 Accuracy: 0.977
Task 1 / 10. Mean Accuracy: 0.977


Total Loss: 0.3076999562851926, KL: 0.13715709733202103, Lik Loss: 0.17054285946044517: 100%|██████████| 50/50 [10:58<00:00, 13.17s/it] 


[Number Tasks: 1]: Task 1 Accuracy: 0.9412
[Number Tasks: 2]: Task 2 Accuracy: 0.9654
Task 2 / 10. Mean Accuracy: 0.9533


Total Loss: 0.2843289720885297, KL: 0.12309240267631856, Lik Loss: 0.1612365702365307: 100%|██████████| 50/50 [11:10<00:00, 13.40s/it]  


[Number Tasks: 1]: Task 1 Accuracy: 0.9182
[Number Tasks: 2]: Task 2 Accuracy: 0.9549
[Number Tasks: 3]: Task 3 Accuracy: 0.9653
Task 3 / 10. Mean Accuracy: 0.9461333333333334


Total Loss: 0.29100911623619974, KL: 0.123182339110273, Lik Loss: 0.16782677814047386: 100%|██████████| 50/50 [10:34<00:00, 12.70s/it]  


[Number Tasks: 1]: Task 1 Accuracy: 0.9013
[Number Tasks: 2]: Task 2 Accuracy: 0.9486
[Number Tasks: 3]: Task 3 Accuracy: 0.9542
[Number Tasks: 4]: Task 4 Accuracy: 0.9645
Task 4 / 10. Mean Accuracy: 0.94215


Total Loss: 0.3075224941715281, KL: 0.11335781834861065, Lik Loss: 0.19416467610825883:  48%|████▊     | 24/50 [05:58<06:02, 13.93s/it] 

In [None]:
TRAIN_NUM_SAMPLES = 10
it = iter(task_loaders[3][1])

In [85]:
correct_entropy = 0
incorrect_entropy = 0

correct_prob_std = 0
incorrect_prob_std = 0 

incorrect_count, correct_count = 0, 0

model.eval()

for inputs, targets in train_loaders[3]:
    with torch.no_grad():
        inputs, targets = inputs.cuda(), targets.cuda() 
        batch = inputs.size(0)
        
        inputs = torch.tile(inputs, (TEST_NUM_SAMPLES, 1, 1))
        samples = model(inputs)
        
        samples = samples.view(TEST_NUM_SAMPLES, batch, -1)
        samples_mean = torch.mean(samples, dim=0)
        
        samples_prob = F.softmax(samples, dim=-1)
        samples_prob_mean = torch.mean(samples_prob, dim=0)
    
        samples_prob_std = torch.std(samples_prob, dim=0)
        samples_entropy = (-samples_prob_mean * torch.log2(samples_prob_mean + 1e-9)).sum(dim=1)
        
        preds = samples_mean.argmax(dim=-1)
        
        incorrect_indices = (preds != targets).nonzero()
        correct_indices = (preds == targets).nonzero()
        
        incorrect_prob_std += torch.sum(torch.mean(samples_prob_std[incorrect_indices], dim=-1))
        correct_prob_std += torch.sum(torch.mean(samples_prob_std[correct_indices], dim=-1)) 
        
        incorrect_entropy += torch.sum(samples_entropy[incorrect_indices])
        correct_entropy += torch.sum(samples_entropy[incorrect_indices])
        
        incorrect_count += incorrect_indices.count_nonzero()
        correct_count += correct_indices.count_nonzero()
 
incorrect_entropy /= incorrect_count
incorrect_prob_std /= incorrect_count

correct_entropy /= correct_count
correct_prob_std /= correct_count

tensor([7, 9, 2, 3, 1, 5, 2, 1, 6, 5, 1, 4, 3, 6, 3, 0, 9, 2, 7, 2, 9, 9, 0, 5,
        8, 2, 1, 2, 3, 0, 2, 8, 5, 8, 0, 4, 4, 5, 2, 3, 8, 6, 1, 6, 6, 5, 6, 9,
        7, 0, 7, 5, 6, 9, 1, 7, 7, 0, 5, 3, 5, 9, 9, 4, 1, 8, 0, 5, 0, 4, 4, 2,
        5, 3, 2, 9, 1, 0, 8, 9, 0, 8, 5, 2, 8, 0, 6, 4, 9, 8, 8, 9, 9, 3, 3, 2,
        8, 6, 3, 1, 5, 6, 5, 8, 5, 0, 2, 9, 9, 9, 4, 9, 6, 8, 1, 1, 7, 1, 7, 5,
        2, 6, 6, 3, 3, 9, 5, 6, 3, 1, 9, 7, 5, 6, 7, 8, 4, 0, 3, 2, 4, 1, 5, 4,
        3, 5, 5, 1, 5, 1, 4, 9, 4, 9, 8, 0, 4, 6, 7, 1, 5, 7, 2, 2, 1, 9, 5, 0,
        0, 1, 8, 3, 8, 1, 0, 7, 7, 5, 6, 8, 8, 1, 3, 2, 5, 5, 4, 2, 1, 3, 7, 6,
        5, 2, 4, 8, 6, 3, 5, 7, 4, 6, 8, 8, 6, 7, 7, 9, 5, 0, 1, 5, 1, 6, 4, 3,
        2, 4, 0, 2, 3, 8, 1, 4, 1, 0, 3, 3, 2, 9, 8, 8, 4, 1, 7, 6, 3, 8, 3, 0,
        6, 2, 3, 6, 2, 7, 5, 9, 2, 6, 8, 7, 9, 8, 6, 6])


ZeroDivisionError: division by zero

In [66]:
from models.coreset import UncertaintyCoreset

In [115]:
uc = UncertaintyCoreset(dataset, 100)

In [116]:
train_loaders = uc.get_train_loaders(256)

In [117]:
uc_loader = uc.update_uncertainty_coreset_and_get_loader(task_idx=0, 
                                          batch_size=256,
                                          training_task_loader=train_loaders[0],
                                          model=model,
                                          uncertainty_coreset_size=100,
                                          no_samples=100)

In [118]:
correct_entropy = 0
incorrect_entropy = 0

correct_prob_std = 0
incorrect_prob_std = 0 

incorrect_count, correct_count = 0, 0

model.eval()

for inputs, targets in uc_loader:
    with torch.no_grad():
        print(targets)
        inputs, targets = inputs.cuda(), targets.cuda()
        batch = inputs.size(0)
        
        inputs = torch.tile(inputs, (TEST_NUM_SAMPLES, 1, 1))
        samples = model(inputs)
        
        samples_mean = torch.mean(samples, dim=0)
        
        samples_prob = F.softmax(samples, dim=-1)
        samples_prob_mean = torch.mean(samples_prob, dim=0)
    
        samples_prob_std = torch.std(samples_prob, dim=0)
        samples_entropy = (-samples_prob_mean * torch.log2(samples_prob_mean + 1e-9)).sum(dim=1)
        
        preds = samples_mean.argmax(dim=-1)
        
        incorrect_indices = (preds != targets).nonzero()
        correct_indices = (preds == targets).nonzero()
        
        incorrect_prob_std += torch.sum(torch.mean(samples_prob_std[incorrect_indices], dim=-1))
        correct_prob_std += torch.sum(torch.mean(samples_prob_std[correct_indices], dim=-1)) 
        
        incorrect_entropy += torch.sum(samples_entropy[incorrect_indices])
        correct_entropy += torch.sum(samples_entropy[incorrect_indices])
        
        incorrect_count += incorrect_indices.size(0)
        correct_count += correct_indices.size(0)

print(correct_count, incorrect_count)

incorrect_entropy /= incorrect_count
incorrect_prob_std /= incorrect_count

correct_entropy /= correct_count
correct_prob_std /= correct_count

tensor([1, 0, 1, 2, 9, 2, 1, 3, 7, 1, 6, 0, 0, 1, 0, 8, 1, 7, 7, 0, 7, 1, 6, 1,
        7, 1, 1, 1, 9, 4, 1, 1, 1, 5, 0, 1, 8, 2, 7, 1, 2, 1, 2, 6, 2, 3, 1, 8,
        1, 3, 7, 7, 0, 0, 0, 1, 2, 0, 8, 6, 9, 2, 6, 2, 3, 3, 1, 6, 1, 2, 1, 4,
        7, 7, 1, 1, 9, 7, 6, 1, 4, 6, 5, 7, 1, 6, 1, 5, 3, 7, 1, 4, 1, 3, 1, 1,
        2, 7, 1, 2, 0, 1, 6, 1, 0, 5, 1, 5, 9, 7, 5, 1, 3, 1, 1, 2, 1, 0, 7, 1,
        7, 8, 6, 5, 1, 1, 3, 8, 2, 6, 1, 8, 7, 2, 9, 2, 7, 7, 7, 2, 7, 1, 3, 1,
        9, 1, 4, 7, 1, 8, 1, 4, 0, 1, 1, 4, 1, 2, 1, 1, 0, 1, 1, 7, 2, 2, 7, 7,
        1, 6, 7, 0, 5, 1, 3, 1, 1, 1, 5, 1, 0, 4, 5, 2, 7, 1, 1, 1, 0, 9, 7, 5,
        7, 1, 1, 0, 4, 7, 2, 3])
155 45


In [119]:
correct_entropy

tensor(0.7789, device='cuda:0')

In [120]:
samples_prob_std.shape

torch.Size([200, 10])

In [121]:
incorrect_entropy, correct_entropy

(tensor(2.6830, device='cuda:0'), tensor(0.7789, device='cuda:0'))

In [122]:
incorrect_count, correct_count

(45, 155)

In [123]:
incorrect_prob_std, correct_prob_std

(tensor(0.0880, device='cuda:0'), tensor(0.0258, device='cuda:0'))

In [None]:
accuracies

In [None]:
for p in model.module_list[0].parameters():
    print(p)

In [None]:
model.module_list[0].prior_weight_mu