In [1]:
import torch
import torchvision
from torchvision import transforms, datasets
from sklearn import metrics
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Subset
import numpy as np
import copy
import statistics
import matplotlib.pyplot as plt
import time
from tqdm import tqdm
torch.cuda.empty_cache()
import sys
sys.path.append('./../src')
import globals
from model import MnistCNN, TinyImageNetCNN
from training import train_model, train_model_CL
from visualizations import plot_embeddings, plot_confusion_matrix
from feature_attribution import Feature_Importance_Evaluations
from pytorch_utils import get_features, get_labels
from embedding_measurements import measure_embedding_confusion_knn, measure_embedding_drift

In [2]:
SEED = globals.SEED
DEVICE = globals.DEVICE
full_trainset = globals.full_trainset
trainset = globals.trainset
testset = globals.testset
trainloaders = globals.trainloaders
valloaders = globals.valloaders
testloaders = globals.testloaders

In [3]:
from collections import defaultdict
import json
file_path = "class_to_indices.json"
def initialize_data():
    data_folder = './../data/'
    if globals.dataset == 'mnist':
        transform = transforms.Compose([
            transforms.ToTensor()
            #transforms.Normalize((0.5,), (0.5,))  # Normalizes to mean 0.5 and std 0.5 for the single channel
        ])
        globals.full_trainset = torchvision.datasets.MNIST(data_folder, train=True, download=True,
                                    transform=transform)
        targets = np.array(globals.full_trainset.targets)
        globals.testset = torchvision.datasets.MNIST(data_folder, train=False, download=True,
                                transform=transform)
    elif globals.dataset == 'tiny_imagenet':
        transform = transforms.Compose([
            transforms.ToTensor()
            #transforms.Normalize((0.5,), (0.5,))  # Normalizes to mean 0.5 and std 0.5 for the single channel
        ])
        globals.full_trainset = datasets.ImageFolder(root=data_folder + "tiny-imagenet-200/train", transform=transform)
        globals.testset = datasets.ImageFolder(root=data_folder + "tiny-imagenet-200/val", transform=transform)
        targets = [sample[1] for sample in globals.full_trainset.samples]
        targets = np.array(targets)
    else:
        raise NotImplementedError("unsupported dataset")
    if globals.val_set_size != 0:
        # Perform stratified split
        train_indices, val_indices = train_test_split(
            np.arange(len(targets)),
            test_size=globals.val_set_size,
            stratify=targets
        )
    else:
        train_indices = np.arange(len(targets))
        val_indices = []

    # Create subsets
    valset = Subset(globals.full_trainset, val_indices)
    globals.trainset = Subset(globals.full_trainset, train_indices)

    # Define class pairs for each subset
    class_pairs = [tuple(range(i*globals.CLASSES_PER_ITER,(i+1)*globals.CLASSES_PER_ITER)) for i in range(globals.ITERATIONS)]
    #print(class_pairs)

    # Dictionary to hold data loaders for each subset
    globals.trainloaders = []
    globals.testloaders = []
    globals.valloaders = []
    subset_indices = []

    def compute_class_to_indices(dataset):
        class_to_indices = defaultdict(list)
        for idx, (_, label) in enumerate(dataset):
            if idx%1000 == 0:
                print(idx, end=' ')
            class_to_indices[label].append(idx)
        return class_to_indices

    train_class_to_indices = compute_class_to_indices(globals.trainset)
    val_class_to_indices = compute_class_to_indices(valset)
    test_class_to_indices = compute_class_to_indices(globals.testset)

    # Loop over each class pair
    for i, t in enumerate(class_pairs):
        # Get indices of images belonging to the specified class pair
        subs_ind = [idx for cls in t for idx in train_class_to_indices[cls]]
        val_subset_indices = [idx for cls in t for idx in val_class_to_indices[cls]]
        test_subset_indices = [idx for cls in t for idx in test_class_to_indices[cls]]

        # Create a subset for the current class pair
        train_subset = Subset(globals.trainset, subs_ind)
        globals.trainloaders.append(DataLoader(train_subset, batch_size=globals.BATCH_SIZE, shuffle=True, pin_memory=True, num_workers = 0))
        subset_indices.append(subs_ind)
        
        val_subset = Subset(valset, val_subset_indices)
        globals.valloaders.append(DataLoader(val_subset, batch_size=100, shuffle=False))

        test_subset = Subset(globals.testset, test_subset_indices)
        globals.testloaders.append(DataLoader(test_subset, batch_size=100, shuffle=False))


In [4]:
def run_experiment(
        verbose = False,
        stopOnLoss = None,
        full_CE = True,
        with_OOD = False,
        ood_method = 'jigsaw',
        kd_loss = 0,
        stopOnValAcc = None,
        epochs = 1000000,
        with_dropout = False,
        ogd = False,
        dataset = None,
        joint_training = False
        ):
    def _print(*args, **kwargs):
        if verbose:
            print(*args, **kwargs)
    if dataset is not None:
        globals.dataset = dataset
    else:
        dataset = globals.dataset
    if dataset == 'tiny_imagenet':
        stopOnLoss = 1.3
        if joint_training:
            globals.CLASSES_PER_ITER = 200
            globals.ITERATIONS = 1
        else:
            globals.CLASSES_PER_ITER = 40
            globals.ITERATIONS = 5
    elif dataset == 'mnist':
        stopOnLoss = 0.03
        if joint_training:
            globals.CLASSES_PER_ITER = 10
            globals.ITERATIONS = 1
        else:
            globals.CLASSES_PER_ITER = 2
            globals.ITERATIONS = 5
    if with_OOD:
        globals.toggle_OOD(ood_method)
    else:
        globals.disable_OOD()
    initialize_data()
    prevModel = None
    globals.BATCH_SIZE=4
    
    globals.WITH_DROPOUT = with_dropout

    #[Denis] added code:
    Feature_Importance_Eval=Feature_Importance_Evaluations(globals.testloaders, DEVICE)

    for i in tqdm(range(globals.ITERATIONS), desc="Experiment Progress"):
        if globals.dataset == 'mnist':
            model = MnistCNN((i+1)*(globals.CLASSES_PER_ITER+globals.OOD_CLASS))
        elif globals.dataset == 'tiny_imagenet':
            model = TinyImageNetCNN((i+1)*(globals.CLASSES_PER_ITER+globals.OOD_CLASS))
        if prevModel is not None:
            with torch.no_grad():
                model.copyPrev(prevModel)
        train_loader = globals.trainloaders[i]
        val_loader = globals.valloaders[i]
        if prevModel:
            _print("CL TRAIN!!")
            train_model_CL(
                model,
                prevModel,
                train_loader,
                val_loader,
                i,
                verbose,
                epochs,
                True,
                freeze_nonzero_params=False,
                l1_loss=0,
                ewc_loss=0,
                kd_loss=kd_loss,
                distance_loss=0,
                center_loss=0,
                param_reuse_loss=0,
                stopOnLoss=stopOnLoss,
                stopOnValAcc = stopOnValAcc,
                full_CE=full_CE,
                ogd=ogd,
                )
        else:
            print("TRAINING!")
            train_model(
                model, 
                train_loader, 
                val_loader, 
                verbose, 
                epochs=epochs, 
                l1_loss=0,
                stopOnLoss=stopOnLoss,
                center_loss=0,
                ogd=ogd
                )

        #[Denis] added code:
        Feature_Importance_Eval.Task_Feature_Attribution(model, i)
        
        if verbose or i == globals.ITERATIONS-1:
            _print("Starting evaluation")
            _print("ITERATION", i+1)
            _print("ACCURACIES PER TASK:")
            accumPred = []
            all_labels = []
            all_embeddings = []
            with torch.no_grad():
                for j in range(i+1):
                    val_loader = globals.testloaders[j]
                    val_labels = get_labels(val_loader).to(DEVICE)
                    all_labels.append(val_labels)
                    model.eval()
                    pred, embeddings = model.get_pred_and_embeddings((get_features(val_loader).to(DEVICE)))
                    model.train()
                    accumPred.append(pred)
                    all_embeddings.append(embeddings)
                    sliced_pred = pred[:, j*(globals.CLASSES_PER_ITER+globals.OOD_CLASS):(j+1)*(globals.CLASSES_PER_ITER+globals.OOD_CLASS)]
                    _, predicted = torch.max(sliced_pred, 1)  # Get the class predictions
                    predicted += j*globals.CLASSES_PER_ITER
                    correct = (predicted == val_labels).sum().item()  # Count how many were correct
                    accuracy = correct / val_labels.size(0)  # Accuracy as a percentage
                    _print(str(accuracy), end=' ')
            accumPred = torch.cat(accumPred)
            all_labels = torch.cat(all_labels)
            all_embeddings = torch.cat(all_embeddings)
            predicted = []
            for x in accumPred:
                if globals.OOD_CLASS == 1:
                    x_pred = x[[i for i in range(x.size(0)) if (i + 1) % (globals.CLASSES_PER_ITER+1) != 0]]
                else:
                    x_pred = x
                x_pred = torch.softmax(x_pred, dim=-1)
                max = 0
                for (k, v) in enumerate(x_pred):
                    if v > max:
                        max = v
                        p = k
                predicted.append(p)
            predicted = torch.tensor(predicted).to(DEVICE)
            correct = (predicted == all_labels).sum().item()  # Count how many were correct
            accuracy = correct / all_labels.size(0)  # Accuracy as a percentage
            _print("Accuracy on tasks so far:", accuracy)
            embedding_drift = measure_embedding_drift(all_embeddings, all_labels, model.prev_test_embedding_centers)
            _print("Average embedding drift based on centroids:", embedding_drift)
            total_confusion, intra_phase_confusion, per_task_confusion = measure_embedding_confusion_knn(all_embeddings, all_labels, k = 1000, task=i+1)
            _print("Total confusion", total_confusion)
            _print("Intra-phase confusion", intra_phase_confusion)
            _print("Per task confusions", per_task_confusion)
            if verbose and globals.dataset == 'mnist':
                plot_confusion_matrix(predicted.cpu(), all_labels.cpu(), list(range(globals.CLASSES_PER_ITER*(i+1))))
        prevModel = copy.deepcopy(model)
        
    #[Denis] added code:
    [avg_att_diff,att_diffs,_,_,avg_att_spread,att_spreads]=Feature_Importance_Eval.Get_Feature_Change_Score(prevModel)
    _print("Average SHAPC values (ordered as tasks):", att_diffs)
    _print("Averaged SHAPC value (the smaller the better):", avg_att_diff)
    _print("Average attention spread values (ordered as tasks):", att_spreads)
    _print("Averaged attention spread value (the bigger the better):", avg_att_spread)
    Feature_Importance_Eval.Save_Random_Picture_Salency() #prints the salcency maps for 1 example by class (first row: image, second row: salency map after training, third row: salency map after training task where class is included)
    
    return accuracy, total_confusion, intra_phase_confusion, per_task_confusion, embedding_drift, avg_att_diff, avg_att_spread

In [5]:
def run_experiments(n_runs=globals.EXPERIMENT_N_RUNS, *args, **kwargs):
    verbose = kwargs.get('verbose', None)
    def _print(*args, **kwargs):
        if verbose:
            print(*args, **kwargs)
    def report_stats(data, name):
        #print(name, data)
        mean = statistics.mean(data)
        std = statistics.stdev(data)
        print(f"Mean " + name + f" across {n_runs} runs: {mean}")
        print(f"Standard deviation of " + name + f" across {n_runs} runs: {std}\n")
    accuracies = []
    total_confusions = []
    intra_phase_confusions = []
    per_task_confusions = []
    att_diffs = []
    embedding_drifts = []
    att_spreads = []
    for r in range(n_runs):
        print(f"Starting run {r+1}.")
        accuracy, total_confusion, intra_phase_confusion, per_task_confusion, embedding_drift, avg_att_diff, avg_att_spread = run_experiment(*args, **kwargs)
        accuracies.append(accuracy)
        total_confusions.append(total_confusion)
        intra_phase_confusions.append(intra_phase_confusion)
        per_task_confusions.append(per_task_confusion)
        att_diffs.append(avg_att_diff)
        embedding_drifts.append(embedding_drift)
        att_spreads.append(avg_att_spread)
        print(f"Run {r} finished with accuracy {accuracy}")

    report_stats(accuracies, "accuracy")
    report_stats(total_confusions, "total confusion")
    report_stats(intra_phase_confusions, "intra-phase confusion")
    report_stats(per_task_confusions, "per-task confusion")
    report_stats(embedding_drifts, "embedding drift")
    report_stats(att_diffs, "attention drift")
    report_stats(att_spreads, "attention spread")

In [6]:
run_experiments(verbose=True, with_dropout=True, kd_loss=1, full_CE=False, with_OOD=True, ogd=False) # best version

Starting run 1.
0 1000 2000 3000 4000 5000 6000 7000 8000 9000 10000 11000 12000 13000 14000 15000 16000 17000 18000 19000 20000 21000 22000 23000 24000 25000 26000 27000 28000 29000 30000 31000 32000 33000 34000 35000 36000 37000 38000 39000 40000 41000 42000 43000 44000 45000 46000 47000 48000 49000 50000 51000 52000 53000 54000 55000 56000 57000 58000 59000 60000 61000 62000 63000 64000 65000 66000 67000 68000 69000 70000 71000 72000 73000 74000 75000 76000 77000 78000 79000 80000 81000 82000 83000 84000 85000 86000 87000 88000 89000 0 1000 2000 3000 4000 5000 6000 7000 8000 9000 0 1000 2000 3000 4000 5000 6000 7000 8000 9000 

Experiment Progress:   0%|          | 0/5 [00:00<?, ?it/s]

TRAINING!
Epoch 0, CE Loss: 3.7142, center loss: 0.0000
Fraction of nonzero parameters 0.9999146269162125
Validation loss 3.6810053110122682 validation accuracy 0.0455 

Epoch 1, CE Loss: 3.6132, center loss: 0.0000
Fraction of nonzero parameters 0.9999145919272436
Validation loss 3.591951620578766 validation accuracy 0.0585 

Epoch 2, CE Loss: 3.4670, center loss: 0.0000
Fraction of nonzero parameters 0.9999145219493061
Validation loss 3.4247055888175963 validation accuracy 0.1025 

Epoch 3, CE Loss: 3.3028, center loss: 0.0000
Fraction of nonzero parameters 0.9999156066073378
Validation loss 3.2678787112236023 validation accuracy 0.137 

Epoch 4, CE Loss: 3.1299, center loss: 0.0000
Fraction of nonzero parameters 0.9999140670927121
Validation loss 3.125190496444702 validation accuracy 0.1625 

Epoch 5, CE Loss: 2.9665, center loss: 0.0000
Fraction of nonzero parameters 0.99991753100062
Validation loss 2.992747700214386 validation accuracy 0.187 

Epoch 6, CE Loss: 2.8215, center loss

Experiment Progress:  20%|██        | 1/5 [23:21<1:33:24, 1401.15s/it]

Total confusion 0.967853
Intra-phase confusion 0.0
Per task confusions 0.967853
CL TRAIN!!
Epoch 0  CELoss: 2.9253, KLLoss: 0.1786, L1Loss: 0.0000, EWCLoss: 0.0000, CenterLoss: 0.0000, InterCenterLoss: 0.0000, ParamReuseLoss: 0.0000
Fraction of nonzero parameters 0.9999277152777354
Validation losses: 2.6526939928531648 0.25693308711051943
Validation accuracy (for last task) 0.0595
Total validation accuracy 0.197


Epoch 1  CELoss: 2.4091, KLLoss: 0.2607, L1Loss: 0.0000, EWCLoss: 0.0000, CenterLoss: 0.0000, InterCenterLoss: 0.0000, ParamReuseLoss: 0.0000
Fraction of nonzero parameters 0.9999265116952459
Validation losses: 2.443448632955551 0.288641095161438
Validation accuracy (for last task) 0.166
Total validation accuracy 0.23325


Epoch 2  CELoss: 2.0880, KLLoss: 0.2630, L1Loss: 0.0000, EWCLoss: 0.0000, CenterLoss: 0.0000, InterCenterLoss: 0.0000, ParamReuseLoss: 0.0000
Fraction of nonzero parameters 0.9999270349919804
Validation losses: 2.3818488985300066 0.3327667087316513
Validati

Experiment Progress:  40%|████      | 2/5 [38:26<55:27, 1109.28s/it]  

Total confusion 0.98056025
Intra-phase confusion 0.49351124999999996
Per task confusions 0.96919475
CL TRAIN!!
Epoch 0  CELoss: 2.8200, KLLoss: 0.1616, L1Loss: 0.0000, EWCLoss: 0.0000, CenterLoss: 0.0000, InterCenterLoss: 0.0000, ParamReuseLoss: 0.0000
Fraction of nonzero parameters 0.9999329876064291
Validation losses: 2.564415621757507 0.22115917205810548
Validation accuracy (for last task) 0.1035
Total validation accuracy 0.18566666666666667


Epoch 1  CELoss: 2.2988, KLLoss: 0.2294, L1Loss: 0.0000, EWCLoss: 0.0000, CenterLoss: 0.0000, InterCenterLoss: 0.0000, ParamReuseLoss: 0.0000
Fraction of nonzero parameters 0.9999326397611472
Validation losses: 2.4113117814064027 0.25564190149307253
Validation accuracy (for last task) 0.1965
Total validation accuracy 0.19533333333333333


Epoch 2  CELoss: 1.9820, KLLoss: 0.2447, L1Loss: 0.0000, EWCLoss: 0.0000, CenterLoss: 0.0000, InterCenterLoss: 0.0000, ParamReuseLoss: 0.0000
Fraction of nonzero parameters 0.9999353877388747
Validation losse

Experiment Progress:  60%|██████    | 3/5 [52:13<32:41, 980.70s/it] 

Total confusion 0.9856675
Intra-phase confusion 0.6537978333333334
Per task confusions 0.9697598333333333
CL TRAIN!!
Epoch 0  CELoss: 2.8565, KLLoss: 0.1131, L1Loss: 0.0000, EWCLoss: 0.0000, CenterLoss: 0.0000, InterCenterLoss: 0.0000, ParamReuseLoss: 0.0000
Fraction of nonzero parameters 0.9999359574689615
Validation losses: 2.6522301375865935 0.17769975066184998
Validation accuracy (for last task) 0.075
Total validation accuracy 0.146625


Epoch 1  CELoss: 2.3762, KLLoss: 0.1777, L1Loss: 0.0000, EWCLoss: 0.0000, CenterLoss: 0.0000, InterCenterLoss: 0.0000, ParamReuseLoss: 0.0000
Fraction of nonzero parameters 0.9999382465602416
Validation losses: 2.4788750052452087 0.20217849612236022
Validation accuracy (for last task) 0.1765
Total validation accuracy 0.162625


Epoch 2  CELoss: 2.0743, KLLoss: 0.1928, L1Loss: 0.0000, EWCLoss: 0.0000, CenterLoss: 0.0000, InterCenterLoss: 0.0000, ParamReuseLoss: 0.0000
Fraction of nonzero parameters 0.9999408824835336
Validation losses: 2.45195852518

Experiment Progress:  80%|████████  | 4/5 [1:09:13<16:36, 996.07s/it]

Total confusion 0.988779625
Intra-phase confusion 0.738088125
Per task confusions 0.969997875
CL TRAIN!!
Epoch 0  CELoss: 2.5131, KLLoss: 0.1463, L1Loss: 0.0000, EWCLoss: 0.0000, CenterLoss: 0.0000, InterCenterLoss: 0.0000, ParamReuseLoss: 0.0000
Fraction of nonzero parameters 0.9999422645782847
Validation losses: 2.3082621335983275 0.19270626902580262
Validation accuracy (for last task) 0.137
Total validation accuracy 0.1196


Epoch 1  CELoss: 2.0487, KLLoss: 0.2030, L1Loss: 0.0000, EWCLoss: 0.0000, CenterLoss: 0.0000, InterCenterLoss: 0.0000, ParamReuseLoss: 0.0000
Fraction of nonzero parameters 0.9999422818695162
Validation losses: 2.157367479801178 0.2224252462387085
Validation accuracy (for last task) 0.249
Total validation accuracy 0.1308


Epoch 2  CELoss: 1.7780, KLLoss: 0.2079, L1Loss: 0.0000, EWCLoss: 0.0000, CenterLoss: 0.0000, InterCenterLoss: 0.0000, ParamReuseLoss: 0.0000
Fraction of nonzero parameters 0.9999432501784757
Validation losses: 2.020438474416733 0.234043174982

Experiment Progress: 100%|██████████| 5/5 [1:25:06<00:00, 1021.21s/it]

Total confusion 0.9892286
Intra-phase confusion 0.7816408
Per task confusions 0.9691332000000001





Average SHAPC values (ordered as tasks): [3.3029472622736475e-08, 3.0496576480132376e-08, 2.9247850735991676e-08, 2.5072801775216645e-08, 2.238449428156504e-16]
Averaged SHAPC value (the smaller the better): 2.3569340367584423e-08
Average attention spread values (ordered as tasks): [531.7690983835846, 538.2286958821611, 544.1302551116951, 543.7956164042156, 551.2464450632735]
Averaged attention spread value (the bigger the better): 541.8340221689859
Run 0 finished with accuracy 0.1191
Starting run 2.
0 1000 2000 3000 4000 5000 6000 7000 8000 9000 10000 11000 12000 13000 14000 15000 16000 17000 18000 19000 20000 21000 22000 23000 24000 25000 26000 27000 28000 29000 30000 31000 32000 33000 34000 35000 36000 37000 38000 39000 40000 41000 42000 43000 44000 45000 46000 47000 48000 49000 50000 51000 52000 53000 54000 55000 56000 57000 58000 59000 60000 61000 62000 63000 64000 65000 66000 67000 68000 69000 70000 71000 72000 73000 74000 75000 76000 77000 78000 79000 80000 81000 82000 83000 840

Experiment Progress:   0%|          | 0/5 [00:00<?, ?it/s]

TRAINING!
Epoch 0, CE Loss: 3.7150, center loss: 0.0000
Fraction of nonzero parameters 0.9999140845871965
Validation loss 3.68889297246933 validation accuracy 0.025 

Epoch 1, CE Loss: 3.6933, center loss: 0.0000
Fraction of nonzero parameters 0.9999155891128535
Validation loss 3.6336562395095826 validation accuracy 0.048 

Epoch 2, CE Loss: 3.5371, center loss: 0.0000
Fraction of nonzero parameters 0.9999147143886343
Validation loss 3.5381251096725466 validation accuracy 0.0645 

Epoch 3, CE Loss: 3.3904, center loss: 0.0000
Fraction of nonzero parameters 0.9999157465632129
Validation loss 3.3781723976135254 validation accuracy 0.0965 

Epoch 4, CE Loss: 3.2477, center loss: 0.0000
Fraction of nonzero parameters 0.9999179158792764
Validation loss 3.2737032413482665 validation accuracy 0.13 

Epoch 5, CE Loss: 3.0607, center loss: 0.0000
Fraction of nonzero parameters 0.9999169536826353
Validation loss 3.153542917966843 validation accuracy 0.148 

Epoch 6, CE Loss: 2.8920, center loss:

Experiment Progress:  20%|██        | 1/5 [23:16<1:33:04, 1396.23s/it]

Total confusion 0.967084
Intra-phase confusion 0.0
Per task confusions 0.967084
CL TRAIN!!
Epoch 0  CELoss: 2.9029, KLLoss: 0.1574, L1Loss: 0.0000, EWCLoss: 0.0000, CenterLoss: 0.0000, InterCenterLoss: 0.0000, ParamReuseLoss: 0.0000
Fraction of nonzero parameters 0.9999237731090018
Validation losses: 2.55069095492363 0.25484737753868103
Validation accuracy (for last task) 0.117
Total validation accuracy 0.21175


Epoch 1  CELoss: 2.4014, KLLoss: 0.2339, L1Loss: 0.0000, EWCLoss: 0.0000, CenterLoss: 0.0000, InterCenterLoss: 0.0000, ParamReuseLoss: 0.0000
Fraction of nonzero parameters 0.9999238428818997
Validation losses: 2.3997799456119537 0.28782689571380615
Validation accuracy (for last task) 0.211
Total validation accuracy 0.2555


Epoch 2  CELoss: 2.1027, KLLoss: 0.2388, L1Loss: 0.0000, EWCLoss: 0.0000, CenterLoss: 0.0000, InterCenterLoss: 0.0000, ParamReuseLoss: 0.0000
Fraction of nonzero parameters 0.9999262674901032
Validation losses: 2.377864217758179 0.3133918523788452
Validati

Experiment Progress:  40%|████      | 2/5 [38:21<55:22, 1107.50s/it]  

Total confusion 0.9795945
Intra-phase confusion 0.4942885
Per task confusions 0.968256
CL TRAIN!!
Epoch 0  CELoss: 2.7842, KLLoss: 0.1661, L1Loss: 0.0000, EWCLoss: 0.0000, CenterLoss: 0.0000, InterCenterLoss: 0.0000, ParamReuseLoss: 0.0000
Fraction of nonzero parameters 0.9999322049545447
Validation losses: 2.481087452173233 0.2102882146835327
Validation accuracy (for last task) 0.124
Total validation accuracy 0.188


Epoch 1  CELoss: 2.2839, KLLoss: 0.2339, L1Loss: 0.0000, EWCLoss: 0.0000, CenterLoss: 0.0000, InterCenterLoss: 0.0000, ParamReuseLoss: 0.0000
Fraction of nonzero parameters 0.9999289699934224
Validation losses: 2.362017822265625 0.2710171341896057
Validation accuracy (for last task) 0.2015
Total validation accuracy 0.1955


Epoch 2  CELoss: 1.9621, KLLoss: 0.2392, L1Loss: 0.0000, EWCLoss: 0.0000, CenterLoss: 0.0000, InterCenterLoss: 0.0000, ParamReuseLoss: 0.0000
Fraction of nonzero parameters 0.9999323267003934
Validation losses: 2.3100959599018096 0.32391855120658875
Va

Experiment Progress:  60%|██████    | 3/5 [52:06<32:37, 978.59s/it] 

Total confusion 0.9846785
Intra-phase confusion 0.6529546666666667
Per task confusions 0.9685643333333332
CL TRAIN!!
Epoch 0  CELoss: 2.8331, KLLoss: 0.1165, L1Loss: 0.0000, EWCLoss: 0.0000, CenterLoss: 0.0000, InterCenterLoss: 0.0000, ParamReuseLoss: 0.0000
Fraction of nonzero parameters 0.9999354372209432
Validation losses: 2.6483939707279207 0.1802607297897339
Validation accuracy (for last task) 0.0765
Total validation accuracy 0.153125


Epoch 1  CELoss: 2.3734, KLLoss: 0.1834, L1Loss: 0.0000, EWCLoss: 0.0000, CenterLoss: 0.0000, InterCenterLoss: 0.0000, ParamReuseLoss: 0.0000
Fraction of nonzero parameters 0.9999358013945561
Validation losses: 2.4572916686534882 0.19780558943748475
Validation accuracy (for last task) 0.1575
Total validation accuracy 0.166375


Epoch 2  CELoss: 2.0756, KLLoss: 0.1939, L1Loss: 0.0000, EWCLoss: 0.0000, CenterLoss: 0.0000, InterCenterLoss: 0.0000, ParamReuseLoss: 0.0000
Fraction of nonzero parameters 0.9999358881025591
Validation losses: 2.47232310175

Experiment Progress:  80%|████████  | 4/5 [1:08:55<16:30, 990.34s/it]

Total confusion 0.988391625
Intra-phase confusion 0.7374385
Per task confusions 0.969691125
CL TRAIN!!
Epoch 0  CELoss: 2.5242, KLLoss: 0.1368, L1Loss: 0.0000, EWCLoss: 0.0000, CenterLoss: 0.0000, InterCenterLoss: 0.0000, ParamReuseLoss: 0.0000
Fraction of nonzero parameters 0.9999395498549516
Validation losses: 2.287313610315323 0.2016712725162506
Validation accuracy (for last task) 0.153
Total validation accuracy 0.1265


Epoch 1  CELoss: 2.0301, KLLoss: 0.1970, L1Loss: 0.0000, EWCLoss: 0.0000, CenterLoss: 0.0000, InterCenterLoss: 0.0000, ParamReuseLoss: 0.0000
Fraction of nonzero parameters 0.9999408121148454
Validation losses: 2.1029998302459716 0.21369392275810242
Validation accuracy (for last task) 0.2645
Total validation accuracy 0.1356


Epoch 2  CELoss: 1.7561, KLLoss: 0.2023, L1Loss: 0.0000, EWCLoss: 0.0000, CenterLoss: 0.0000, InterCenterLoss: 0.0000, ParamReuseLoss: 0.0000
Fraction of nonzero parameters 0.9999394979812574
Validation losses: 2.0001137018203736 0.243707627058

Experiment Progress: 100%|██████████| 5/5 [1:24:36<00:00, 1015.39s/it]

Total confusion 0.989728
Intra-phase confusion 0.7840871
Per task confusions 0.9696906000000001





Average SHAPC values (ordered as tasks): [3.163390837546132e-08, 2.9355965743604884e-08, 2.9063340949875283e-08, 2.4709996164773428e-08, 0.0]
Averaged SHAPC value (the smaller the better): 2.2952642246742983e-08
Average attention spread values (ordered as tasks): [521.0138665186576, 528.6964583333341, 534.5057629648846, 534.8133495941155, 536.9082104848234]
Averaged attention spread value (the bigger the better): 531.1875295791631
Run 1 finished with accuracy 0.1228
Mean accuracy across 2 runs: 0.12095
Standard deviation of accuracy across 2 runs: 0.002616295090390232

Mean total confusion across 2 runs: 0.9894783
Standard deviation of total confusion across 2 runs: 0.00035312912652458924

Mean intra-phase confusion across 2 runs: 0.7828639500000001
Standard deviation of intra-phase confusion across 2 runs: 0.0017297953188166598

Mean per-task confusion across 2 runs: 0.9694119000000001
Standard deviation of per-task confusion across 2 runs: 0.0003941413198333322

Mean embedding drift 

In [7]:
run_experiments()

Starting run 1.
0 1000 2000 3000 4000 5000 6000 7000 8000 9000 10000 11000 12000 13000 14000 15000 16000 17000 18000 19000 20000 21000 22000 23000 24000 25000 26000 27000 28000 29000 30000 31000 32000 33000 34000 35000 36000 37000 38000 39000 40000 41000 42000 43000 44000 45000 46000 47000 48000 49000 50000 51000 52000 53000 54000 55000 56000 57000 58000 59000 60000 61000 62000 63000 64000 65000 66000 67000 68000 69000 70000 71000 72000 73000 74000 75000 76000 77000 78000 79000 80000 81000 82000 83000 84000 85000 86000 87000 88000 89000 0 1000 2000 3000 4000 5000 6000 7000 8000 9000 0 1000 2000 3000 4000 5000 6000 7000 8000 9000 

Experiment Progress:   0%|          | 0/5 [00:00<?, ?it/s]

TRAINING!


Experiment Progress:  80%|████████  | 4/5 [32:38<07:15, 435.12s/it]

ALL EMBEDDINGS LEN 10000


Experiment Progress: 100%|██████████| 5/5 [40:17<00:00, 483.41s/it]


Run 0 finished with accuracy 0.0796
Starting run 2.
0 1000 2000 3000 4000 5000 6000 7000 8000 9000 10000 11000 12000 13000 14000 15000 16000 17000 18000 19000 20000 21000 22000 23000 24000 25000 26000 27000 28000 29000 30000 31000 32000 33000 34000 35000 36000 37000 38000 39000 40000 41000 42000 43000 44000 45000 46000 47000 48000 49000 50000 51000 52000 53000 54000 55000 56000 57000 58000 59000 60000 61000 62000 63000 64000 65000 66000 67000 68000 69000 70000 71000 72000 73000 74000 75000 76000 77000 78000 79000 80000 81000 82000 83000 84000 85000 86000 87000 88000 89000 0 1000 2000 3000 4000 5000 6000 7000 8000 9000 0 1000 2000 3000 4000 5000 6000 7000 8000 9000 

Experiment Progress:   0%|          | 0/5 [00:00<?, ?it/s]

TRAINING!


Experiment Progress:  80%|████████  | 4/5 [32:43<07:16, 436.50s/it]

ALL EMBEDDINGS LEN 10000


Experiment Progress: 100%|██████████| 5/5 [40:25<00:00, 485.18s/it]


Run 1 finished with accuracy 0.0747
Mean accuracy across 2 runs: 0.07715
Standard deviation of accuracy across 2 runs: 0.003464823227814084

Mean total confusion across 2 runs: 0.9901413
Standard deviation of total confusion across 2 runs: 0.00022174868658004045

Mean intra-phase confusion across 2 runs: 0.7862998
Standard deviation of intra-phase confusion across 2 runs: 0.0013937074657186463

Mean per-task confusion across 2 runs: 0.97019425
Standard deviation of per-task confusion across 2 runs: 3.3021886681458536e-05

Mean embedding drift across 2 runs: 10.08632230758667
Standard deviation of embedding drift across 2 runs: 0.457977096813441

Mean attention drift across 2 runs: 2.377720081443435e-08
Standard deviation of attention drift across 2 runs: 6.49803816454002e-11

Mean attention spread across 2 runs: 542.7944798161824
Standard deviation of attention spread across 2 runs: 1.633229647082942



In [None]:
run_experiments(full_CE=False)

Starting run 1.
0 1000 2000 3000 4000 5000 6000 7000 8000 9000 10000 11000 12000 13000 14000 15000 16000 17000 18000 19000 20000 21000 22000 23000 24000 25000 26000 27000 28000 29000 30000 31000 32000 33000 34000 35000 36000 37000 38000 39000 40000 41000 42000 43000 44000 45000 46000 47000 48000 49000 50000 51000 52000 53000 54000 55000 56000 57000 58000 59000 60000 61000 62000 63000 64000 65000 66000 67000 68000 69000 70000 71000 72000 73000 74000 75000 76000 77000 78000 79000 80000 81000 82000 83000 84000 85000 86000 87000 88000 89000 0 1000 2000 3000 4000 5000 6000 7000 8000 9000 0 1000 2000 3000 4000 5000 6000 7000 8000 9000 

Experiment Progress:   0%|          | 0/5 [00:00<?, ?it/s]

TRAINING!


Experiment Progress:  80%|████████  | 4/5 [29:57<06:20, 380.93s/it]

ALL EMBEDDINGS LEN 10000


Experiment Progress: 100%|██████████| 5/5 [39:44<00:00, 476.99s/it]


Run 0 finished with accuracy 0.1022
Starting run 2.
0 1000 2000 3000 4000 5000 6000 7000 8000 9000 10000 11000 12000 13000 14000 15000 16000 17000 18000 19000 20000 21000 22000 23000 24000 25000 26000 27000 28000 29000 30000 31000 32000 33000 34000 35000 36000 37000 38000 39000 40000 41000 42000 43000 44000 45000 46000 47000 48000 49000 50000 51000 52000 53000 54000 55000 56000 57000 58000 59000 60000 61000 62000 63000 64000 65000 66000 67000 68000 69000 70000 71000 72000 73000 74000 75000 76000 77000 78000 79000 80000 81000 82000 83000 84000 85000 86000 87000 88000 89000 0 1000 2000 3000 4000 5000 6000 7000 8000 9000 0 1000 2000 3000 4000 5000 6000 7000 8000 9000 

Experiment Progress:   0%|          | 0/5 [00:00<?, ?it/s]

TRAINING!


Experiment Progress:  40%|████      | 2/5 [20:31<28:32, 570.96s/it]

In [9]:
run_experiments(with_OOD=True)

Starting run 1.


Experiment Progress: 100%|██████████| 5/5 [01:54<00:00, 22.95s/it]


Starting run 2.


Experiment Progress: 100%|██████████| 5/5 [01:56<00:00, 23.32s/it]


Starting run 3.


Experiment Progress: 100%|██████████| 5/5 [01:54<00:00, 22.94s/it]


Starting run 4.


Experiment Progress: 100%|██████████| 5/5 [01:54<00:00, 22.83s/it]


Starting run 5.


Experiment Progress: 100%|██████████| 5/5 [01:50<00:00, 22.05s/it]


Starting run 6.


Experiment Progress: 100%|██████████| 5/5 [01:59<00:00, 23.83s/it]


Starting run 7.


Experiment Progress: 100%|██████████| 5/5 [01:48<00:00, 21.64s/it]


Starting run 8.


Experiment Progress: 100%|██████████| 5/5 [01:53<00:00, 22.61s/it]


Starting run 9.


Experiment Progress: 100%|██████████| 5/5 [01:52<00:00, 22.51s/it]


Starting run 10.


Experiment Progress: 100%|██████████| 5/5 [01:51<00:00, 22.28s/it]


Starting run 11.


Experiment Progress: 100%|██████████| 5/5 [01:54<00:00, 22.94s/it]


Starting run 12.


Experiment Progress: 100%|██████████| 5/5 [01:52<00:00, 22.52s/it]


Starting run 13.


Experiment Progress: 100%|██████████| 5/5 [01:45<00:00, 21.05s/it]


Starting run 14.


Experiment Progress: 100%|██████████| 5/5 [02:02<00:00, 24.40s/it]


Starting run 15.


Experiment Progress: 100%|██████████| 5/5 [01:52<00:00, 22.56s/it]


Starting run 16.


Experiment Progress: 100%|██████████| 5/5 [01:49<00:00, 21.95s/it]


Starting run 17.


Experiment Progress: 100%|██████████| 5/5 [01:57<00:00, 23.54s/it]


Starting run 18.


Experiment Progress: 100%|██████████| 5/5 [01:51<00:00, 22.34s/it]


Starting run 19.


Experiment Progress: 100%|██████████| 5/5 [01:54<00:00, 22.84s/it]


Starting run 20.


Experiment Progress: 100%|██████████| 5/5 [01:56<00:00, 23.27s/it]


Mean accuracy across 20 runs: 0.6283
Standard deviation of accuracy across 20 runs: 0.048455655256283446

Mean total confusion across 20 runs: 0.42272846
Standard deviation of total confusion across 20 runs: 0.013965500739751512

Mean intra-phase confusion across 20 runs: 0.420302175
Standard deviation of intra-phase confusion across 20 runs: 0.013710185111609733

Mean per-task confusion across 20 runs: 0.06710880139771606
Standard deviation of per-task confusion across 20 runs: 0.005308599754009476

Mean embedding drift across 20 runs: 5.311595153808594
Standard deviation of embedding drift across 20 runs: 0.26183589240492755

Mean attention drift across 20 runs: 2.481615933800046e-06
Standard deviation of attention drift across 20 runs: 2.374863145379374e-07

Mean attention spread across 20 runs: 48.717932207607355
Standard deviation of attention spread across 20 runs: 1.6846581863184822



In [10]:
run_experiments(kd_loss=1)

Starting run 1.


Experiment Progress: 100%|██████████| 5/5 [02:39<00:00, 31.98s/it]


Starting run 2.


Experiment Progress: 100%|██████████| 5/5 [02:42<00:00, 32.45s/it]


Starting run 3.


Experiment Progress: 100%|██████████| 5/5 [02:45<00:00, 33.07s/it]


Starting run 4.


Experiment Progress: 100%|██████████| 5/5 [02:38<00:00, 31.61s/it]


Starting run 5.


Experiment Progress: 100%|██████████| 5/5 [02:37<00:00, 31.40s/it]


Starting run 6.


Experiment Progress: 100%|██████████| 5/5 [02:43<00:00, 32.67s/it]


Starting run 7.


Experiment Progress: 100%|██████████| 5/5 [02:44<00:00, 32.88s/it]


Starting run 8.


Experiment Progress: 100%|██████████| 5/5 [02:42<00:00, 32.52s/it]


Starting run 9.


Experiment Progress: 100%|██████████| 5/5 [02:39<00:00, 31.83s/it]


Starting run 10.


Experiment Progress: 100%|██████████| 5/5 [02:42<00:00, 32.41s/it]


Starting run 11.


Experiment Progress: 100%|██████████| 5/5 [02:45<00:00, 33.07s/it]


Starting run 12.


Experiment Progress: 100%|██████████| 5/5 [02:40<00:00, 32.04s/it]


Starting run 13.


Experiment Progress: 100%|██████████| 5/5 [02:39<00:00, 31.97s/it]


Starting run 14.


Experiment Progress: 100%|██████████| 5/5 [02:42<00:00, 32.52s/it]


Starting run 15.


Experiment Progress: 100%|██████████| 5/5 [02:39<00:00, 31.99s/it]


Starting run 16.


Experiment Progress: 100%|██████████| 5/5 [02:38<00:00, 31.74s/it]


Starting run 17.


Experiment Progress: 100%|██████████| 5/5 [02:39<00:00, 31.91s/it]


Starting run 18.


Experiment Progress: 100%|██████████| 5/5 [02:35<00:00, 31.20s/it]


Starting run 19.


Experiment Progress: 100%|██████████| 5/5 [02:40<00:00, 32.15s/it]


Starting run 20.


Experiment Progress: 100%|██████████| 5/5 [02:45<00:00, 33.09s/it]


Mean accuracy across 20 runs: 0.198135
Standard deviation of accuracy across 20 runs: 0.0023528873639357285

Mean total confusion across 20 runs: 0.35109726
Standard deviation of total confusion across 20 runs: 0.014281296893541353

Mean intra-phase confusion across 20 runs: 0.343802215
Standard deviation of intra-phase confusion across 20 runs: 0.013427579682730567

Mean per-task confusion across 20 runs: 0.07598414344695434
Standard deviation of per-task confusion across 20 runs: 0.008372468915565324

Mean embedding drift across 20 runs: 11.377571392059327
Standard deviation of embedding drift across 20 runs: 0.4583113572741939

Mean attention drift across 20 runs: 3.4306784140112743e-06
Standard deviation of attention drift across 20 runs: 5.925314459883529e-07

Mean attention spread across 20 runs: 70.53680304246653
Standard deviation of attention spread across 20 runs: 5.56448982202968



In [11]:
run_experiments(ogd=True)

Starting run 1.


Experiment Progress: 100%|██████████| 5/5 [02:17<00:00, 27.46s/it]


Starting run 2.


Experiment Progress: 100%|██████████| 5/5 [02:21<00:00, 28.23s/it]


Starting run 3.


Experiment Progress: 100%|██████████| 5/5 [02:22<00:00, 28.46s/it]


Starting run 4.


Experiment Progress: 100%|██████████| 5/5 [02:16<00:00, 27.39s/it]


Starting run 5.


Experiment Progress: 100%|██████████| 5/5 [02:16<00:00, 27.28s/it]


Starting run 6.


Experiment Progress: 100%|██████████| 5/5 [02:21<00:00, 28.32s/it]


Starting run 7.


Experiment Progress: 100%|██████████| 5/5 [02:19<00:00, 27.93s/it]


Starting run 8.


Experiment Progress: 100%|██████████| 5/5 [02:19<00:00, 27.89s/it]


Starting run 9.


Experiment Progress: 100%|██████████| 5/5 [02:22<00:00, 28.48s/it]


Starting run 10.


Experiment Progress: 100%|██████████| 5/5 [02:16<00:00, 27.37s/it]


Starting run 11.


Experiment Progress: 100%|██████████| 5/5 [02:34<00:00, 30.96s/it]


Starting run 12.


Experiment Progress: 100%|██████████| 5/5 [02:23<00:00, 28.78s/it]


Starting run 13.


Experiment Progress: 100%|██████████| 5/5 [02:23<00:00, 28.63s/it]


Starting run 14.


Experiment Progress: 100%|██████████| 5/5 [02:20<00:00, 28.20s/it]


Starting run 15.


Experiment Progress: 100%|██████████| 5/5 [02:26<00:00, 29.21s/it]


Starting run 16.


Experiment Progress: 100%|██████████| 5/5 [02:21<00:00, 28.37s/it]


Starting run 17.


Experiment Progress: 100%|██████████| 5/5 [02:31<00:00, 30.33s/it]


Starting run 18.


Experiment Progress: 100%|██████████| 5/5 [02:23<00:00, 28.75s/it]


Starting run 19.


Experiment Progress: 100%|██████████| 5/5 [02:25<00:00, 29.05s/it]


Starting run 20.


Experiment Progress: 100%|██████████| 5/5 [02:28<00:00, 29.62s/it]


Mean accuracy across 20 runs: 0.209255
Standard deviation of accuracy across 20 runs: 0.011154440937652544

Mean total confusion across 20 runs: 0.396663665
Standard deviation of total confusion across 20 runs: 0.011422315999429186

Mean intra-phase confusion across 20 runs: 0.392806475
Standard deviation of intra-phase confusion across 20 runs: 0.011228882843006858

Mean per-task confusion across 20 runs: 0.06895191923239256
Standard deviation of per-task confusion across 20 runs: 0.004630687357285916

Mean embedding drift across 20 runs: 10.525574493408204
Standard deviation of embedding drift across 20 runs: 0.536590333801918

Mean attention drift across 20 runs: 3.5386282061989125e-06
Standard deviation of attention drift across 20 runs: 3.211259395193246e-07

Mean attention spread across 20 runs: 52.21747352882258
Standard deviation of attention spread across 20 runs: 2.8498456039367315



In [12]:
run_experiments(with_dropout=True, kd_loss=1, full_CE=False, with_OOD=True, ogd=True) # best version

Starting run 1.


Experiment Progress: 100%|██████████| 5/5 [06:10<00:00, 74.01s/it] 


Starting run 2.


Experiment Progress: 100%|██████████| 5/5 [06:33<00:00, 78.78s/it] 


Starting run 3.


Experiment Progress: 100%|██████████| 5/5 [06:23<00:00, 76.72s/it] 


Starting run 4.


Experiment Progress: 100%|██████████| 5/5 [05:45<00:00, 69.09s/it]


Starting run 5.


Experiment Progress: 100%|██████████| 5/5 [06:08<00:00, 73.77s/it] 


Starting run 6.


Experiment Progress: 100%|██████████| 5/5 [06:05<00:00, 73.20s/it] 


Starting run 7.


Experiment Progress: 100%|██████████| 5/5 [06:27<00:00, 77.55s/it]


Starting run 8.


Experiment Progress: 100%|██████████| 5/5 [06:14<00:00, 74.83s/it]


Starting run 9.


Experiment Progress: 100%|██████████| 5/5 [06:22<00:00, 76.54s/it]


Starting run 10.


Experiment Progress: 100%|██████████| 5/5 [05:28<00:00, 65.77s/it]


Starting run 11.


Experiment Progress: 100%|██████████| 5/5 [06:11<00:00, 74.39s/it]


Starting run 12.


Experiment Progress: 100%|██████████| 5/5 [06:11<00:00, 74.22s/it]


Starting run 13.


Experiment Progress: 100%|██████████| 5/5 [06:35<00:00, 79.09s/it] 


Starting run 14.


Experiment Progress: 100%|██████████| 5/5 [05:39<00:00, 67.81s/it]


Starting run 15.


Experiment Progress: 100%|██████████| 5/5 [06:12<00:00, 74.42s/it]


Starting run 16.


Experiment Progress: 100%|██████████| 5/5 [06:28<00:00, 77.76s/it] 


Starting run 17.


Experiment Progress: 100%|██████████| 5/5 [06:29<00:00, 77.97s/it] 


Starting run 18.


Experiment Progress: 100%|██████████| 5/5 [06:33<00:00, 78.76s/it]


Starting run 19.


Experiment Progress: 100%|██████████| 5/5 [06:43<00:00, 80.66s/it] 


Starting run 20.


Experiment Progress: 100%|██████████| 5/5 [05:52<00:00, 70.54s/it]


Mean accuracy across 20 runs: 0.871815
Standard deviation of accuracy across 20 runs: 0.01575520882956357

Mean total confusion across 20 runs: 0.316397895
Standard deviation of total confusion across 20 runs: 0.007647750371082655

Mean intra-phase confusion across 20 runs: 0.31410229
Standard deviation of intra-phase confusion across 20 runs: 0.00771135210363335

Mean per-task confusion across 20 runs: 0.05065046355999813
Standard deviation of per-task confusion across 20 runs: 0.0017025988970819426

Mean embedding drift across 20 runs: 5.121034741401672
Standard deviation of embedding drift across 20 runs: 0.297980671060721

Mean attention drift across 20 runs: 1.3957485409985412e-06
Standard deviation of attention drift across 20 runs: 8.139538808058916e-08

Mean attention spread across 20 runs: 59.9870614130834
Standard deviation of attention spread across 20 runs: 2.964589650580979



In [13]:
run_experiments(with_dropout=True, kd_loss=1, full_CE=False, with_OOD=True, ogd=False)

Starting run 1.


Experiment Progress: 100%|██████████| 5/5 [05:45<00:00, 69.02s/it]


Starting run 2.


Experiment Progress: 100%|██████████| 5/5 [05:47<00:00, 69.59s/it]


Starting run 3.


Experiment Progress: 100%|██████████| 5/5 [06:15<00:00, 75.10s/it] 


Starting run 4.


Experiment Progress: 100%|██████████| 5/5 [06:12<00:00, 74.51s/it]


Starting run 5.


Experiment Progress: 100%|██████████| 5/5 [05:21<00:00, 64.22s/it]


Starting run 6.


Experiment Progress: 100%|██████████| 5/5 [05:17<00:00, 63.54s/it]


Starting run 7.


Experiment Progress: 100%|██████████| 5/5 [05:42<00:00, 68.40s/it]


Starting run 8.


Experiment Progress: 100%|██████████| 5/5 [05:51<00:00, 70.29s/it]


Starting run 9.


Experiment Progress: 100%|██████████| 5/5 [05:23<00:00, 64.60s/it]


Starting run 10.


Experiment Progress: 100%|██████████| 5/5 [05:30<00:00, 66.15s/it]


Starting run 11.


Experiment Progress: 100%|██████████| 5/5 [05:43<00:00, 68.67s/it]


Starting run 12.


Experiment Progress: 100%|██████████| 5/5 [05:26<00:00, 65.37s/it]


Starting run 13.


Experiment Progress: 100%|██████████| 5/5 [05:04<00:00, 60.98s/it]


Starting run 14.


Experiment Progress: 100%|██████████| 5/5 [05:54<00:00, 70.86s/it]


Starting run 15.


Experiment Progress: 100%|██████████| 5/5 [05:26<00:00, 65.21s/it]


Starting run 16.


Experiment Progress: 100%|██████████| 5/5 [05:43<00:00, 68.69s/it]


Starting run 17.


Experiment Progress: 100%|██████████| 5/5 [05:21<00:00, 64.28s/it]


Starting run 18.


Experiment Progress: 100%|██████████| 5/5 [05:23<00:00, 64.66s/it]


Starting run 19.


Experiment Progress: 100%|██████████| 5/5 [05:31<00:00, 66.36s/it]


Starting run 20.


Experiment Progress: 100%|██████████| 5/5 [05:06<00:00, 61.35s/it]


Mean accuracy across 20 runs: 0.858475
Standard deviation of accuracy across 20 runs: 0.023806497896693323

Mean total confusion across 20 runs: 0.31626039499999997
Standard deviation of total confusion across 20 runs: 0.008508850530599231

Mean intra-phase confusion across 20 runs: 0.313983085
Standard deviation of intra-phase confusion across 20 runs: 0.008481222010528212

Mean per-task confusion across 20 runs: 0.05121441125292367
Standard deviation of per-task confusion across 20 runs: 0.0026791838663664923

Mean embedding drift across 20 runs: 5.118687725067138
Standard deviation of embedding drift across 20 runs: 0.19137689914486014

Mean attention drift across 20 runs: 1.4939435838786699e-06
Standard deviation of attention drift across 20 runs: 8.791844230266099e-08

Mean attention spread across 20 runs: 58.78804669647447
Standard deviation of attention spread across 20 runs: 2.3469737513292404



In [14]:
run_experiments(with_dropout=False, kd_loss=1, full_CE=False, with_OOD=True, ogd=True)

Starting run 1.


Experiment Progress: 100%|██████████| 5/5 [03:22<00:00, 40.58s/it]


Starting run 2.


Experiment Progress: 100%|██████████| 5/5 [03:24<00:00, 40.89s/it]


Starting run 3.


Experiment Progress: 100%|██████████| 5/5 [03:37<00:00, 43.58s/it]


Starting run 4.


Experiment Progress: 100%|██████████| 5/5 [03:32<00:00, 42.46s/it]


Starting run 5.


Experiment Progress: 100%|██████████| 5/5 [03:26<00:00, 41.29s/it]


Starting run 6.


Experiment Progress: 100%|██████████| 5/5 [03:49<00:00, 45.95s/it]


Starting run 7.


Experiment Progress: 100%|██████████| 5/5 [03:29<00:00, 41.91s/it]


Starting run 8.


Experiment Progress: 100%|██████████| 5/5 [03:23<00:00, 40.77s/it]


Starting run 9.


Experiment Progress: 100%|██████████| 5/5 [03:24<00:00, 40.91s/it]


Starting run 10.


Experiment Progress: 100%|██████████| 5/5 [03:22<00:00, 40.43s/it]


Starting run 11.


Experiment Progress: 100%|██████████| 5/5 [03:30<00:00, 42.13s/it]


Starting run 12.


Experiment Progress: 100%|██████████| 5/5 [03:19<00:00, 39.96s/it]


Starting run 13.


Experiment Progress: 100%|██████████| 5/5 [03:24<00:00, 40.94s/it]


Starting run 14.


Experiment Progress: 100%|██████████| 5/5 [03:28<00:00, 41.69s/it]


Starting run 15.


Experiment Progress: 100%|██████████| 5/5 [03:24<00:00, 40.82s/it]


Starting run 16.


Experiment Progress: 100%|██████████| 5/5 [03:31<00:00, 42.38s/it]


Starting run 17.


Experiment Progress: 100%|██████████| 5/5 [03:31<00:00, 42.30s/it]


Starting run 18.


Experiment Progress: 100%|██████████| 5/5 [03:18<00:00, 39.80s/it]


Starting run 19.


Experiment Progress: 100%|██████████| 5/5 [03:29<00:00, 41.91s/it]


Starting run 20.


Experiment Progress: 100%|██████████| 5/5 [03:28<00:00, 41.65s/it]


Mean accuracy across 20 runs: 0.81172
Standard deviation of accuracy across 20 runs: 0.033084828769948955

Mean total confusion across 20 runs: 0.347483375
Standard deviation of total confusion across 20 runs: 0.007654212741523459

Mean intra-phase confusion across 20 runs: 0.3439667
Standard deviation of intra-phase confusion across 20 runs: 0.007557838211454804

Mean per-task confusion across 20 runs: 0.06472576782833862
Standard deviation of per-task confusion across 20 runs: 0.003323329509736449

Mean embedding drift across 20 runs: 5.184679293632508
Standard deviation of embedding drift across 20 runs: 0.26218339714907

Mean attention drift across 20 runs: 7.776475092738062e-07
Standard deviation of attention drift across 20 runs: 1.5827700630732783e-07

Mean attention spread across 20 runs: 71.38503082507216
Standard deviation of attention spread across 20 runs: 5.877301429072464



In [15]:
run_experiments(with_dropout=True, kd_loss=1, full_CE=False, with_OOD=False, ogd=True)

Starting run 1.


Experiment Progress: 100%|██████████| 5/5 [04:37<00:00, 55.51s/it]


Starting run 2.


Experiment Progress: 100%|██████████| 5/5 [04:22<00:00, 52.54s/it]


Starting run 3.


Experiment Progress: 100%|██████████| 5/5 [04:02<00:00, 48.54s/it]


Starting run 4.


Experiment Progress: 100%|██████████| 5/5 [04:10<00:00, 50.14s/it]


Starting run 5.


Experiment Progress: 100%|██████████| 5/5 [04:06<00:00, 49.36s/it]


Starting run 6.


Experiment Progress: 100%|██████████| 5/5 [04:28<00:00, 53.77s/it]


Starting run 7.


Experiment Progress: 100%|██████████| 5/5 [04:34<00:00, 54.97s/it]


Starting run 8.


Experiment Progress: 100%|██████████| 5/5 [04:24<00:00, 52.85s/it]


Starting run 9.


Experiment Progress: 100%|██████████| 5/5 [04:20<00:00, 52.04s/it]


Starting run 10.


Experiment Progress: 100%|██████████| 5/5 [04:27<00:00, 53.54s/it]


Starting run 11.


Experiment Progress: 100%|██████████| 5/5 [04:32<00:00, 54.51s/it]


Starting run 12.


Experiment Progress: 100%|██████████| 5/5 [04:10<00:00, 50.08s/it]


Starting run 13.


Experiment Progress: 100%|██████████| 5/5 [04:51<00:00, 58.23s/it]


Starting run 14.


Experiment Progress: 100%|██████████| 5/5 [04:31<00:00, 54.37s/it]


Starting run 15.


Experiment Progress: 100%|██████████| 5/5 [04:15<00:00, 51.06s/it]


Starting run 16.


Experiment Progress: 100%|██████████| 5/5 [04:28<00:00, 53.68s/it]


Starting run 17.


Experiment Progress: 100%|██████████| 5/5 [04:19<00:00, 51.88s/it]


Starting run 18.


Experiment Progress: 100%|██████████| 5/5 [04:08<00:00, 49.73s/it]


Starting run 19.


Experiment Progress: 100%|██████████| 5/5 [04:33<00:00, 54.73s/it]


Starting run 20.


Experiment Progress: 100%|██████████| 5/5 [04:34<00:00, 54.94s/it]


Mean accuracy across 20 runs: 0.76306
Standard deviation of accuracy across 20 runs: 0.022815239597574616

Mean total confusion across 20 runs: 0.38624753500000003
Standard deviation of total confusion across 20 runs: 0.007639918549667443

Mean intra-phase confusion across 20 runs: 0.384547365
Standard deviation of intra-phase confusion across 20 runs: 0.007622410869291255

Mean per-task confusion across 20 runs: 0.05615223586213179
Standard deviation of per-task confusion across 20 runs: 0.0021236330593199524

Mean embedding drift across 20 runs: 4.436177814006806
Standard deviation of embedding drift across 20 runs: 0.20393670061664768

Mean attention drift across 20 runs: 1.288934924350007e-06
Standard deviation of attention drift across 20 runs: 7.783110215203474e-08

Mean attention spread across 20 runs: 49.28415918543767
Standard deviation of attention spread across 20 runs: 1.5464254827885922



In [16]:
# training on entire dataset
globals.ITERATIONS = 1
globals.CLASSES_PER_ITER = 10
run_experiments()

Starting run 1.


Experiment Progress: 100%|██████████| 1/1 [03:34<00:00, 214.20s/it]


Starting run 2.


Experiment Progress: 100%|██████████| 1/1 [03:31<00:00, 211.66s/it]


Starting run 3.


Experiment Progress: 100%|██████████| 1/1 [03:31<00:00, 211.63s/it]


Starting run 4.


Experiment Progress: 100%|██████████| 1/1 [03:28<00:00, 208.76s/it]


Starting run 5.


Experiment Progress: 100%|██████████| 1/1 [03:24<00:00, 204.90s/it]


Starting run 6.


Experiment Progress: 100%|██████████| 1/1 [04:30<00:00, 270.70s/it]


Starting run 7.


Experiment Progress: 100%|██████████| 1/1 [03:24<00:00, 204.43s/it]


Starting run 8.


Experiment Progress: 100%|██████████| 1/1 [03:23<00:00, 203.38s/it]


Starting run 9.


Experiment Progress: 100%|██████████| 1/1 [04:00<00:00, 240.87s/it]


Starting run 10.


Experiment Progress: 100%|██████████| 1/1 [03:21<00:00, 201.48s/it]


Starting run 11.


Experiment Progress: 100%|██████████| 1/1 [03:49<00:00, 229.31s/it]


Starting run 12.


Experiment Progress: 100%|██████████| 1/1 [03:25<00:00, 205.38s/it]


Starting run 13.


Experiment Progress: 100%|██████████| 1/1 [03:25<00:00, 205.48s/it]


Starting run 14.


Experiment Progress: 100%|██████████| 1/1 [03:24<00:00, 204.68s/it]


Starting run 15.


Experiment Progress: 100%|██████████| 1/1 [03:37<00:00, 217.50s/it]


Starting run 16.


Experiment Progress: 100%|██████████| 1/1 [03:28<00:00, 208.54s/it]


Starting run 17.


Experiment Progress: 100%|██████████| 1/1 [03:57<00:00, 237.87s/it]


Starting run 18.


Experiment Progress: 100%|██████████| 1/1 [04:03<00:00, 243.11s/it]


Starting run 19.


Experiment Progress: 100%|██████████| 1/1 [04:00<00:00, 240.77s/it]


Starting run 20.


Experiment Progress: 100%|██████████| 1/1 [03:54<00:00, 234.82s/it]


Mean accuracy across 20 runs: 0.989975
Standard deviation of accuracy across 20 runs: 0.0010577408895725271

Mean total confusion across 20 runs: 0.20250792
Standard deviation of total confusion across 20 runs: 0.009379041777932884

Mean intra-phase confusion across 20 runs: 0.0
Standard deviation of intra-phase confusion across 20 runs: 0.0

Mean per-task confusion across 20 runs: 0.20250792
Standard deviation of per-task confusion across 20 runs: 0.009379041777932884

Mean embedding drift across 20 runs: 7.819666097930168e-06
Standard deviation of embedding drift across 20 runs: 4.108424095086772e-07

Mean attention drift across 20 runs: 1.5143667320842182e-20
Standard deviation of attention drift across 20 runs: 1.267635715636015e-21

Mean attention spread across 20 runs: 52.118767505178454
Standard deviation of attention spread across 20 runs: 2.053219975656019



In [2]:
torch.cuda.empty_cache()