In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from sklearn import metrics
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Subset
import numpy as np
import copy
import statistics
import matplotlib.pyplot as plt
import time
from tqdm import tqdm

import sys
sys.path.append('./../src')
import globals
from model import Net
from training import train_model, train_model_CL
from visualizations import plot_embeddings, plot_confusion_matrix
from feature_attribution import Feature_Importance_Evaluations
from pytorch_utils import get_features, get_labels
from embedding_measurements import measure_embedding_confusion_knn, measure_embedding_drift

In [2]:
SEED = globals.SEED
DEVICE = globals.DEVICE
full_trainset = globals.full_trainset
trainset = globals.trainset
testset = globals.testset
trainloaders = globals.trainloaders
valloaders = globals.valloaders
testloaders = globals.testloaders

In [3]:
# This is the two-step process used to prepare the
# data for use with the convolutional neural network.

# First step is to convert Python Image Library (PIL) format
# to PyTorch tensors.

# Second step is used to normalize the data by specifying a 
# mean and standard deviation for each of the three channels.
# This will convert the data from [0,1] to [-1,1]

# Normalization of data should help speed up conversion and
# reduce the chance of vanishing gradients with certain 
# activation functions.
def initialize_data():
    transform = transforms.Compose([
        transforms.ToTensor()
        #transforms.Normalize((0.5,), (0.5,))  # Normalizes to mean 0.5 and std 0.5 for the single channel
    ])

    globals.full_trainset = torchvision.datasets.MNIST('./../data/', train=True, download=True,
                                transform=transform)
    targets = np.array(globals.full_trainset.targets)

    # Perform stratified split
    train_indices, val_indices = train_test_split(
        np.arange(len(targets)),
        test_size=0.01,
        stratify=targets
    )

    # Create subsets
    valset = Subset(globals.full_trainset, val_indices)
    globals.trainset = Subset(globals.full_trainset, train_indices)

    globals.testset = torchvision.datasets.MNIST('./../data/', train=False, download=True,
                                transform=transform)

    # Define class pairs for each subset
    class_pairs = [tuple(range(i*globals.CLASSES_PER_ITER,(i+1)*globals.CLASSES_PER_ITER)) for i in range(globals.ITERATIONS)]
    #print(class_pairs)

    # Dictionary to hold data loaders for each subset
    globals.trainloaders = []
    globals.testloaders = []
    globals.valloaders = []
    subset_indices = []
    # Loop over each class pair
    for i, t in enumerate(class_pairs):
        # Get indices of images belonging to the specified class pair
        subs_ind = [idx for idx, (_, label) in enumerate(globals.trainset) if label in list(t)]
        val_subset_indices = [idx for idx, (_, label) in enumerate(valset) if label in list(t)]
        test_subset_indices = [idx for idx, (_, label) in enumerate(globals.testset) if label in list(t)]
        # Create a subset for the current class pair
        train_subset = Subset(globals.trainset, subs_ind)
        globals.trainloaders.append(DataLoader(train_subset, batch_size=globals.BATCH_SIZE, shuffle=True, pin_memory=True, num_workers = 0))

        subset_indices.append(subs_ind)
        
        val_subset = Subset(valset, val_subset_indices)
        globals.valloaders.append(DataLoader(val_subset, batch_size=500, shuffle=False))

        test_subset = Subset(globals.testset, test_subset_indices)
        globals.testloaders.append(DataLoader(test_subset, batch_size=500, shuffle=False))


In [4]:
def run_experiment(
        verbose = False,
        stopOnLoss = 0.03,
        full_CE = True,
        with_OOD = False,
        kd_loss = 0,
        stopOnValAcc = None,
        epochs = 1000000,
        with_dropout = False
        ):
    def _print(*args, **kwargs):
        if verbose:
            print(*args, **kwargs)
    if with_OOD:
        globals.OOD_CLASS = 1
    else:
        globals.OOD_CLASS = 0
    initialize_data()
    prevModel = None
    globals.BATCH_SIZE=4
    globals.WITH_DROPOUT = with_dropout

    #[Denis] added code:
    Feature_Importance_Eval=Feature_Importance_Evaluations(globals.testloaders, DEVICE)

    for i in tqdm(range(globals.ITERATIONS), desc="Experiment Progress"):
        model = Net((i+1)*(globals.CLASSES_PER_ITER+globals.OOD_CLASS))
        if prevModel is not None:
            with torch.no_grad():
                model.copyPrev(prevModel)
        train_loader = globals.trainloaders[i]
        val_loader = globals.valloaders[i]
        if prevModel:
            _print("CL TRAIN!!")
            train_model_CL(
                model,
                prevModel,
                train_loader,
                val_loader,
                i,
                verbose,
                epochs,
                True,
                freeze_nonzero_params=False,
                l1_loss=0,
                ewc_loss=0,
                kd_loss=kd_loss,
                distance_loss=0,
                center_loss=0,
                param_reuse_loss=0,
                stopOnLoss=stopOnLoss,
                stopOnValAcc = stopOnValAcc,
                full_CE=full_CE
                )
        else:
            train_model(
                model, 
                train_loader, 
                val_loader, 
                verbose, 
                epochs=epochs, 
                l1_loss=0,
                stopOnLoss=stopOnLoss,
                center_loss =0,
                )

        #[Denis] added code:
        Feature_Importance_Eval.Task_Feature_Attribution(model, i)
        
        if verbose or i == globals.ITERATIONS-1:
            _print("Starting evaluation")
            _print("ITERATION", i+1)
            _print("ACCURACIES PER TASK:")
            accumPred = []
            all_labels = []
            all_embeddings = []
            with torch.no_grad():
                for j in range(i+1):
                    val_loader = globals.testloaders[j]
                    val_labels = get_labels(val_loader).to(DEVICE)
                    all_labels.append(val_labels)
                    model.eval()
                    pred, embeddings = model.get_pred_and_embeddings((get_features(val_loader).to(DEVICE)))
                    model.train()
                    accumPred.append(pred)
                    all_embeddings.append(embeddings)
                    sliced_pred = pred[:, j*(globals.CLASSES_PER_ITER+globals.OOD_CLASS):(j+1)*(globals.CLASSES_PER_ITER+globals.OOD_CLASS)]
                    _, predicted = torch.max(sliced_pred, 1)  # Get the class predictions
                    predicted += j*globals.CLASSES_PER_ITER
                    correct = (predicted == val_labels).sum().item()  # Count how many were correct
                    accuracy = correct / val_labels.size(0)  # Accuracy as a percentage
                    _print(str(accuracy), end=' ')
            accumPred = torch.cat(accumPred)
            all_labels = torch.cat(all_labels)
            all_embeddings = torch.cat(all_embeddings)
            predicted = []
            for x in accumPred:
                if globals.OOD_CLASS == 1:
                    x_pred = x[[i for i in range(x.size(0)) if (i + 1) % (globals.CLASSES_PER_ITER+1) != 0]]
                else:
                    x_pred = x
                x_pred = torch.softmax(x_pred, dim=-1)
                max = 0
                for (k, v) in enumerate(x_pred):
                    if v > max:
                        max = v
                        p = k
                predicted.append(p)
            predicted = torch.tensor(predicted).to(DEVICE)
            correct = (predicted == all_labels).sum().item()  # Count how many were correct
            accuracy = correct / all_labels.size(0)  # Accuracy as a percentage
            _print("Accuracy on tasks so far:", accuracy)

            embedding_drift = measure_embedding_drift(all_embeddings, all_labels, model.prev_test_embedding_centers)
            _print("Average embedding drift based on centroids:", embedding_drift)
            total_confusion, intra_phase_confusion, per_task_confusion = measure_embedding_confusion_knn(all_embeddings, all_labels, k = 1000, task=i+1)
            _print("Total confusion", total_confusion)
            _print("Intra-phase confusion", intra_phase_confusion)
            _print("Per task confusions", per_task_confusion)
            if verbose:
                plot_confusion_matrix(predicted.cpu(), all_labels.cpu(), list(range(globals.CLASSES_PER_ITER*(i+1))))
        prevModel = copy.deepcopy(model)
        
    #[Denis] added code:
    [avg_att_diff,att_diffs,_,_,avg_att_spread,att_spreads]=Feature_Importance_Eval.Get_Feature_Change_Score(prevModel)
    _print("Average SHAPC values (ordered as tasks):", att_diffs)
    _print("Averaged SHAPC value (the smaller the better):", avg_att_diff)
    _print("Average attention spread values (ordered as tasks):", att_spreads)
    _print("Averaged attention spread value (the bigger the better):", avg_att_spread)
    Feature_Importance_Eval.Save_Random_Picture_Salency() #prints the salcency maps for 1 example by class (first row: image, second row: salency map after training, third row: salency map after training task where class is included)
    
    return accuracy, total_confusion, intra_phase_confusion, per_task_confusion, embedding_drift, avg_att_diff, avg_att_spread

In [5]:
def run_experiments(n_runs=1, *args, **kwargs):
    verbose = kwargs.get('verbose', None)
    def _print(*args, **kwargs):
        if verbose:
            print(*args, **kwargs)
    def report_stats(data, name):
        print(name, data)
        mean = statistics.mean(data)
        std = statistics.stdev(data)
        print(f"Mean " + name + f" across {n_runs} runs: {mean}")
        print(f"Standard deviation of " + name + f" across {n_runs} runs: {std}\n")
    accuracies = []
    total_confusions = []
    intra_phase_confusions = []
    per_task_confusions = []
    att_diffs = []
    embedding_drifts = []
    att_spreads = []
    for r in range(n_runs):
        print(f"Starting run {r+1}.")
        accuracy, total_confusion, intra_phase_confusion, per_task_confusion, embedding_drift, avg_att_diff, avg_att_spread = run_experiment(*args, **kwargs)
        accuracies.append(accuracy)
        total_confusions.append(total_confusion)
        intra_phase_confusions.append(intra_phase_confusion)
        per_task_confusions.append(per_task_confusion)
        att_diffs.append(avg_att_diff)
        embedding_drifts.append(embedding_drift)
        att_spreads.append(avg_att_spread)
        _print(f"Run {r} finished with accuracy {accuracy}")

    report_stats(accuracies, "accuracy")
    report_stats(total_confusions, "total confusion")
    report_stats(intra_phase_confusions, "intra-phase confusion")
    report_stats(per_task_confusions, "per-task confusion")
    report_stats(embedding_drifts, "embedding drift")
    report_stats(att_diffs, "attention drift")
    report_stats(att_spreads, "attention spread")

In [6]:
run_experiments(n_runs=3, verbose=False, stopOnLoss = 0.02)

Starting run 1.


Experiment Progress: 100%|██████████| 5/5 [02:48<00:00, 33.69s/it]


Starting run 2.


Experiment Progress: 100%|██████████| 5/5 [03:27<00:00, 41.48s/it]


Starting run 3.


Experiment Progress: 100%|██████████| 5/5 [02:15<00:00, 27.04s/it]


accuracy [0.1964, 0.1975, 0.1977]
Mean accuracy across 3 runs: 0.1972
Standard deviation of accuracy across 3 runs: 0.0007000000000000023

total confusion [0.415069, 0.3985588, 0.40981270000000003]
Mean total confusion across 3 runs: 0.4078135
Standard deviation of total confusion across 3 runs: 0.008434706662949239

intra-phase confusion [0.409624, 0.3929266, 0.40238399999999996]
Mean intra-phase confusion across 3 runs: 0.40164486666666666
Standard deviation of intra-phase confusion across 3 runs: 0.008373203105940587

per-task confusion [0.07924270031978235, 0.07444721949064895, 0.08855066330220647]
Mean per-task confusion across 3 runs: 0.08074686103754593
Standard deviation of per-task confusion across 3 runs: 0.007171028966270273

embedding drift [9.406989097595215, 8.95884895324707, 9.783537864685059]
Mean embedding drift across 3 runs: 9.383125305175781
Standard deviation of embedding drift across 3 runs: 0.41286203579848785

attention drift [6.822499271845407e-06, 5.3854567178

In [7]:
run_experiments(n_runs=3, verbose=False, stopOnLoss = 0.02, with_dropout=True)

Starting run 1.


Experiment Progress: 100%|██████████| 5/5 [03:11<00:00, 38.39s/it]


Starting run 2.


Experiment Progress: 100%|██████████| 5/5 [03:19<00:00, 39.99s/it]


Starting run 3.


Experiment Progress: 100%|██████████| 5/5 [03:30<00:00, 42.05s/it]


accuracy [0.1976, 0.198, 0.1977]
Mean accuracy across 3 runs: 0.19776666666666667
Standard deviation of accuracy across 3 runs: 0.00020816659994662146

total confusion [0.43540540000000005, 0.4399872, 0.4529856]
Mean total confusion across 3 runs: 0.44279273333333335
Standard deviation of total confusion across 3 runs: 0.009119710575085854

intra-phase confusion [0.42733049999999995, 0.42952840000000003, 0.44323429999999997]
Mean intra-phase confusion across 3 runs: 0.4333644
Standard deviation of intra-phase confusion across 3 runs: 0.008617939754372845

per-task confusion [0.10210612156045648, 0.10872054799485395, 0.11486059058351616]
Mean per-task confusion across 3 runs: 0.10856242004627553
Standard deviation of per-task confusion across 3 runs: 0.006378704676589007

embedding drift [7.711767673492432, 8.657382011413574, 8.532083511352539]
Mean embedding drift across 3 runs: 8.300411065419516
Standard deviation of embedding drift across 3 runs: 0.5136153333711453

attention drift [

In [8]:
run_experiments(n_runs=3, verbose=False, stopOnLoss = 0.02, kd_loss=1)

Starting run 1.


Experiment Progress: 100%|██████████| 5/5 [02:42<00:00, 32.43s/it]


Starting run 2.


Experiment Progress: 100%|██████████| 5/5 [02:43<00:00, 32.67s/it]


Starting run 3.


Experiment Progress: 100%|██████████| 5/5 [02:51<00:00, 34.39s/it]


accuracy [0.2011, 0.196, 0.2104]
Mean accuracy across 3 runs: 0.2025
Standard deviation of accuracy across 3 runs: 0.007301369734508722

total confusion [0.3813442, 0.38061730000000005, 0.37627330000000003]
Mean total confusion across 3 runs: 0.3794116
Standard deviation of total confusion across 3 runs: 0.0027420414055954728

intra-phase confusion [0.37786030000000004, 0.37774339999999995, 0.3734803]
Mean intra-phase confusion across 3 runs: 0.3763613333333333
Standard deviation of intra-phase confusion across 3 runs: 0.0024957325985235997

per-task confusion [0.06769524134907531, 0.060844546645660566, 0.06240981341149352]
Mean per-task confusion across 3 runs: 0.0636498671354098
Standard deviation of per-task confusion across 3 runs: 0.0035897499089069832

embedding drift [10.058082580566406, 10.744964599609375, 11.043731689453125]
Mean embedding drift across 3 runs: 10.615592956542969
Standard deviation of embedding drift across 3 runs: 0.5053996517455982

attention drift [6.3650068

In [9]:
run_experiments(n_runs=3, verbose=False, stopOnLoss = 0.02, full_CE=False)

Starting run 1.


Experiment Progress: 100%|██████████| 5/5 [01:53<00:00, 22.63s/it]


Starting run 2.


Experiment Progress: 100%|██████████| 5/5 [01:59<00:00, 23.81s/it]


Starting run 3.


Experiment Progress: 100%|██████████| 5/5 [02:01<00:00, 24.32s/it]


accuracy [0.5644, 0.6899, 0.6116]
Mean accuracy across 3 runs: 0.6219666666666667
Standard deviation of accuracy across 3 runs: 0.0633889843216732

total confusion [0.44386689999999995, 0.4237191, 0.4426772]
Mean total confusion across 3 runs: 0.4367544
Standard deviation of total confusion across 3 runs: 0.011304562401526183

intra-phase confusion [0.44077370000000005, 0.42138790000000004, 0.43978740000000005]
Mean intra-phase confusion across 3 runs: 0.43398300000000006
Standard deviation of intra-phase confusion across 3 runs: 0.010918818843171642

per-task confusion [0.0753906675313886, 0.06373890813393235, 0.07395251927662698]
Mean per-task confusion across 3 runs: 0.07102736498064931
Standard deviation of per-task confusion across 3 runs: 0.006352815911508058

embedding drift [5.349673271179199, 5.414111137390137, 5.4888596534729]
Mean embedding drift across 3 runs: 5.417548020680745
Standard deviation of embedding drift across 3 runs: 0.0696568114333913

attention drift [3.43882

In [10]:
run_experiments(n_runs=3, verbose=False, stopOnLoss = 0.02, with_OOD=True)

Starting run 1.


Experiment Progress: 100%|██████████| 5/5 [03:03<00:00, 36.64s/it]


Starting run 2.


Experiment Progress: 100%|██████████| 5/5 [03:12<00:00, 38.51s/it]


Starting run 3.


Experiment Progress: 100%|██████████| 5/5 [03:01<00:00, 36.28s/it]


accuracy [0.1956, 0.1978, 0.1963]
Mean accuracy across 3 runs: 0.19656666666666667
Standard deviation of accuracy across 3 runs: 0.0011239810200058277

total confusion [0.31434470000000003, 0.3418599, 0.33706820000000004]
Mean total confusion across 3 runs: 0.33109093333333334
Standard deviation of total confusion across 3 runs: 0.014699229413929593

intra-phase confusion [0.31037950000000003, 0.3371006, 0.3304551]
Mean intra-phase confusion across 3 runs: 0.3259784
Standard deviation of intra-phase confusion across 3 runs: 0.01391168316452036

per-task confusion [0.05960315311095501, 0.0664080323482237, 0.07616511991531197]
Mean per-task confusion across 3 runs: 0.06739210179149689
Standard deviation of per-task confusion across 3 runs: 0.008324721052926847

embedding drift [9.44377326965332, 10.233915328979492, 10.4202299118042]
Mean embedding drift across 3 runs: 10.032639503479004
Standard deviation of embedding drift across 3 runs: 0.5184118651693069

attention drift [3.0905997055

In [11]:
run_experiments(n_runs=3, verbose=False, stopOnLoss = 0.02, full_CE = False, with_OOD=True)

Starting run 1.


Experiment Progress: 100%|██████████| 5/5 [03:09<00:00, 37.97s/it]


Starting run 2.


Experiment Progress: 100%|██████████| 5/5 [03:03<00:00, 36.79s/it]


Starting run 3.


Experiment Progress: 100%|██████████| 5/5 [03:11<00:00, 38.21s/it]


accuracy [0.7654, 0.7314, 0.7328]
Mean accuracy across 3 runs: 0.7432
Standard deviation of accuracy across 3 runs: 0.019238503060269495

total confusion [0.35066620000000004, 0.3228105, 0.3514534]
Mean total confusion across 3 runs: 0.34164336666666667
Standard deviation of total confusion across 3 runs: 0.016314489606583896

intra-phase confusion [0.347985, 0.3202524, 0.3489135]
Mean intra-phase confusion across 3 runs: 0.33905029999999997
Standard deviation of intra-phase confusion across 3 runs: 0.016286077224734012

per-task confusion [0.059980212444068835, 0.05483852550189692, 0.059442913083810736]
Mean per-task confusion across 3 runs: 0.058087217009925494
Standard deviation of per-task confusion across 3 runs: 0.0028262466340335884

embedding drift [6.5190749168396, 5.821240425109863, 6.616893768310547]
Mean embedding drift across 3 runs: 6.319069703420003
Standard deviation of embedding drift across 3 runs: 0.4338981731730037

attention drift [1.9613252560486547e-06, 1.3824820

In [12]:
run_experiments(n_runs=3, verbose=False, stopOnLoss = 0.02, full_CE = False, kd_loss = 1, with_OOD=True)

Starting run 1.


Experiment Progress: 100%|██████████| 5/5 [03:44<00:00, 44.95s/it]


Starting run 2.


Experiment Progress: 100%|██████████| 5/5 [04:01<00:00, 48.29s/it]


Starting run 3.


Experiment Progress: 100%|██████████| 5/5 [03:43<00:00, 44.75s/it]


accuracy [0.7769, 0.8705, 0.8392]
Mean accuracy across 3 runs: 0.8288666666666666
Standard deviation of accuracy across 3 runs: 0.047647910062597006

total confusion [0.35181189999999996, 0.33913230000000005, 0.35328760000000003]
Mean total confusion across 3 runs: 0.34807726666666666
Standard deviation of total confusion across 3 runs: 0.007781628629235198

intra-phase confusion [0.3490227, 0.3359761, 0.3500449]
Mean intra-phase confusion across 3 runs: 0.3450145666666667
Standard deviation of intra-phase confusion across 3 runs: 0.007844210156117273

per-task confusion [0.05687022024336545, 0.05976064777910328, 0.06189915734324638]
Mean per-task confusion across 3 runs: 0.05951000845523837
Standard deviation of per-task confusion across 3 runs: 0.002523819950321396

embedding drift [5.19499397277832, 5.30782413482666, 5.013857841491699]
Mean embedding drift across 3 runs: 5.17222531636556
Standard deviation of embedding drift across 3 runs: 0.1482998792663225

attention drift [9.8427

In [13]:
run_experiments(n_runs=3, verbose=False, stopOnLoss = 0.02, full_CE = False, kd_loss = 1)

Starting run 1.


Experiment Progress: 100%|██████████| 5/5 [02:31<00:00, 30.25s/it]


Starting run 2.


Experiment Progress: 100%|██████████| 5/5 [02:25<00:00, 29.20s/it]


Starting run 3.


Experiment Progress: 100%|██████████| 5/5 [02:23<00:00, 28.79s/it]


accuracy [0.6707, 0.7146, 0.708]
Mean accuracy across 3 runs: 0.6977666666666666
Standard deviation of accuracy across 3 runs: 0.023671572261540513

total confusion [0.39216870000000004, 0.3781702, 0.3921686]
Mean total confusion across 3 runs: 0.3875025
Standard deviation of total confusion across 3 runs: 0.008082008875892178

intra-phase confusion [0.3900099, 0.3760787, 0.3900641]
Mean intra-phase confusion across 3 runs: 0.3853842333333333
Standard deviation of intra-phase confusion across 3 runs: 0.008058873827857938

per-task confusion [0.05703432950829794, 0.05830395668970132, 0.057738022293946825]
Mean per-task confusion across 3 runs: 0.0576921028306487
Standard deviation of per-task confusion across 3 runs: 0.0006360579712351418

embedding drift [5.050500869750977, 4.625858306884766, 4.97245454788208]
Mean embedding drift across 3 runs: 4.882937908172607
Standard deviation of embedding drift across 3 runs: 0.22603152023125586

attention drift [8.581989060222118e-07, 9.16373032

In [14]:
run_experiments(n_runs=3, verbose=False, stopOnLoss = 0.02, full_CE = False, kd_loss = 1, with_dropout=True)

Starting run 1.


Experiment Progress: 100%|██████████| 5/5 [04:06<00:00, 49.22s/it]


Starting run 2.


Experiment Progress: 100%|██████████| 5/5 [03:54<00:00, 46.80s/it]


Starting run 3.


Experiment Progress: 100%|██████████| 5/5 [04:16<00:00, 51.29s/it]


accuracy [0.7941, 0.7589, 0.7547]
Mean accuracy across 3 runs: 0.7692333333333333
Standard deviation of accuracy across 3 runs: 0.021637313449994974

total confusion [0.3988575, 0.3862567, 0.38862149999999995]
Mean total confusion across 3 runs: 0.39124523333333333
Standard deviation of total confusion across 3 runs: 0.00669761320272627

intra-phase confusion [0.39698809999999995, 0.38445490000000004, 0.38682479999999997]
Mean intra-phase confusion across 3 runs: 0.3894226
Standard deviation of intra-phase confusion across 3 runs: 0.006658205403109724

per-task confusion [0.06140189885493008, 0.056351439357886356, 0.06230954945676648]
Mean per-task confusion across 3 runs: 0.06002096255652764
Standard deviation of per-task confusion across 3 runs: 0.0032101413960725165

embedding drift [4.694825649261475, 4.6878814697265625, 4.639401435852051]
Mean embedding drift across 3 runs: 4.674036184946696
Standard deviation of embedding drift across 3 runs: 0.030194863652449933

attention drift

In [15]:
run_experiments(n_runs=3, verbose=False, stopOnLoss = 0.02, full_CE = False, kd_loss = 1, with_OOD=True, with_dropout=True)

Starting run 1.


Experiment Progress: 100%|██████████| 5/5 [06:52<00:00, 82.60s/it] 


Starting run 2.


Experiment Progress: 100%|██████████| 5/5 [07:25<00:00, 89.03s/it] 


Starting run 3.


Experiment Progress: 100%|██████████| 5/5 [08:30<00:00, 102.05s/it]


accuracy [0.8426, 0.8516, 0.8419]
Mean accuracy across 3 runs: 0.8453666666666667
Standard deviation of accuracy across 3 runs: 0.005409559439855847

total confusion [0.3361244, 0.3372204, 0.32370750000000004]
Mean total confusion across 3 runs: 0.3323507666666667
Standard deviation of total confusion across 3 runs: 0.007505321312464434

intra-phase confusion [0.3335331, 0.3346409, 0.32130820000000004]
Mean intra-phase confusion across 3 runs: 0.32982740000000005
Standard deviation of intra-phase confusion across 3 runs: 0.0073986067397855265

per-task confusion [0.0553429920459035, 0.055829422516431484, 0.05180261498255838]
Mean per-task confusion across 3 runs: 0.05432500984829779
Standard deviation of per-task confusion across 3 runs: 0.0021979559924384094

embedding drift [5.326330184936523, 4.6283979415893555, 5.182539463043213]
Mean embedding drift across 3 runs: 5.045755863189697
Standard deviation of embedding drift across 3 runs: 0.36852363420487383

attention drift [1.6006146

In [16]:
run_experiments(n_runs=3, verbose=False, stopOnLoss = 0.02, full_CE = False, kd_loss = 1, with_OOD=True)

Starting run 1.


Experiment Progress: 100%|██████████| 5/5 [03:40<00:00, 44.04s/it]


Starting run 2.


Experiment Progress: 100%|██████████| 5/5 [03:41<00:00, 44.21s/it]


Starting run 3.


Experiment Progress: 100%|██████████| 5/5 [03:31<00:00, 42.22s/it]


accuracy [0.8188, 0.857, 0.8669]
Mean accuracy across 3 runs: 0.8475666666666667
Standard deviation of accuracy across 3 runs: 0.025399671913891607

total confusion [0.3407717, 0.34503989999999995, 0.3564543]
Mean total confusion across 3 runs: 0.34742196666666664
Standard deviation of total confusion across 3 runs: 0.008108123512461646

intra-phase confusion [0.33786700000000003, 0.3416525, 0.35379879999999997]
Mean intra-phase confusion across 3 runs: 0.3444394333333333
Standard deviation of intra-phase confusion across 3 runs: 0.008323509528037605

per-task confusion [0.05994734114534672, 0.06377310507573661, 0.06036305765982066]
Mean per-task confusion across 3 runs: 0.06136116796030133
Standard deviation of per-task confusion across 3 runs: 0.002099115418863926

embedding drift [4.810020923614502, 5.231167316436768, 5.27020263671875]
Mean embedding drift across 3 runs: 5.10379695892334
Standard deviation of embedding drift across 3 runs: 0.25516506079146867

attention drift [7.727

In [17]:
# training on entire dataset
globals.ITERATIONS = 1
globals.CLASSES_PER_ITER = 10
run_experiments(n_runs=3, verbose=False, stopOnLoss = 0.02)

Starting run 1.


Experiment Progress: 100%|██████████| 1/1 [03:27<00:00, 207.65s/it]


Starting run 2.


Experiment Progress: 100%|██████████| 1/1 [03:28<00:00, 208.37s/it]


Starting run 3.


Experiment Progress: 100%|██████████| 1/1 [03:33<00:00, 213.33s/it]


accuracy [0.9875, 0.9905, 0.9917]
Mean accuracy across 3 runs: 0.9899
Standard deviation of accuracy across 3 runs: 0.0021633307652783864

total confusion [0.20088799999999996, 0.18580180000000002, 0.19334340000000005]
Mean total confusion across 3 runs: 0.1933444
Standard deviation of total confusion across 3 runs: 0.0075431000497142775

intra-phase confusion [0.0, 0.0, 0.0]
Mean intra-phase confusion across 3 runs: 0.0
Standard deviation of intra-phase confusion across 3 runs: 0.0

per-task confusion [0.20088799999999996, 0.18580180000000002, 0.19334340000000005]
Mean per-task confusion across 3 runs: 0.1933444
Standard deviation of per-task confusion across 3 runs: 0.0075431000497142775

embedding drift [7.900726814114023e-06, 7.735938197583891e-06, 7.888202162575908e-06]
Mean embedding drift across 3 runs: 7.841622391424607e-06
Standard deviation of embedding drift across 3 runs: 9.173918652821886e-08

attention drift [1.453547069852216e-20, 1.6469810920513368e-20, 1.82196333463402

In [18]:
torch.cuda.empty_cache()