In [181]:
import torch
import torchvision
from torchvision import transforms, datasets
from sklearn import metrics
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Subset
import numpy as np
import copy
import statistics
import matplotlib.pyplot as plt
import time
from tqdm import tqdm
torch.cuda.empty_cache()
import sys
sys.path.append('./../src')
import globals
from model import MnistCNN, TinyImageNetCNN, Cifar10CNN
from training import train_model, train_model_CL
from visualizations import plot_embeddings, plot_confusion_matrix
from feature_attribution import Feature_Importance_Evaluations
from pytorch_utils import get_features, get_labels
from data_utils import initialize_data
from embedding_measurements import measure_embedding_confusion_knn, measure_embedding_drift

In [182]:
import submitit
import os

RUN_LOCALLY = True  # With False it runs on the Slurm Cluster 

partition = ''

timelimit_hours = 8
num_gpus_per_node = 1
num_jobs = 1

parent_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
log_path = os.path.join(parent_dir, 'submitit_logs')
run_path = os.path.join(parent_dir, 'src')


ex_parallel = submitit.AutoExecutor(folder=log_path)
ex_parallel.update_parameters(
        slurm_signal_delay_s=180,
        tasks_per_node=num_gpus_per_node,
        nodes=1,
        slurm_partition=partition,
        timeout_min=int(timelimit_hours*60),
        cpus_per_task=8,
        slurm_gres=f'gpu:{num_gpus_per_node}',
        slurm_setup = [f'export PYTHONPATH="${{PYTHONPATH}}:{run_path}"'],
    )
SEED = globals.SEED
DEVICE = globals.DEVICE if RUN_LOCALLY else torch.device("cuda:0")
full_trainset = globals.full_trainset
trainset = globals.trainset
testset = globals.testset
trainloaders = globals.trainloaders
valloaders = globals.valloaders
testloaders = globals.testloaders

In [183]:
def evalModel(model, attr_device, i, only_pred = False):
    accumPred = []
    all_labels = []
    all_embeddings = []
    with torch.no_grad():
        for j in range(i+1):
            val_loader = globals.testloaders[j]
            val_labels = get_labels(val_loader).to(attr_device)
            model.eval()
            pred, embeddings = model.get_pred_and_embeddings((get_features(val_loader).to(attr_device)))
            model.train()
            accumPred.append(pred)
            if not only_pred:
                all_embeddings.append(embeddings)
                all_labels.append(val_labels)
    accumPred = torch.cat(accumPred)
    if not only_pred:
        all_labels = torch.cat(all_labels)
        all_embeddings = torch.cat(all_embeddings)
    if globals.OOD_CLASS == 1:
        accumPred = accumPred[:,[i for i in range(accumPred.size(1)) if (i + 1) % (globals.CLASSES_PER_ITER+1) != 0]]
    predicted = torch.argmax(accumPred, dim=1)
    return predicted, all_labels, all_embeddings

In [184]:
def run_experiment(
        verbose = False,
        full_CE = True,
        with_OOD = False,
        kd_loss = 0,
        stopOnValAcc = None,
        epochs = 1000000,
        with_dropout = False,
        optimiser_type = 'sgd',
        dataset = 'mnist',
        joint_training = False,
        save_saliency=False,
        plotting=False,
        lr=0.003
        ):
    def _print(*args, **kwargs):
        if verbose:
            print(*args, **kwargs)
    if dataset == 'tiny_imagenet':
        ogd_basis_size=None
        patience = 6
        ood_method = 'jigsaw'
        if joint_training:
            globals.CLASSES_PER_ITER = 200
            globals.ITERATIONS = 1
        else:
            globals.CLASSES_PER_ITER = 40
            globals.ITERATIONS = 5
    elif dataset == 'mnist':
        ogd_basis_size=200
        patience = 2
        ood_method = 'jigsaw'
        if joint_training:
            globals.CLASSES_PER_ITER = 10
            globals.ITERATIONS = 1
        else:
            globals.CLASSES_PER_ITER = 2
            globals.ITERATIONS = 5
    elif dataset == 'cifar10':
        ogd_basis_size=50
        patience = 4
        ood_method = 'jigsaw'
        if joint_training:
            globals.CLASSES_PER_ITER = 10
            globals.ITERATIONS = 1
        else:
            globals.CLASSES_PER_ITER = 2
            globals.ITERATIONS = 5
    if with_OOD:
        globals.toggle_OOD(ood_method)
    else:
        globals.disable_OOD()
    initialize_data(dataset)
    prevModel = None
    
    globals.WITH_DROPOUT = with_dropout

    #[Denis] added code:
    if dataset == 'tiny_imagenet' or dataset == 'cifar10':
        attr_device = 'cpu' # CPU has high memory in the cluster
    else:
        attr_device = globals.DEVICE
    Feature_Importance_Eval=Feature_Importance_Evaluations(globals.testloaders, attr_device)

    for i in tqdm(range(globals.ITERATIONS), desc="Experiment Progress"):
        if dataset == 'mnist':
            model = MnistCNN((i+1)*(globals.CLASSES_PER_ITER+globals.OOD_CLASS))
        elif dataset == 'tiny_imagenet':
            model = TinyImageNetCNN((i+1)*(globals.CLASSES_PER_ITER+globals.OOD_CLASS))
        elif dataset == 'cifar10':
            model = Cifar10CNN((i+1)*(globals.CLASSES_PER_ITER+globals.OOD_CLASS))
        if prevModel is not None:
            with torch.no_grad():
                model.copyPrev(prevModel)
        train_loader = globals.trainloaders[i]
        val_loader = globals.valloaders[i]
        if prevModel:
            _print("CL TRAIN!!")
            model = train_model_CL(
                model,
                prevModel,
                train_loader,
                val_loader,
                i,
                verbose,
                epochs,
                True,
                kd_loss=kd_loss,
                stopOnLoss=None,
                stopOnValAcc = stopOnValAcc,
                full_CE=full_CE,
                optimiser_type=optimiser_type,
                plotting=plotting,
                patience=patience,
                ogd_basis_size=ogd_basis_size,
                lr=lr
                )
        else:
            _print("TRAINING!")
            model = train_model(
                model, 
                train_loader,
                val_loader, 
                verbose, 
                epochs=epochs, 
                stopOnLoss=None,
                optimiser_type=optimiser_type,
                plotting=plotting,
                patience=patience,
                ogd_basis_size=ogd_basis_size,
                lr=lr
                )

        #[Denis] added code:
        Feature_Importance_Eval.Task_Feature_Attribution(model, i)
        
        if verbose or i == globals.ITERATIONS-1:
            if dataset == 'tiny_imagenet' or dataset == 'cifar10':
                model.to(attr_device)
            _print("Starting evaluation")
            _print("ITERATION", i+1)
            
            predicted, all_labels, all_embeddings = evalModel(model, attr_device, i)
            if dataset == 'tiny_imagenet' or dataset == 'cifar10':
                model.to(DEVICE)
            predicted = predicted.to(attr_device)
            correct = (predicted == all_labels).sum().item()  # Count how many were correct
            accuracy = correct / all_labels.size(0)  # Accuracy as a percentage
            _print("Accuracy on tasks so far:", accuracy)
            embedding_drift = measure_embedding_drift(all_embeddings, all_labels, model.prev_test_embedding_centers)
            _print("Average embedding drift based on centroids:", embedding_drift)
            total_confusion, intra_phase_confusion, per_task_confusion = measure_embedding_confusion_knn(all_embeddings, all_labels, k = 300, task=i+1)
            _print("Total confusion", total_confusion)
            _print("Intra-phase confusion", intra_phase_confusion)
            _print("Per task confusions", per_task_confusion)
            if verbose and dataset == 'mnist':
                plot_confusion_matrix(predicted.cpu(), all_labels.cpu(), list(range(globals.CLASSES_PER_ITER*(i+1))))
        prevModel = model
        
    #[Denis] added code:
    [avg_att_diff,att_diffs,_,_,avg_att_spread,att_spreads]=Feature_Importance_Eval.Get_Feature_Change_Score(prevModel)
    _print("Average SHAPC values (ordered as tasks):", att_diffs)
    _print("Averaged SHAPC value (the smaller the better):", avg_att_diff)
    _print("Average attention spread values (ordered as tasks):", att_spreads)
    _print("Averaged attention spread value (the bigger the better):", avg_att_spread)
    if save_saliency:
        Feature_Importance_Eval.Save_Random_Picture_Salency() #prints the salcency maps for 1 example by class (first row: image, second row: salency map after training, third row: salency map after training task where class is included)
    if not joint_training:
        _print("Training output layer to check for bias...")
        model.output_layer.reset_parameters()
        globals.ITERATIONS = 1
        if dataset == 'mnist' or dataset == 'cifar10':
            globals.CLASSES_PER_ITER = 10
        else:
            globals.CLASSES_PER_ITER = 200
        model = train_model(
                    model, 
                    DataLoader(globals.trainset, batch_size=globals.BATCH_SIZE, shuffle=True, pin_memory=True, num_workers = 0), 
                    DataLoader(globals.valset, batch_size=100, shuffle=False), 
                    verbose, 
                    epochs=epochs, 
                    stopOnLoss=None,
                    optimiser_type='sgd',
                    plotting=plotting,
                    patience=patience,
                    ogd_basis_size=ogd_basis_size,
                    only_output_layer=True
                    )
        model.to(attr_device)
        predicted, _, _ = evalModel(model, attr_device, i, True)
        predicted = predicted.to(attr_device)
        correct = (predicted == all_labels).sum().item()  # Count how many were correct
        unbiased_output_accuracy = correct / all_labels.size(0)  # Accuracy as a percentage
        output_bias = unbiased_output_accuracy-accuracy
    else:
        output_bias = 0
    return accuracy, total_confusion, intra_phase_confusion, per_task_confusion, embedding_drift, avg_att_diff, avg_att_spread, output_bias

In [185]:
def run_experiments(n_runs=globals.EXPERIMENT_N_RUNS, *args, **kwargs):
    
    verbose = kwargs.get('verbose', None)
    def _print(*args, **kwargs):
        if verbose:
            print(*args, **kwargs)
    jobs = []
    for r in range(n_runs):
        print(f"Starting run {r+1}.")
        if RUN_LOCALLY:
            jobs.append(run_experiment(*args, **kwargs))
        else:  # Run on Cluster
            jobs.append(ex_parallel.submit(run_experiment,*args, **kwargs))
    return jobs

def report_performance(jobs):
    def report_stats(data, name):
        #print(name, data)
        mean = statistics.mean(data)
        std = statistics.stdev(data)
        print(f"Mean " + name + f" across {len(jobs)} runs: {mean}")
        print(f"Standard deviation of " + name + f" across {len(jobs)} runs: {std}\n")
    accuracies = []
    total_confusions = []
    intra_phase_confusions = []
    per_task_confusions = []
    att_diffs = []
    embedding_drifts = []
    att_spreads = []
    output_biases = []
    for i, job in enumerate(jobs):
        print(f"Finished run {i+1}.")
        if type(job) == submitit.slurm.slurm.SlurmJob:
            accuracy, total_confusion, intra_phase_confusion, per_task_confusion, embedding_drift, avg_att_diff, avg_att_spread, output_bias = job.result()
        else:
            accuracy, total_confusion, intra_phase_confusion, per_task_confusion, embedding_drift, avg_att_diff, avg_att_spread, output_bias = job
        accuracies.append(accuracy)
        total_confusions.append(total_confusion)
        intra_phase_confusions.append(intra_phase_confusion)
        per_task_confusions.append(per_task_confusion)
        att_diffs.append(avg_att_diff)
        embedding_drifts.append(embedding_drift)
        att_spreads.append(avg_att_spread)
        output_biases.append(output_bias)
        print(f"Run {i} finished with accuracy {accuracy}")

    report_stats(accuracies, "accuracy")
    report_stats(total_confusions, "total confusion")
    report_stats(intra_phase_confusions, "intra-phase confusion")
    report_stats(per_task_confusions, "per-task confusion")
    report_stats(output_biases, "output bias")
    report_stats(embedding_drifts, "embedding drift")
    report_stats(att_diffs, "attention drift")
    report_stats(att_spreads, "attention spread")

## Experiments for MNIST:

In [None]:
ex1 = run_experiments(dataset='mnist', n_runs=5)

Starting run 2.


Experiment Progress:   0%|          | 0/5 [00:00<?, ?it/s]

In [64]:
report_performance(ex1)

Finished run 1.
Run 0 finished with accuracy 0.1965
Finished run 2.
Run 1 finished with accuracy 0.1975
Finished run 3.
Run 2 finished with accuracy 0.1974
Finished run 4.
Run 3 finished with accuracy 0.1966
Finished run 5.
Run 4 finished with accuracy 0.1971
Mean accuracy across 5 runs: 0.19702
Standard deviation of accuracy across 5 runs: 0.00045497252664309203

Mean total confusion across 5 runs: 0.30151626666666664
Standard deviation of total confusion across 5 runs: 0.01638281958061891

Mean intra-phase confusion across 5 runs: 0.2954242
Standard deviation of intra-phase confusion across 5 runs: 0.01568106066175937

Mean per-task confusion across 5 runs: 0.028010427293363522
Standard deviation of per-task confusion across 5 runs: 0.004846384929706976

Mean output bias across 5 runs: 0.78506
Standard deviation of output bias across 5 runs: 0.0013557285864066026

Mean embedding drift across 5 runs: 15.751250457763671
Standard deviation of embedding drift across 5 runs: 0.74161578186

In [26]:
ex2 = run_experiments(dataset='mnist', n_runs=5, full_CE=False)# baseline with special CE

In [65]:
report_performance(ex2)

Finished run 1.
Run 0 finished with accuracy 0.6791
Finished run 2.
Run 1 finished with accuracy 0.6323
Finished run 3.
Run 2 finished with accuracy 0.6481
Finished run 4.
Run 3 finished with accuracy 0.6545
Finished run 5.
Run 4 finished with accuracy 0.6707
Mean accuracy across 5 runs: 0.65694
Standard deviation of accuracy across 5 runs: 0.018517775244342945

Mean total confusion across 5 runs: 0.26517653333333335
Standard deviation of total confusion across 5 runs: 0.010413654283359608

Mean intra-phase confusion across 5 runs: 0.2631304666666667
Standard deviation of intra-phase confusion across 5 runs: 0.010266379819044758

Mean per-task confusion across 5 runs: 0.013770140119461387
Standard deviation of per-task confusion across 5 runs: 0.0011526582073460923

Mean output bias across 5 runs: 0.32206
Standard deviation of output bias across 5 runs: 0.014791991076254755

Mean embedding drift across 5 runs: 8.187272357940675
Standard deviation of embedding drift across 5 runs: 0.638

In [27]:
ex3 = run_experiments(dataset='mnist', n_runs=5, with_OOD=True) # baseline with OOD

In [66]:
report_performance(ex3)

Finished run 1.
Run 0 finished with accuracy 0.1971
Finished run 2.
Run 1 finished with accuracy 0.1973
Finished run 3.
Run 2 finished with accuracy 0.1959
Finished run 4.
Run 3 finished with accuracy 0.1973
Finished run 5.
Run 4 finished with accuracy 0.1981
Mean accuracy across 5 runs: 0.19714
Standard deviation of accuracy across 5 runs: 0.0007924645102463614

Mean total confusion across 5 runs: 0.21946179999999998
Standard deviation of total confusion across 5 runs: 0.009266260748303771

Mean intra-phase confusion across 5 runs: 0.2134478
Standard deviation of intra-phase confusion across 5 runs: 0.0077643381223242445

Mean per-task confusion across 5 runs: 0.023137949450076273
Standard deviation of per-task confusion across 5 runs: 0.002563727416715811

Mean output bias across 5 runs: 0.78676
Standard deviation of output bias across 5 runs: 0.001772850811546226

Mean embedding drift across 5 runs: 18.92416915893555
Standard deviation of embedding drift across 5 runs: 0.83469436214

In [28]:
ex4 = run_experiments(dataset='mnist', n_runs=5, with_dropout=True) # baseline with dropout

In [67]:
report_performance(ex4)

Finished run 1.
Run 0 finished with accuracy 0.1979
Finished run 2.
Run 1 finished with accuracy 0.1973
Finished run 3.
Run 2 finished with accuracy 0.1967
Finished run 4.
Run 3 finished with accuracy 0.197
Finished run 5.
Run 4 finished with accuracy 0.1974
Mean accuracy across 5 runs: 0.19726
Standard deviation of accuracy across 5 runs: 0.00045055521304274376

Mean total confusion across 5 runs: 0.2983267333333333
Standard deviation of total confusion across 5 runs: 0.026113845311762588

Mean intra-phase confusion across 5 runs: 0.2877751333333333
Standard deviation of intra-phase confusion across 5 runs: 0.023466952600767463

Mean per-task confusion across 5 runs: 0.04347928433319726
Standard deviation of per-task confusion across 5 runs: 0.010389888824512515

Mean output bias across 5 runs: 0.78876
Standard deviation of output bias across 5 runs: 0.00037815340802378457

Mean embedding drift across 5 runs: 12.544556808471679
Standard deviation of embedding drift across 5 runs: 0.27

In [29]:
ex5 = run_experiments(dataset='mnist', n_runs=5, kd_loss=1) # baseline with KD loss

In [68]:
report_performance(ex5)

Finished run 1.
Run 0 finished with accuracy 0.1974
Finished run 2.
Run 1 finished with accuracy 0.1942
Finished run 3.
Run 2 finished with accuracy 0.1973
Finished run 4.
Run 3 finished with accuracy 0.1965
Finished run 5.
Run 4 finished with accuracy 0.1956
Mean accuracy across 5 runs: 0.1962
Standard deviation of accuracy across 5 runs: 0.0013322912594474182

Mean total confusion across 5 runs: 0.2786928
Standard deviation of total confusion across 5 runs: 0.012375638247514113

Mean intra-phase confusion across 5 runs: 0.2759276
Standard deviation of intra-phase confusion across 5 runs: 0.012122452019565444

Mean per-task confusion across 5 runs: 0.018358304031778457
Standard deviation of per-task confusion across 5 runs: 0.0022142367376584097

Mean output bias across 5 runs: 0.78626
Standard deviation of output bias across 5 runs: 0.0019346834366376297

Mean embedding drift across 5 runs: 18.043492126464844
Standard deviation of embedding drift across 5 runs: 1.1325006635331765

Me

In [176]:
ex7 = run_experiments(dataset='mnist', n_runs=5, optimiser_type='adam') # baseline with Adam

In [214]:
report_performance(ex7)

Finished run 1.
Run 0 finished with accuracy 0.1967
Finished run 2.
Run 1 finished with accuracy 0.1979
Finished run 3.
Run 2 finished with accuracy 0.1963
Finished run 4.
Run 3 finished with accuracy 0.1976
Finished run 5.
Run 4 finished with accuracy 0.1976
Mean accuracy across 5 runs: 0.19722
Standard deviation of accuracy across 5 runs: 0.0006833739825307897

Mean total confusion across 5 runs: 0.26263453333333336
Standard deviation of total confusion across 5 runs: 0.022666385567619712

Mean intra-phase confusion across 5 runs: 0.2583906
Standard deviation of intra-phase confusion across 5 runs: 0.02094580774469625

Mean per-task confusion across 5 runs: 0.018907968423044322
Standard deviation of per-task confusion across 5 runs: 0.005712710005354771

Mean output bias across 5 runs: 0.78828
Standard deviation of output bias across 5 runs: 0.0023477648945326632

Mean embedding drift across 5 runs: 19.7469783782959
Standard deviation of embedding drift across 5 runs: 1.3269651164639

In [31]:
ex8 = run_experiments_locally(dataset='mnist', n_runs=5, with_dropout=True, kd_loss=1, full_CE=False, with_OOD=True) # all improvements

In [70]:
report_performance(ex8)

Finished run 1.
Run 0 finished with accuracy 0.8625
Finished run 2.
Run 1 finished with accuracy 0.8783
Finished run 3.
Run 2 finished with accuracy 0.8472
Finished run 4.
Run 3 finished with accuracy 0.8676
Finished run 5.
Run 4 finished with accuracy 0.8712
Mean accuracy across 5 runs: 0.86536
Standard deviation of accuracy across 5 runs: 0.011667604724192543

Mean total confusion across 5 runs: 0.16521786666666669
Standard deviation of total confusion across 5 runs: 0.007124483687335604

Mean intra-phase confusion across 5 runs: 0.1628992
Standard deviation of intra-phase confusion across 5 runs: 0.006968148325695195

Mean per-task confusion across 5 runs: 0.010805132892886463
Standard deviation of per-task confusion across 5 runs: 0.0006999352199073621

Mean output bias across 5 runs: 0.12109999999999999
Standard deviation of output bias across 5 runs: 0.012728511303369314

Mean embedding drift across 5 runs: 6.2946525573730465
Standard deviation of embedding drift across 5 runs: 0

In [32]:
ex9 = run_experiments(dataset='mnist', n_runs=5, with_dropout=True, kd_loss=1, full_CE=True, with_OOD=True) # all improvements without special CE

In [71]:
report_performance(ex9)

Finished run 1.
Run 0 finished with accuracy 0.2068
Finished run 2.
Run 1 finished with accuracy 0.2016
Finished run 3.
Run 2 finished with accuracy 0.1987
Finished run 4.
Run 3 finished with accuracy 0.198
Finished run 5.
Run 4 finished with accuracy 0.1991
Mean accuracy across 5 runs: 0.20084
Standard deviation of accuracy across 5 runs: 0.0035976381140965298

Mean total confusion across 5 runs: 0.1945109333333333
Standard deviation of total confusion across 5 runs: 0.012306290075043356

Mean intra-phase confusion across 5 runs: 0.19035033333333334
Standard deviation of intra-phase confusion across 5 runs: 0.011988576310156464

Mean per-task confusion across 5 runs: 0.01745498423588679
Standard deviation of per-task confusion across 5 runs: 0.0023845686078281946

Mean output bias across 5 runs: 0.78404
Standard deviation of output bias across 5 runs: 0.0026922109872741895

Mean embedding drift across 5 runs: 15.807488250732423
Standard deviation of embedding drift across 5 runs: 0.92

In [33]:
ex10 = run_experiments(dataset='mnist', n_runs=5, with_dropout=True, kd_loss=1, full_CE=False, with_OOD=False) # all improvements without OOD

In [72]:
report_performance(ex10)

Finished run 1.
Run 0 finished with accuracy 0.7781
Finished run 2.
Run 1 finished with accuracy 0.718
Finished run 3.
Run 2 finished with accuracy 0.7602
Finished run 4.
Run 3 finished with accuracy 0.7721
Finished run 5.
Run 4 finished with accuracy 0.7554
Mean accuracy across 5 runs: 0.75676
Standard deviation of accuracy across 5 runs: 0.023487933072111746

Mean total confusion across 5 runs: 0.225098
Standard deviation of total confusion across 5 runs: 0.007906878380386412

Mean intra-phase confusion across 5 runs: 0.2237734666666667
Standard deviation of intra-phase confusion across 5 runs: 0.00776819161567078

Mean per-task confusion across 5 runs: 0.010401042402823463
Standard deviation of per-task confusion across 5 runs: 0.0010261922031812266

Mean output bias across 5 runs: 0.22828
Standard deviation of output bias across 5 runs: 0.02330272516252126

Mean embedding drift across 5 runs: 5.013447093963623
Standard deviation of embedding drift across 5 runs: 0.10689131038691713

In [34]:
ex11 = run_experiments(dataset='mnist', n_runs=5, with_dropout=False, kd_loss=1, full_CE=False, with_OOD=True) # all improvements without dropouts

In [73]:
report_performance(ex11)

Finished run 1.
Run 0 finished with accuracy 0.8262
Finished run 2.
Run 1 finished with accuracy 0.7848
Finished run 3.
Run 2 finished with accuracy 0.7836
Finished run 4.
Run 3 finished with accuracy 0.825
Finished run 5.
Run 4 finished with accuracy 0.7412
Mean accuracy across 5 runs: 0.79216
Standard deviation of accuracy across 5 runs: 0.03521914252221369

Mean total confusion across 5 runs: 0.20879066666666665
Standard deviation of total confusion across 5 runs: 0.007163731619918676

Mean intra-phase confusion across 5 runs: 0.20517813333333332
Standard deviation of intra-phase confusion across 5 runs: 0.006934348926419376

Mean per-task confusion across 5 runs: 0.01508101297714759
Standard deviation of per-task confusion across 5 runs: 0.0012527355734396797

Mean output bias across 5 runs: 0.18968000000000002
Standard deviation of output bias across 5 runs: 0.03315692989406592

Mean embedding drift across 5 runs: 8.123602676391602
Standard deviation of embedding drift across 5 ru

In [35]:
ex12 = run_experiments(dataset='mnist', n_runs=5, with_dropout=True, kd_loss=0, full_CE=False, with_OOD=True) # all improvements without kd loss

In [74]:
report_performance(ex12)

Finished run 1.
Run 0 finished with accuracy 0.6349
Finished run 2.
Run 1 finished with accuracy 0.7466
Finished run 3.
Run 2 finished with accuracy 0.7567
Finished run 4.
Run 3 finished with accuracy 0.6567
Finished run 5.
Run 4 finished with accuracy 0.6589
Mean accuracy across 5 runs: 0.69076
Standard deviation of accuracy across 5 runs: 0.056483785283920215

Mean total confusion across 5 runs: 0.1931168
Standard deviation of total confusion across 5 runs: 0.01320736273531631

Mean intra-phase confusion across 5 runs: 0.19061653333333334
Standard deviation of intra-phase confusion across 5 runs: 0.013116010363928017

Mean per-task confusion across 5 runs: 0.013589046252764705
Standard deviation of per-task confusion across 5 runs: 0.0013523568370232314

Mean output bias across 5 runs: 0.29632
Standard deviation of output bias across 5 runs: 0.05571572488983698

Mean embedding drift across 5 runs: 6.830794620513916
Standard deviation of embedding drift across 5 runs: 0.28650523230271

In [200]:
ex14 = run_experiments(dataset='mnist', n_runs=5, with_dropout=True, kd_loss=1, full_CE=False, with_OOD=True, optimiser_type='adam') # all improvements with Adam

In [244]:
report_performance(ex14)

Finished run 1.
Run 0 finished with accuracy 0.7221
Finished run 2.
Run 1 finished with accuracy 0.7089
Finished run 3.
Run 2 finished with accuracy 0.7229
Finished run 4.
Run 3 finished with accuracy 0.6543
Finished run 5.
Run 4 finished with accuracy 0.6454
Mean accuracy across 5 runs: 0.69072
Standard deviation of accuracy across 5 runs: 0.03785197484940515

Mean total confusion across 5 runs: 0.16659393333333333
Standard deviation of total confusion across 5 runs: 0.008450156535039282

Mean intra-phase confusion across 5 runs: 0.1653022
Standard deviation of intra-phase confusion across 5 runs: 0.008425832272112783

Mean per-task confusion across 5 runs: 0.00689000241771053
Standard deviation of per-task confusion across 5 runs: 0.0008540630579552725

Mean output bias across 5 runs: 0.29536
Standard deviation of output bias across 5 runs: 0.03750050666324391

Mean embedding drift across 5 runs: 9.341233158111573
Standard deviation of embedding drift across 5 runs: 1.336539008371681

In [37]:
ex15 = run_experiments(dataset='mnist', n_runs=5, with_dropout=True, joint_training=True) # joint training

In [76]:
report_performance(ex15)

Finished run 1.
Run 0 finished with accuracy 0.9905
Finished run 2.
Run 1 finished with accuracy 0.9915
Finished run 3.
Run 2 finished with accuracy 0.9913
Finished run 4.
Run 3 finished with accuracy 0.9906
Finished run 5.
Run 4 finished with accuracy 0.9913
Mean accuracy across 5 runs: 0.99104
Standard deviation of accuracy across 5 runs: 0.00045607017003963297

Mean total confusion across 5 runs: 0.049904933333333325
Standard deviation of total confusion across 5 runs: 0.002843215036616902

Mean intra-phase confusion across 5 runs: 0.0
Standard deviation of intra-phase confusion across 5 runs: 0.0

Mean per-task confusion across 5 runs: 0.049904933333333325
Standard deviation of per-task confusion across 5 runs: 0.002843215036616902

Mean output bias across 5 runs: 0
Standard deviation of output bias across 5 runs: 0.0

Mean embedding drift across 5 runs: 6.624188972637057e-06
Standard deviation of embedding drift across 5 runs: 2.3252230501740707e-07

Mean attention drift across 5 

## Experiments for CIFAR10:

In [13]:
ex_cf_1 = run_experiments(dataset='cifar10', n_runs=3) # baseline

Starting run 1.
Starting run 2.
Starting run 3.


In [23]:
report_performance(ex_cf_1)

Finished run 1.
Run 0 finished with accuracy 0.1879
Finished run 2.
Run 1 finished with accuracy 0.1921
Finished run 3.
Run 2 finished with accuracy 0.1862
Mean accuracy across 3 runs: 0.18873333333333334
Standard deviation of accuracy across 3 runs: 0.0030369941279714866

Mean total confusion across 3 runs: 0.8150898888888889
Standard deviation of total confusion across 3 runs: 0.00590723881101571

Mean intra-phase confusion across 3 runs: 0.735893
Standard deviation of intra-phase confusion across 3 runs: 0.006888937323306408

Mean per-task confusion across 3 runs: 0.35517077777777784
Standard deviation of per-task confusion across 3 runs: 0.013072628598773894

Mean output bias across 3 runs: 0.13786666666666667
Standard deviation of output bias across 3 runs: 0.027837085575421364

Mean embedding drift across 3 runs: 5.611140727996826
Standard deviation of embedding drift across 3 runs: 0.42372610870980176

Mean attention drift across 3 runs: 3.426761663301175e-07
Standard deviation 

In [39]:
ex_cf_2 = run_experiments(dataset='cifar10', n_runs=3, full_CE=False) # baseline with special CE

Starting run 1.
Starting run 2.
Starting run 3.


In [78]:
report_performance(ex_cf_2)

Finished run 1.
Run 0 finished with accuracy 0.1581
Finished run 2.
Run 1 finished with accuracy 0.0577
Finished run 3.
Run 2 finished with accuracy 0.1615
Mean accuracy across 3 runs: 0.12576666666666667
Standard deviation of accuracy across 3 runs: 0.058971970743170295

Mean total confusion across 3 runs: 0.8136176666666667
Standard deviation of total confusion across 3 runs: 0.0011870364124340572

Mean intra-phase confusion across 3 runs: 0.7345164444444444
Standard deviation of intra-phase confusion across 3 runs: 0.005755544827317365

Mean per-task confusion across 3 runs: 0.348965
Standard deviation of per-task confusion across 3 runs: 0.007998210862436688

Mean output bias across 3 runs: 0.25053333333333333
Standard deviation of output bias across 3 runs: 0.07827888178387153

Mean embedding drift across 3 runs: 3.3491338888804116
Standard deviation of embedding drift across 3 runs: 0.4337011301139496

Mean attention drift across 3 runs: 3.1004676443776226e-07
Standard deviation 

In [186]:
ex_cf_3 = run_experiments(dataset='cifar10', n_runs=3, with_OOD=True) # baseline with OOD

Starting run 1.
Starting run 2.
Starting run 3.


In [203]:
report_performance(ex_cf_3)

Finished run 1.
Run 0 finished with accuracy 0.1896
Finished run 2.
Run 1 finished with accuracy 0.1903
Finished run 3.
Run 2 finished with accuracy 0.189
Mean accuracy across 3 runs: 0.18963333333333332
Standard deviation of accuracy across 3 runs: 0.0006506407098647691

Mean total confusion across 3 runs: 0.7954076666666667
Standard deviation of total confusion across 3 runs: 0.012795435058384411

Mean intra-phase confusion across 3 runs: 0.7152497777777779
Standard deviation of intra-phase confusion across 3 runs: 0.010628348833470129

Mean per-task confusion across 3 runs: 0.3371353333333333
Standard deviation of per-task confusion across 3 runs: 0.009632010076360584

Mean output bias across 3 runs: 0.21576666666666666
Standard deviation of output bias across 3 runs: 0.028889675202974047

Mean embedding drift across 3 runs: 7.078327337900798
Standard deviation of embedding drift across 3 runs: 0.12330433918255805

Mean attention drift across 3 runs: 2.899680300302521e-07
Standard d

In [41]:
ex_cf_4 = run_experiments(dataset='cifar10', n_runs=3, with_dropout=True) # baseline with dropout

Starting run 1.
Starting run 2.
Starting run 3.


In [80]:
report_performance(ex_cf_4)

Finished run 1.
Run 0 finished with accuracy 0.1903
Finished run 2.
Run 1 finished with accuracy 0.1898
Finished run 3.
Run 2 finished with accuracy 0.1926
Mean accuracy across 3 runs: 0.1909
Standard deviation of accuracy across 3 runs: 0.001493318452306806

Mean total confusion across 3 runs: 0.7926706666666666
Standard deviation of total confusion across 3 runs: 0.013702840609807063

Mean intra-phase confusion across 3 runs: 0.7151139999999999
Standard deviation of intra-phase confusion across 3 runs: 0.011003860246901203

Mean per-task confusion across 3 runs: 0.3405243333333334
Standard deviation of per-task confusion across 3 runs: 0.013260845779126543

Mean output bias across 3 runs: 0.19856666666666667
Standard deviation of output bias across 3 runs: 0.04019966832367319

Mean embedding drift across 3 runs: 4.00794514020284
Standard deviation of embedding drift across 3 runs: 0.07444926386411467

Mean attention drift across 3 runs: 3.5139745852272137e-07
Standard deviation of at

In [42]:
ex_cf_5 = run_experiments(dataset='cifar10', n_runs=3, kd_loss=1) # baseline with KD loss

Starting run 1.
Starting run 2.
Starting run 3.


In [81]:
report_performance(ex_cf_5)

Finished run 1.
Run 0 finished with accuracy 0.1909
Finished run 2.
Run 1 finished with accuracy 0.1909
Finished run 3.
Run 2 finished with accuracy 0.1914
Mean accuracy across 3 runs: 0.19106666666666666
Standard deviation of accuracy across 3 runs: 0.00028867513459481317

Mean total confusion across 3 runs: 0.7290466666666666
Standard deviation of total confusion across 3 runs: 0.013700385773157414

Mean intra-phase confusion across 3 runs: 0.6753583333333333
Standard deviation of intra-phase confusion across 3 runs: 0.013727380198388574

Mean per-task confusion across 3 runs: 0.23976144444444442
Standard deviation of per-task confusion across 3 runs: 0.016109701191564074

Mean output bias across 3 runs: 0.31960000000000005
Standard deviation of output bias across 3 runs: 0.023900627606822412

Mean embedding drift across 3 runs: 7.149629592895508
Standard deviation of embedding drift across 3 runs: 0.2568053278392342

Mean attention drift across 3 runs: 2.231473689681945e-07
Standard

In [177]:
ex_cf_7 = run_experiments(dataset='cifar10', n_runs=3, optimiser_type='adam', lr=0.001) # baseline with Adam

Starting run 1.
Starting run 2.
Starting run 3.


In [180]:
report_performance(ex_cf_7)

Finished run 1.
Run 0 finished with accuracy 0.1884
Finished run 2.
Run 1 finished with accuracy 0.1884
Finished run 3.
Run 2 finished with accuracy 0.186
Mean accuracy across 3 runs: 0.18760000000000002
Standard deviation of accuracy across 3 runs: 0.0013856406460551094

Mean total confusion across 3 runs: 0.8075774444444445
Standard deviation of total confusion across 3 runs: 0.006486076457675674

Mean intra-phase confusion across 3 runs: 0.7307914444444444
Standard deviation of intra-phase confusion across 3 runs: 0.0028884119499463365

Mean per-task confusion across 3 runs: 0.3462801111111111
Standard deviation of per-task confusion across 3 runs: 0.011423664523617113

Mean output bias across 3 runs: 0.1556
Standard deviation of output bias across 3 runs: 0.04400136361523356

Mean embedding drift across 3 runs: 26.946285247802734
Standard deviation of embedding drift across 3 runs: 10.425892047866796

Mean attention drift across 3 runs: 3.904787418643233e-07
Standard deviation of a

In [187]:
ex_cf_8 = run_experiments(dataset='cifar10', n_runs=3, with_dropout=True, kd_loss=1, full_CE=False, with_OOD=True, optimiser_type='sgd') # all improvements

Starting run 1.
Starting run 2.
Starting run 3.


In [208]:
report_performance(ex_cf_8)

Finished run 1.
Run 0 finished with accuracy 0.4094
Finished run 2.
Run 1 finished with accuracy 0.3506
Finished run 3.
Run 2 finished with accuracy 0.411
Mean accuracy across 3 runs: 0.3903333333333333
Standard deviation of accuracy across 3 runs: 0.034419374389046235

Mean total confusion across 3 runs: 0.6786732222222222
Standard deviation of total confusion across 3 runs: 0.008851295387514566

Mean intra-phase confusion across 3 runs: 0.6425097777777777
Standard deviation of intra-phase confusion across 3 runs: 0.009202880408590311

Mean per-task confusion across 3 runs: 0.16869922222222222
Standard deviation of per-task confusion across 3 runs: 0.0013186683659128485

Mean output bias across 3 runs: 0.185
Standard deviation of output bias across 3 runs: 0.027904838290160373

Mean embedding drift across 3 runs: 2.7267093658447266
Standard deviation of embedding drift across 3 runs: 0.18285992033846996

Mean attention drift across 3 runs: 1.351646285622602e-07
Standard deviation of a

In [188]:
ex_cf_9 = run_experiments(dataset='cifar10', n_runs=3, with_dropout=True, kd_loss=1, full_CE=True, with_OOD=True, optimiser_type='sgd') # all improvements without special CE

Starting run 1.
Starting run 2.
Starting run 3.


In [205]:
report_performance(ex_cf_9)

Finished run 1.
Run 0 finished with accuracy 0.1924
Finished run 2.
Run 1 finished with accuracy 0.1923
Finished run 3.
Run 2 finished with accuracy 0.1932
Mean accuracy across 3 runs: 0.19263333333333332
Standard deviation of accuracy across 3 runs: 0.0004932882862316341

Mean total confusion across 3 runs: 0.6966635555555556
Standard deviation of total confusion across 3 runs: 0.009691097554705303

Mean intra-phase confusion across 3 runs: 0.6490783333333333
Standard deviation of intra-phase confusion across 3 runs: 0.007695207195246763

Mean per-task confusion across 3 runs: 0.21338422222222225
Standard deviation of per-task confusion across 3 runs: 0.02077878608456589

Mean output bias across 3 runs: 0.36346666666666666
Standard deviation of output bias across 3 runs: 0.019119710597530808

Mean embedding drift across 3 runs: 6.24628225962321
Standard deviation of embedding drift across 3 runs: 0.6229272377467012

Mean attention drift across 3 runs: 2.484807036566394e-07
Standard de

In [46]:
ex_cf_10 = run_experiments(dataset='cifar10', n_runs=3, with_dropout=True, kd_loss=1, full_CE=False, with_OOD=False, optimiser_type='sgd') # all improvements without OOD

Starting run 1.
Starting run 2.
Starting run 3.


In [91]:
report_performance(ex_cf_10)

Finished run 1.
Run 0 finished with accuracy 0.3335
Finished run 2.
Run 1 finished with accuracy 0.3889
Finished run 3.
Run 2 finished with accuracy 0.3643
Mean accuracy across 3 runs: 0.36223333333333335
Standard deviation of accuracy across 3 runs: 0.027757761677291874

Mean total confusion across 3 runs: 0.6878068888888889
Standard deviation of total confusion across 3 runs: 0.0208423163895351

Mean intra-phase confusion across 3 runs: 0.6543283333333333
Standard deviation of intra-phase confusion across 3 runs: 0.018454712942166742

Mean per-task confusion across 3 runs: 0.16579566666666667
Standard deviation of per-task confusion across 3 runs: 0.009867631535029639

Mean output bias across 3 runs: 0.2186333333333333
Standard deviation of output bias across 3 runs: 0.007200231477760552

Mean embedding drift across 3 runs: 2.03570028146108
Standard deviation of embedding drift across 3 runs: 0.2500934602964982

Mean attention drift across 3 runs: 1.2656743829744255e-07
Standard devi

In [189]:
ex_cf_11 = run_experiments(dataset='cifar10', n_runs=3, with_dropout=False, kd_loss=1, full_CE=False, with_OOD=True, optimiser_type='sgd') # all improvements without dropouts

Starting run 1.
Starting run 2.
Starting run 3.


In [209]:
report_performance(ex_cf_11)

Finished run 1.
Run 0 finished with accuracy 0.4584
Finished run 2.
Run 1 finished with accuracy 0.4085
Finished run 3.
Run 2 finished with accuracy 0.4875
Mean accuracy across 3 runs: 0.4514666666666666
Standard deviation of accuracy across 3 runs: 0.03995376494566356

Mean total confusion across 3 runs: 0.6546418888888889
Standard deviation of total confusion across 3 runs: 0.005383163437291571

Mean intra-phase confusion across 3 runs: 0.6179250000000001
Standard deviation of intra-phase confusion across 3 runs: 0.0024396887779660195

Mean per-task confusion across 3 runs: 0.1690621111111111
Standard deviation of per-task confusion across 3 runs: 0.01101647221882524

Mean output bias across 3 runs: 0.1623666666666667
Standard deviation of output bias across 3 runs: 0.03730017873058159

Mean embedding drift across 3 runs: 4.669615427652995
Standard deviation of embedding drift across 3 runs: 0.2838798655521294

Mean attention drift across 3 runs: 6.844182714947541e-08
Standard deviat

In [190]:
ex_cf_12 = run_experiments(dataset='cifar10', n_runs=3, with_dropout=True, kd_loss=0, full_CE=False, with_OOD=True, optimiser_type='sgd') # all improvements without kd loss

Starting run 1.
Starting run 2.
Starting run 3.


In [219]:
report_performance(ex_cf_12)

Finished run 1.
Run 0 finished with accuracy 0.1616
Finished run 2.
Run 1 finished with accuracy 0.0545
Finished run 3.
Run 2 finished with accuracy 0.0549
Mean accuracy across 3 runs: 0.09033333333333333
Standard deviation of accuracy across 3 runs: 0.06171906782618588

Mean total confusion across 3 runs: 0.8054283333333333
Standard deviation of total confusion across 3 runs: 0.004481763802095938

Mean intra-phase confusion across 3 runs: 0.7300105555555556
Standard deviation of intra-phase confusion across 3 runs: 0.004231749839452171

Mean per-task confusion across 3 runs: 0.3399215555555556
Standard deviation of per-task confusion across 3 runs: 0.01223262827365733

Mean output bias across 3 runs: 0.2692333333333333
Standard deviation of output bias across 3 runs: 0.09019120430138038

Mean embedding drift across 3 runs: 3.1476405461629233
Standard deviation of embedding drift across 3 runs: 0.3724409874625537

Mean attention drift across 3 runs: 3.0758886688389e-07
Standard deviati

In [199]:
ex_cf_14 = run_experiments(dataset='cifar10', n_runs=3, with_dropout=True, kd_loss=1, full_CE=False, with_OOD=True, optimiser_type='adam', lr=0.001) # all improvements with Adam

Starting run 1.
Starting run 2.
Starting run 3.


In [233]:
report_performance(ex_cf_14)

Finished run 1.
Run 0 finished with accuracy 0.3494
Finished run 2.
Run 1 finished with accuracy 0.3875
Finished run 3.
Run 2 finished with accuracy 0.3646
Mean accuracy across 3 runs: 0.36716666666666664
Standard deviation of accuracy across 3 runs: 0.0191792422512813

Mean total confusion across 3 runs: 0.7250651111111112
Standard deviation of total confusion across 3 runs: 0.007741162712935783

Mean intra-phase confusion across 3 runs: 0.6788436666666666
Standard deviation of intra-phase confusion across 3 runs: 0.007391588582826914

Mean per-task confusion across 3 runs: 0.21498788888888887
Standard deviation of per-task confusion across 3 runs: 0.005980634120348888

Mean output bias across 3 runs: 0.1579
Standard deviation of output bias across 3 runs: 0.018730456481356787

Mean embedding drift across 3 runs: 10.01424757639567
Standard deviation of embedding drift across 3 runs: 0.9237363654591716

Mean attention drift across 3 runs: 3.1909643724456296e-07
Standard deviation of at

In [117]:
ex_cf_15 = run_experiments(dataset='cifar10', n_runs=3, with_dropout=True, joint_training=True) # joint training

Starting run 1.
Starting run 2.
Starting run 3.


In [145]:
report_performance(ex_cf_15)

Finished run 1.
Run 0 finished with accuracy 0.8198
Finished run 2.
Run 1 finished with accuracy 0.8164
Finished run 3.
Run 2 finished with accuracy 0.8237
Mean accuracy across 3 runs: 0.8199666666666666
Standard deviation of accuracy across 3 runs: 0.003652852766446143

Mean total confusion across 3 runs: 0.29480211111111115
Standard deviation of total confusion across 3 runs: 0.004794796141342918

Mean intra-phase confusion across 3 runs: 0.0
Standard deviation of intra-phase confusion across 3 runs: 0.0

Mean per-task confusion across 3 runs: 0.29480211111111115
Standard deviation of per-task confusion across 3 runs: 0.004794796141342918

Mean output bias across 3 runs: 0
Standard deviation of output bias across 3 runs: 0.0

Mean embedding drift across 3 runs: 2.9902341793786036e-06
Standard deviation of embedding drift across 3 runs: 9.015204393919827e-08

Mean attention drift across 3 runs: 0.0
Standard deviation of attention drift across 3 runs: 0.0

Mean attention spread across 

In [247]:
ex_cf_15_2 = run_experiments(dataset='cifar10', n_runs=3, with_dropout=False, joint_training=True) # joint training

Starting run 1.
Starting run 2.
Starting run 3.


In [249]:
report_performance(ex_cf_15_2)

Finished run 1.
Run 0 finished with accuracy 0.7772
Finished run 2.
Run 1 finished with accuracy 0.7882
Finished run 3.
Run 2 finished with accuracy 0.7725
Mean accuracy across 3 runs: 0.7793
Standard deviation of accuracy across 3 runs: 0.008057915363169332

Mean total confusion across 3 runs: 0.4202355555555556
Standard deviation of total confusion across 3 runs: 0.02317752783608707

Mean intra-phase confusion across 3 runs: 0.0
Standard deviation of intra-phase confusion across 3 runs: 0.0

Mean per-task confusion across 3 runs: 0.4202355555555556
Standard deviation of per-task confusion across 3 runs: 0.02317752783608707

Mean output bias across 3 runs: 0
Standard deviation of output bias across 3 runs: 0.0

Mean embedding drift across 3 runs: 5.864368328426887e-06
Standard deviation of embedding drift across 3 runs: 6.625277733125135e-07

Mean attention drift across 3 runs: 0.0
Standard deviation of attention drift across 3 runs: 0.0

Mean attention spread across 3 runs: 134.20428

In [51]:
ex_cf_16 = run_experiments(dataset='cifar10', n_runs=3, kd_loss=1, full_CE=False) # SPECIAL CASE (WILL NOT ADD TO TABLE IN REPORT)
# use this case for analysis of why special CE doesn't work in cifar10. Problem is domain shift, this shows that kd_loss stabilizes it

Starting run 1.
Starting run 2.
Starting run 3.


In [100]:
report_performance(ex_cf_16)

Finished run 1.
Run 0 finished with accuracy 0.4343
Finished run 2.
Run 1 finished with accuracy 0.4133
Finished run 3.
Run 2 finished with accuracy 0.3978
Mean accuracy across 3 runs: 0.41513333333333335
Standard deviation of accuracy across 3 runs: 0.0183189337389853

Mean total confusion across 3 runs: 0.6591074444444444
Standard deviation of total confusion across 3 runs: 0.0010595532252024884

Mean intra-phase confusion across 3 runs: 0.6233385555555556
Standard deviation of intra-phase confusion across 3 runs: 0.002438204287615822

Mean per-task confusion across 3 runs: 0.16436333333333333
Standard deviation of per-task confusion across 3 runs: 0.006785838497275858

Mean output bias across 3 runs: 0.18649999999999997
Standard deviation of output bias across 3 runs: 0.024874283909290733

Mean embedding drift across 3 runs: 4.089497168858846
Standard deviation of embedding drift across 3 runs: 0.5539006621653072

Mean attention drift across 3 runs: 8.705355532144439e-08
Standard de

## Experiments for tiny imagenet:

In [15]:
in_ex_1 = run_experiments(dataset='tiny_imagenet', n_runs=3) # baseline

Starting run 1.
Starting run 2.
Starting run 3.


In [101]:
report_performance(in_ex_1)

Finished run 1.
Run 0 finished with accuracy 0.0634
Finished run 2.
Run 1 finished with accuracy 0.0628
Finished run 3.
Run 2 finished with accuracy 0.0594
Mean accuracy across 3 runs: 0.06186666666666667
Standard deviation of accuracy across 3 runs: 0.0021571586249817887

Mean total confusion across 3 runs: 0.9915338888888888
Standard deviation of total confusion across 3 runs: 7.065513532751243e-05

Mean intra-phase confusion across 3 runs: 0.7913187777777778
Standard deviation of intra-phase confusion across 3 runs: 0.00035209694582493447

Mean per-task confusion across 3 runs: 0.9691908888888889
Standard deviation of per-task confusion across 3 runs: 0.00022921056145273696

Mean output bias across 3 runs: 0.050666666666666665
Standard deviation of output bias across 3 runs: 0.005550075074567313

Mean embedding drift across 3 runs: 31.5178165435791
Standard deviation of embedding drift across 3 runs: 0.8299934470172012

Mean attention drift across 3 runs: 2.513306655880972e-08
Stand

In [52]:
in_ex_2 = run_experiments(dataset='tiny_imagenet', n_runs=3, full_CE=False) # baseline with special CE

Starting run 1.
Starting run 2.
Starting run 3.


In [102]:
report_performance(in_ex_2)

Finished run 1.
Run 0 finished with accuracy 0.0907
Finished run 2.
Run 1 finished with accuracy 0.0767
Finished run 3.
Run 2 finished with accuracy 0.0877
Mean accuracy across 3 runs: 0.08503333333333334
Standard deviation of accuracy across 3 runs: 0.007371114795831992

Mean total confusion across 3 runs: 0.9919312222222222
Standard deviation of total confusion across 3 runs: 8.691588100206384e-05

Mean intra-phase confusion across 3 runs: 0.7929924444444444
Standard deviation of intra-phase confusion across 3 runs: 0.0003788705925852719

Mean per-task confusion across 3 runs: 0.9698965555555555
Standard deviation of per-task confusion across 3 runs: 3.398420112231607e-05

Mean output bias across 3 runs: 0.0315
Standard deviation of output bias across 3 runs: 0.009183136719008379

Mean embedding drift across 3 runs: 28.166884740193684
Standard deviation of embedding drift across 3 runs: 0.8096294696119921

Mean attention drift across 3 runs: 2.139450292889163e-08
Standard deviation o

In [191]:
in_ex_3 = run_experiments(dataset='tiny_imagenet', n_runs=3, with_OOD=True) # baseline with OOD

Starting run 1.
Starting run 2.
Starting run 3.


In [223]:
report_performance(in_ex_3)

Finished run 1.
Run 0 finished with accuracy 0.0618
Finished run 2.
Run 1 finished with accuracy 0.065
Finished run 3.
Run 2 finished with accuracy 0.0681
Mean accuracy across 3 runs: 0.06496666666666667
Standard deviation of accuracy across 3 runs: 0.0031501322723551327

Mean total confusion across 3 runs: 0.9915974444444444
Standard deviation of total confusion across 3 runs: 0.00020883495166529553

Mean intra-phase confusion across 3 runs: 0.7915796666666667
Standard deviation of intra-phase confusion across 3 runs: 0.0007766400567687742

Mean per-task confusion across 3 runs: 0.9691761111111111
Standard deviation of per-task confusion across 3 runs: 0.00029397474066351605

Mean output bias across 3 runs: 0.048600000000000004
Standard deviation of output bias across 3 runs: 0.0019974984355438197

Mean embedding drift across 3 runs: 30.93061065673828
Standard deviation of embedding drift across 3 runs: 1.1639619282451465

Mean attention drift across 3 runs: 2.6318439145178726e-08
Sta

In [54]:
in_ex_4 = run_experiments(dataset='tiny_imagenet', n_runs=3, with_dropout=True) # baseline with dropout

Starting run 1.
Starting run 2.
Starting run 3.


In [104]:
report_performance(in_ex_4)

Finished run 1.
Run 0 finished with accuracy 0.0931
Finished run 2.
Run 1 finished with accuracy 0.0926
Finished run 3.
Run 2 finished with accuracy 0.0909
Mean accuracy across 3 runs: 0.0922
Standard deviation of accuracy across 3 runs: 0.0011532562594670837

Mean total confusion across 3 runs: 0.983536
Standard deviation of total confusion across 3 runs: 0.0005849770175922516

Mean intra-phase confusion across 3 runs: 0.7669123333333333
Standard deviation of intra-phase confusion across 3 runs: 0.002596272991124836

Mean per-task confusion across 3 runs: 0.9555955555555555
Standard deviation of per-task confusion across 3 runs: 0.0005886268704303028

Mean output bias across 3 runs: 0.14406666666666665
Standard deviation of output bias across 3 runs: 0.0028041635710730774

Mean embedding drift across 3 runs: 9.662531534830729
Standard deviation of embedding drift across 3 runs: 0.38065871027636633

Mean attention drift across 3 runs: 2.2121093413722073e-08
Standard deviation of attent

In [55]:
in_ex_5 = run_experiments(dataset='tiny_imagenet', n_runs=3, kd_loss=1) # baseline with KD loss

Starting run 1.
Starting run 2.
Starting run 3.


In [105]:
report_performance(in_ex_5)

Finished run 1.
Run 0 finished with accuracy 0.073
Finished run 2.
Run 1 finished with accuracy 0.0721
Finished run 3.
Run 2 finished with accuracy 0.0748
Mean accuracy across 3 runs: 0.0733
Standard deviation of accuracy across 3 runs: 0.0013747727084867567

Mean total confusion across 3 runs: 0.9870364444444444
Standard deviation of total confusion across 3 runs: 0.0004383674680414145

Mean intra-phase confusion across 3 runs: 0.7816603333333333
Standard deviation of intra-phase confusion across 3 runs: 0.0010330113476842643

Mean per-task confusion across 3 runs: 0.960741888888889
Standard deviation of per-task confusion across 3 runs: 0.0005856598684233329

Mean output bias across 3 runs: 0.09
Standard deviation of output bias across 3 runs: 0.0071526218968990625

Mean embedding drift across 3 runs: 18.61656316121419
Standard deviation of embedding drift across 3 runs: 0.5844481796040305

Mean attention drift across 3 runs: 1.4041124999501313e-08
Standard deviation of attention dri

In [178]:
in_ex_7 = run_experiments(dataset='tiny_imagenet', n_runs=3, optimiser_type='adam', lr=1e-5) # baseline with Adam

Starting run 1.
Starting run 2.
Starting run 3.


In [193]:
report_performance(in_ex_7)

Finished run 1.
Run 0 finished with accuracy 0.0642
Finished run 2.
Run 1 finished with accuracy 0.0691
Finished run 3.
Run 2 finished with accuracy 0.0673
Mean accuracy across 3 runs: 0.06686666666666666
Standard deviation of accuracy across 3 runs: 0.002478574859336175

Mean total confusion across 3 runs: 0.9847037777777777
Standard deviation of total confusion across 3 runs: 0.0003946498213371984

Mean intra-phase confusion across 3 runs: 0.7739423333333333
Standard deviation of intra-phase confusion across 3 runs: 0.0011745263347882576

Mean per-task confusion across 3 runs: 0.9565204444444445
Standard deviation of per-task confusion across 3 runs: 0.0006138018440047692

Mean output bias across 3 runs: 0.04523333333333334
Standard deviation of output bias across 3 runs: 0.0065248243909957655

Mean embedding drift across 3 runs: 166.10789998372397
Standard deviation of embedding drift across 3 runs: 13.98068507086361

Mean attention drift across 3 runs: 2.3109940638018264e-08
Standa

In [194]:
in_ex_8 = run_experiments(dataset='tiny_imagenet', n_runs=3, with_dropout=True, kd_loss=1, full_CE=False, with_OOD=True) # all improvements

Starting run 1.
Starting run 2.
Starting run 3.


In [225]:
report_performance(in_ex_8)

Finished run 1.
Run 0 finished with accuracy 0.1852
Finished run 2.
Run 1 finished with accuracy 0.1793
Finished run 3.
Run 2 finished with accuracy 0.1664
Mean accuracy across 3 runs: 0.17696666666666666
Standard deviation of accuracy across 3 runs: 0.009614745619793245

Mean total confusion across 3 runs: 0.9836961111111111
Standard deviation of total confusion across 3 runs: 0.00013817232933040937

Mean intra-phase confusion across 3 runs: 0.7741814444444445
Standard deviation of intra-phase confusion across 3 runs: 0.0006395016578324263

Mean per-task confusion across 3 runs: 0.9559936666666667
Standard deviation of per-task confusion across 3 runs: 0.0002571076298622276

Mean output bias across 3 runs: 0.07853333333333333
Standard deviation of output bias across 3 runs: 0.006536308846232205

Mean embedding drift across 3 runs: 7.189832051595052
Standard deviation of embedding drift across 3 runs: 0.27527592087940633

Mean attention drift across 3 runs: 1.545933216156037e-08
Standa

In [195]:
in_ex_9 = run_experiments(dataset='tiny_imagenet', n_runs=3, with_dropout=True, kd_loss=1, full_CE=True, with_OOD=True) # all improvements without special CE

Starting run 1.
Starting run 2.
Starting run 3.


In [227]:
report_performance(in_ex_9)

Finished run 1.
Run 0 finished with accuracy 0.0886
Finished run 2.
Run 1 finished with accuracy 0.0905
Finished run 3.
Run 2 finished with accuracy 0.0815
Mean accuracy across 3 runs: 0.08686666666666666
Standard deviation of accuracy across 3 runs: 0.004743767841424505

Mean total confusion across 3 runs: 0.9821005555555555
Standard deviation of total confusion across 3 runs: 0.0007761556497202294

Mean intra-phase confusion across 3 runs: 0.7707377777777779
Standard deviation of intra-phase confusion across 3 runs: 0.001749275289094585

Mean per-task confusion across 3 runs: 0.9528783333333333
Standard deviation of per-task confusion across 3 runs: 0.0010631142930090274

Mean output bias across 3 runs: 0.15813333333333332
Standard deviation of output bias across 3 runs: 0.00829236596716121

Mean embedding drift across 3 runs: 9.898040135701498
Standard deviation of embedding drift across 3 runs: 0.2241771029562697

Mean attention drift across 3 runs: 1.848128587738825e-08
Standard d

In [59]:
in_ex_10 = run_experiments(dataset='tiny_imagenet', n_runs=3, with_dropout=True, kd_loss=1, full_CE=False, with_OOD=False) # all improvements without OOD

Starting run 1.
Starting run 2.
Starting run 3.


In [118]:
report_performance(in_ex_10)

Finished run 1.
Run 0 finished with accuracy 0.1719
Finished run 2.
Run 1 finished with accuracy 0.1806
Finished run 3.
Run 2 finished with accuracy 0.1757
Mean accuracy across 3 runs: 0.17606666666666668
Standard deviation of accuracy across 3 runs: 0.004361574639202383

Mean total confusion across 3 runs: 0.9841695555555555
Standard deviation of total confusion across 3 runs: 0.00030854215151133487

Mean intra-phase confusion across 3 runs: 0.7761391111111111
Standard deviation of intra-phase confusion across 3 runs: 0.0008763111023763833

Mean per-task confusion across 3 runs: 0.9566862222222222
Standard deviation of per-task confusion across 3 runs: 0.0006991716792298495

Mean output bias across 3 runs: 0.07166666666666667
Standard deviation of output bias across 3 runs: 0.01032537327815966

Mean embedding drift across 3 runs: 6.84524933497111
Standard deviation of embedding drift across 3 runs: 0.09448083905415895

Mean attention drift across 3 runs: 1.3168940280094965e-08
Standar

In [196]:
in_ex_11 = run_experiments(dataset='tiny_imagenet', n_runs=3, with_dropout=False, kd_loss=1, full_CE=False, with_OOD=True) # all improvements without dropouts

Starting run 1.
Starting run 2.
Starting run 3.


In [234]:
in_ex_11[2] = ex_parallel.submit(run_experiment,dataset='tiny_imagenet', with_dropout=False, kd_loss=1, full_CE=False, with_OOD=True)

In [245]:
in_ex_11

[SlurmJob<job_id=12304548, task_id=0, state="COMPLETED">,
 SlurmJob<job_id=12304549, task_id=0, state="COMPLETED">,
 SlurmJob<job_id=12304614, task_id=0, state="COMPLETED">]

In [246]:
report_performance(in_ex_11)

Finished run 1.
Run 0 finished with accuracy 0.1409
Finished run 2.
Run 1 finished with accuracy 0.1459
Finished run 3.
Run 2 finished with accuracy 0.1458
Mean accuracy across 3 runs: 0.1442
Standard deviation of accuracy across 3 runs: 0.002858321185591296

Mean total confusion across 3 runs: 0.9885244444444444
Standard deviation of total confusion across 3 runs: 0.00030243021845881946

Mean intra-phase confusion across 3 runs: 0.7852413333333333
Standard deviation of intra-phase confusion across 3 runs: 0.000600662967062231

Mean per-task confusion across 3 runs: 0.9634887777777779
Standard deviation of per-task confusion across 3 runs: 0.0005432581065850459

Mean output bias across 3 runs: 0.022866666666666664
Standard deviation of output bias across 3 runs: 0.004488132499529549

Mean embedding drift across 3 runs: 10.45455805460612
Standard deviation of embedding drift across 3 runs: 0.210280833388122

Mean attention drift across 3 runs: 9.314878123511949e-09
Standard deviation of

In [197]:
in_ex_12 = run_experiments(dataset='tiny_imagenet', n_runs=3, with_dropout=True, kd_loss=0, full_CE=False, with_OOD=True) # all improvements without kd loss

Starting run 1.
Starting run 2.
Starting run 3.


In [241]:
report_performance(in_ex_12)

Finished run 1.
Run 0 finished with accuracy 0.1431
Finished run 2.
Run 1 finished with accuracy 0.1426
Finished run 3.
Run 2 finished with accuracy 0.1513
Mean accuracy across 3 runs: 0.14566666666666667
Standard deviation of accuracy across 3 runs: 0.004885011088353153

Mean total confusion across 3 runs: 0.983294
Standard deviation of total confusion across 3 runs: 0.0006924769550918922

Mean intra-phase confusion across 3 runs: 0.7717032222222222
Standard deviation of intra-phase confusion across 3 runs: 0.0013501539144085153

Mean per-task confusion across 3 runs: 0.9558331111111111
Standard deviation of per-task confusion across 3 runs: 0.0010580228065875225

Mean output bias across 3 runs: 0.10443333333333334
Standard deviation of output bias across 3 runs: 0.011602729563914404

Mean embedding drift across 3 runs: 8.168802579243978
Standard deviation of embedding drift across 3 runs: 0.25334120069607824

Mean attention drift across 3 runs: 2.0936146238910897e-08
Standard deviati

In [198]:
in_ex_13 = run_experiments(dataset='tiny_imagenet', n_runs=3, with_dropout=True, kd_loss=1, full_CE=False, with_OOD=True, optimiser_type='adam', lr=1e-5) # all improvements with Adam

Starting run 1.
Starting run 2.
Starting run 3.


In [243]:
report_performance(in_ex_13)

Finished run 1.
Run 0 finished with accuracy 0.1281
Finished run 2.
Run 1 finished with accuracy 0.1291
Finished run 3.
Run 2 finished with accuracy 0.12
Mean accuracy across 3 runs: 0.12573333333333334
Standard deviation of accuracy across 3 runs: 0.004990323970779183

Mean total confusion across 3 runs: 0.9822547777777778
Standard deviation of total confusion across 3 runs: 0.0008190082304113691

Mean intra-phase confusion across 3 runs: 0.7696024444444445
Standard deviation of intra-phase confusion across 3 runs: 0.0031863844075366706

Mean per-task confusion across 3 runs: 0.9528646666666667
Standard deviation of per-task confusion across 3 runs: 0.0014558943795635107

Mean output bias across 3 runs: 0.07866666666666668
Standard deviation of output bias across 3 runs: 0.004168133075290815

Mean embedding drift across 3 runs: 19.478856404622395
Standard deviation of embedding drift across 3 runs: 1.8317047127459407

Mean attention drift across 3 runs: 9.389580646579355e-09
Standard 

In [63]:
in_ex_14 = run_experiments(dataset='tiny_imagenet', n_runs=3, with_dropout=True, joint_training=True) # joint training

Starting run 1.
Starting run 2.
Starting run 3.


In [125]:
report_performance(in_ex_14)

Finished run 1.
Run 0 finished with accuracy 0.3287
Finished run 2.
Run 1 finished with accuracy 0.3324
Finished run 3.
Run 2 finished with accuracy 0.3397
Mean accuracy across 3 runs: 0.3336
Standard deviation of accuracy across 3 runs: 0.0055973207876626185

Mean total confusion across 3 runs: 0.9785554444444444
Standard deviation of total confusion across 3 runs: 0.0005464743903050235

Mean intra-phase confusion across 3 runs: 0.0
Standard deviation of intra-phase confusion across 3 runs: 0.0

Mean per-task confusion across 3 runs: 0.9785554444444444
Standard deviation of per-task confusion across 3 runs: 0.0005464743903050235

Mean output bias across 3 runs: 0
Standard deviation of output bias across 3 runs: 0.0

Mean embedding drift across 3 runs: 0.00012565715587697923
Standard deviation of embedding drift across 3 runs: 2.0208399336555662e-05

Mean attention drift across 3 runs: 0.0
Standard deviation of attention drift across 3 runs: 0.0

Mean attention spread across 3 runs: 52