In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from sklearn import metrics
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Subset
import numpy as np
import copy
import statistics
import matplotlib.pyplot as plt
import time
from tqdm import tqdm

import sys
sys.path.append('./../src')
import globals
from model import Net
from training import train_model, train_model_CL
from visualizations import plot_embeddings, plot_confusion_matrix
from feature_attribution import Feature_Importance_Evaluations
from pytorch_utils import get_features, get_labels
from embedding_measurements import measure_embedding_confusion_knn, measure_embedding_drift

In [17]:
SEED = globals.SEED
DEVICE = globals.DEVICE
full_trainset = globals.full_trainset
trainset = globals.trainset
testset = globals.testset
trainloaders = globals.trainloaders
valloaders = globals.valloaders
testloaders = globals.testloaders

In [18]:
# This is the two-step process used to prepare the
# data for use with the convolutional neural network.

# First step is to convert Python Image Library (PIL) format
# to PyTorch tensors.

# Second step is used to normalize the data by specifying a 
# mean and standard deviation for each of the three channels.
# This will convert the data from [0,1] to [-1,1]

# Normalization of data should help speed up conversion and
# reduce the chance of vanishing gradients with certain 
# activation functions.
def initialize_data():
    transform = transforms.Compose([
        transforms.ToTensor()
        #transforms.Normalize((0.5,), (0.5,))  # Normalizes to mean 0.5 and std 0.5 for the single channel
    ])

    globals.full_trainset = torchvision.datasets.MNIST('./../data/', train=True, download=True,
                                transform=transform)
    targets = np.array(globals.full_trainset.targets)

    # Perform stratified split
    train_indices, val_indices = train_test_split(
        np.arange(len(targets)),
        test_size=0.01,
        stratify=targets
    )

    # Create subsets
    valset = Subset(globals.full_trainset, val_indices)
    globals.trainset = Subset(globals.full_trainset, train_indices)

    globals.testset = torchvision.datasets.MNIST('./../data/', train=False, download=True,
                                transform=transform)

    # Define class pairs for each subset
    class_pairs = [tuple(range(i*globals.CLASSES_PER_ITER,(i+1)*globals.CLASSES_PER_ITER)) for i in range(globals.ITERATIONS)]
    #print(class_pairs)

    # Dictionary to hold data loaders for each subset
    globals.trainloaders = []
    globals.testloaders = []
    globals.valloaders = []
    subset_indices = []
    # Loop over each class pair
    for i, t in enumerate(class_pairs):
        # Get indices of images belonging to the specified class pair
        subs_ind = [idx for idx, (_, label) in enumerate(globals.trainset) if label in list(t)]
        val_subset_indices = [idx for idx, (_, label) in enumerate(valset) if label in list(t)]
        test_subset_indices = [idx for idx, (_, label) in enumerate(globals.testset) if label in list(t)]
        # Create a subset for the current class pair
        train_subset = Subset(globals.trainset, subs_ind)
        globals.trainloaders.append(DataLoader(train_subset, batch_size=globals.BATCH_SIZE, shuffle=True, pin_memory=True, num_workers = 0))

        subset_indices.append(subs_ind)
        
        val_subset = Subset(valset, val_subset_indices)
        globals.valloaders.append(DataLoader(val_subset, batch_size=500, shuffle=False))

        test_subset = Subset(globals.testset, test_subset_indices)
        globals.testloaders.append(DataLoader(test_subset, batch_size=500, shuffle=False))


In [19]:
def run_experiment(
        verbose = False,
        stopOnLoss = 0.03,
        full_CE = True,
        with_OOD = False,
        kd_loss = 0,
        stopOnValAcc = None,
        epochs = 1000000,
        with_dropout = False
        ):
    def _print(*args, **kwargs):
        if verbose:
            print(*args, **kwargs)
    if with_OOD:
        globals.OOD_CLASS = 1
    else:
        globals.OOD_CLASS = 0
    initialize_data()
    prevModel = None
    globals.BATCH_SIZE=4
    globals.WITH_DROPOUT = with_dropout

    #[Denis] added code:
    Feature_Importance_Eval=Feature_Importance_Evaluations(globals.valloaders, DEVICE)

    for i in tqdm(range(globals.ITERATIONS), desc="Experiment Progress"):
        model = Net((i+1)*(globals.CLASSES_PER_ITER+globals.OOD_CLASS))
        if prevModel is not None:
            with torch.no_grad():
                model.copyPrev(prevModel)
        train_loader = globals.trainloaders[i]
        val_loader = globals.valloaders[i]
        if prevModel:
            _print("CL TRAIN!!")
            train_model_CL(
                model,
                prevModel,
                train_loader,
                val_loader,
                i,
                verbose,
                epochs,
                True,
                freeze_nonzero_params=False,
                l1_loss=0,
                ewc_loss=0,
                kd_loss=kd_loss,
                distance_loss=0,
                center_loss=0,
                param_reuse_loss=0,
                stopOnLoss=stopOnLoss,
                stopOnValAcc = stopOnValAcc,
                full_CE=full_CE
                )
        else:
            train_model(
                model, 
                train_loader, 
                val_loader, 
                verbose, 
                epochs=epochs, 
                l1_loss=0,
                stopOnLoss=stopOnLoss,
                center_loss =0,
                )

        #[Denis] added code:
        Feature_Importance_Eval.Task_Feature_Attribution(model, i)
        
        if verbose or i == globals.ITERATIONS-1:
            _print("Starting evaluation")
            _print("ITERATION", i+1)
            _print("ACCURACIES PER TASK:")
            accumPred = []
            all_labels = []
            all_embeddings = []
            with torch.no_grad():
                for j in range(i+1):
                    val_loader = globals.testloaders[j]
                    val_labels = get_labels(val_loader).to(DEVICE)
                    all_labels.append(val_labels)
                    model.eval()
                    pred, embeddings = model.get_pred_and_embeddings((get_features(val_loader).to(DEVICE)))
                    model.train()
                    accumPred.append(pred)
                    all_embeddings.append(embeddings)
                    sliced_pred = pred[:, j*(globals.CLASSES_PER_ITER+globals.OOD_CLASS):(j+1)*(globals.CLASSES_PER_ITER+globals.OOD_CLASS)]
                    _, predicted = torch.max(sliced_pred, 1)  # Get the class predictions
                    predicted += j*globals.CLASSES_PER_ITER
                    correct = (predicted == val_labels).sum().item()  # Count how many were correct
                    accuracy = correct / val_labels.size(0)  # Accuracy as a percentage
                    _print(str(accuracy), end=' ')
            accumPred = torch.cat(accumPred)
            all_labels = torch.cat(all_labels)
            all_embeddings = torch.cat(all_embeddings)
            predicted = []
            for x in accumPred:
                if globals.OOD_CLASS == 1:
                    x_pred = x[[i for i in range(x.size(0)) if (i + 1) % (globals.CLASSES_PER_ITER+1) != 0]]
                else:
                    x_pred = x
                x_pred = torch.softmax(x_pred, dim=-1)
                max = 0
                for (k, v) in enumerate(x_pred):
                    if v > max:
                        max = v
                        p = k
                predicted.append(p)
            predicted = torch.tensor(predicted).to(DEVICE)
            correct = (predicted == all_labels).sum().item()  # Count how many were correct
            accuracy = correct / all_labels.size(0)  # Accuracy as a percentage
            _print("Accuracy on tasks so far:", accuracy)

            embedding_drift = measure_embedding_drift(all_embeddings, all_labels, model.prev_test_embedding_centers)
            _print("Average embedding drift based on centroids:", embedding_drift)
            total_confusion, intra_phase_confusion, per_task_confusion = measure_embedding_confusion_knn(all_embeddings, all_labels, k = 1000, task=i+1)
            _print("Total confusion", total_confusion)
            _print("Intra-phase confusion", intra_phase_confusion)
            _print("Per task confusions", per_task_confusion)
            if verbose:
                plot_confusion_matrix(predicted.cpu(), all_labels.cpu(), list(range(globals.CLASSES_PER_ITER*(i+1))))
        prevModel = copy.deepcopy(model)
        
    #[Denis] added code:
    avg_shap_val,shap_vals,avg_entropy_val,entropy_vals=Feature_Importance_Eval.Get_Feature_Change_Score(prevModel)
    _print("Average SHAPC values (ordered as tasks):", shap_vals)
    _print("Averaged SHAPC value (the smaller the better):", avg_shap_val)
    _print("Average entropy values (ordered as tasks):", entropy_vals)
    _print("Averaged entropy value (the bigger the better):", avg_entropy_val)
    Feature_Importance_Eval.Save_Random_Picture_Salency() #prints the salcency maps for 1 example by class (first row: image, second row: salency map after training, third row: salency map after training task where class is included)
    
    return accuracy, total_confusion, intra_phase_confusion, per_task_confusion, embedding_drift, avg_shap_val

In [20]:
def run_experiments(n_runs=1, *args, **kwargs):
    verbose = kwargs.get('verbose', None)
    def _print(*args, **kwargs):
        if verbose:
            print(*args, **kwargs)
    accuracies = []
    total_confusions = []
    intra_phase_confusions = []
    per_task_confusions = []
    shap_vals = []
    embedding_drifts = []
    for r in range(n_runs):
        print(f"Starting run {r+1}.")
        accuracy, total_confusion, intra_phase_confusion, per_task_confusion, embedding_drift, avg_shap_val = run_experiment(*args, **kwargs)
        accuracies.append(accuracy)
        total_confusions.append(total_confusion)
        intra_phase_confusions.append(intra_phase_confusion)
        per_task_confusions.append(per_task_confusion)
        shap_vals.append(avg_shap_val)
        embedding_drifts.append(embedding_drift)
        _print(f"Run {r} finished with accuracy {accuracy}")
    # Calculate means and standard deviations for each measure
    mean_acc = statistics.mean(accuracies)
    acc_std = statistics.stdev(accuracies)

    mean_total_confusion = statistics.mean(total_confusions)
    total_confusion_std = statistics.stdev(total_confusions)

    mean_intra_phase_confusion = statistics.mean(intra_phase_confusions)
    intra_phase_confusion_std = statistics.stdev(intra_phase_confusions)

    mean_per_task_confusion = statistics.mean(per_task_confusions)
    per_task_confusion_std = statistics.stdev(per_task_confusions)

    mean_embedding_drift = statistics.mean(embedding_drifts)
    embedding_drift_std = statistics.stdev(embedding_drifts)

    mean_shap_val = statistics.mean(shap_vals)
    shap_val_std = statistics.stdev(shap_vals)

    # Print all results
    #print("Accuracies:")
    #print(accuracies)
    print(f"Mean accuracy across {n_runs} runs: {mean_acc}")
    print(f"Standard deviation of accuracy across {n_runs} runs: {acc_std}\n")

    #print("Total Confusions:")
    #print(total_confusions)
    print(f"Mean total confusion across {n_runs} runs: {mean_total_confusion}")
    print(f"Standard deviation of total confusion across {n_runs} runs: {total_confusion_std}\n")

    #print("Intra-Phase Confusions:")
    #print(intra_phase_confusions)
    print(f"Mean intra-phase confusion across {n_runs} runs: {mean_intra_phase_confusion}")
    print(f"Standard deviation of intra-phase confusion across {n_runs} runs: {intra_phase_confusion_std}\n")

    #print("Per-Task Confusions:")
    #print(per_task_confusions)
    print(f"Mean per-task confusion across {n_runs} runs: {mean_per_task_confusion}")
    print(f"Standard deviation of per-task confusion across {n_runs} runs: {per_task_confusion_std}\n")

    #print("Embedding drifts:")
    #print(embedding_drifts)
    print(f"Mean embedding drift across {n_runs} runs: {mean_embedding_drift}")
    print(f"Standard deviation of embedding drift across {n_runs} runs: {embedding_drift_std}\n")

    #print("SHAP Values:")
    #print(shap_vals)
    print(f"Mean SHAP values across {n_runs} runs: {mean_shap_val}")
    print(f"Standard deviation of SHAP values across {n_runs} runs: {shap_val_std}\n")

In [6]:
run_experiments(n_runs=5, verbose=False, stopOnLoss = 0.02)

Starting run 1.


Experiment Progress: 100%|██████████| 5/5 [02:03<00:00, 24.66s/it]


Starting run 2.


Experiment Progress: 100%|██████████| 5/5 [02:03<00:00, 24.68s/it]


Starting run 3.


Experiment Progress: 100%|██████████| 5/5 [01:54<00:00, 22.86s/it]


Starting run 4.


Experiment Progress: 100%|██████████| 5/5 [01:59<00:00, 23.91s/it]


Starting run 5.


Experiment Progress: 100%|██████████| 5/5 [01:55<00:00, 23.05s/it]


Mean accuracy across 5 runs: 0.19712
Standard deviation of accuracy across 5 runs: 0.00017888543819998597

Mean total confusion across 5 runs: 0.41027318
Standard deviation of total confusion across 5 runs: 0.01668874668562622

Mean intra-phase confusion across 5 runs: 0.40326086
Standard deviation of intra-phase confusion across 5 runs: 0.015929467483346702

Mean per-task confusion across 5 runs: 0.07731184684362342
Standard deviation of per-task confusion across 5 runs: 0.007665591813330528

Mean embedding drift across 5 runs: 9.901034545898437
Standard deviation of embedding drift across 5 runs: 0.5234171887956153

Mean SHAP values across 5 runs: 0.0008863337079308142
Standard deviation of SHAP values across 5 runs: 7.095100737754766e-05



In [7]:
run_experiments(n_runs=5, verbose=False, stopOnLoss = 0.02, with_dropout=True)

Starting run 1.


Experiment Progress: 100%|██████████| 5/5 [03:11<00:00, 38.23s/it]


Starting run 2.


Experiment Progress: 100%|██████████| 5/5 [03:01<00:00, 36.37s/it]


Starting run 3.


Experiment Progress: 100%|██████████| 5/5 [03:06<00:00, 37.22s/it]


Starting run 4.


Experiment Progress: 100%|██████████| 5/5 [03:15<00:00, 39.11s/it]


Starting run 5.


Experiment Progress: 100%|██████████| 5/5 [03:19<00:00, 39.89s/it]


Mean accuracy across 5 runs: 0.1977
Standard deviation of accuracy across 5 runs: 0.0003162277660168448

Mean total confusion across 5 runs: 0.4342616
Standard deviation of total confusion across 5 runs: 0.006836191673366111

Mean intra-phase confusion across 5 runs: 0.42464177999999997
Standard deviation of intra-phase confusion across 5 runs: 0.008322003001201067

Mean per-task confusion across 5 runs: 0.11041337647890548
Standard deviation of per-task confusion across 5 runs: 0.006085887620899638

Mean embedding drift across 5 runs: 8.4266357421875
Standard deviation of embedding drift across 5 runs: 0.32925923782169186

Mean SHAP values across 5 runs: 0.0015679794807187378
Standard deviation of SHAP values across 5 runs: 0.00014219603657457793



In [8]:
run_experiments(n_runs=5, verbose=False, stopOnLoss = 0.02, kd_loss=1)

Starting run 1.


Experiment Progress: 100%|██████████| 5/5 [02:42<00:00, 32.45s/it]


Starting run 2.


Experiment Progress: 100%|██████████| 5/5 [02:46<00:00, 33.28s/it]


Starting run 3.


Experiment Progress: 100%|██████████| 5/5 [02:25<00:00, 29.18s/it]


Starting run 4.


Experiment Progress: 100%|██████████| 5/5 [02:35<00:00, 31.01s/it]


Starting run 5.


Experiment Progress: 100%|██████████| 5/5 [02:37<00:00, 31.42s/it]


Mean accuracy across 5 runs: 0.20788
Standard deviation of accuracy across 5 runs: 0.009331237860005502

Mean total confusion across 5 runs: 0.38565748
Standard deviation of total confusion across 5 runs: 0.013362716836107832

Mean intra-phase confusion across 5 runs: 0.38222578
Standard deviation of intra-phase confusion across 5 runs: 0.012737244903706618

Mean per-task confusion across 5 runs: 0.06670221160245901
Standard deviation of per-task confusion across 5 runs: 0.004874366234574275

Mean embedding drift across 5 runs: 10.33856201171875
Standard deviation of embedding drift across 5 runs: 0.32364676735838166

Mean SHAP values across 5 runs: 0.0009319879324541062
Standard deviation of SHAP values across 5 runs: 7.25542854255561e-05



In [9]:
run_experiments(n_runs=5, verbose=False, stopOnLoss = 0.02, full_CE=False)

Starting run 1.


Experiment Progress: 100%|██████████| 5/5 [01:54<00:00, 22.99s/it]


Starting run 2.


Experiment Progress: 100%|██████████| 5/5 [01:56<00:00, 23.33s/it]


Starting run 3.


Experiment Progress: 100%|██████████| 5/5 [02:00<00:00, 24.06s/it]


Starting run 4.


Experiment Progress: 100%|██████████| 5/5 [01:54<00:00, 22.82s/it]


Starting run 5.


Experiment Progress: 100%|██████████| 5/5 [01:56<00:00, 23.20s/it]


Mean accuracy across 5 runs: 0.65862
Standard deviation of accuracy across 5 runs: 0.038600867865891335

Mean total confusion across 5 runs: 0.41970142
Standard deviation of total confusion across 5 runs: 0.011164848655803633

Mean intra-phase confusion across 5 runs: 0.41735474
Standard deviation of intra-phase confusion across 5 runs: 0.011159116510414256

Mean per-task confusion across 5 runs: 0.06622207210816197
Standard deviation of per-task confusion across 5 runs: 0.004421540729752876

Mean embedding drift across 5 runs: 5.348196601867675
Standard deviation of embedding drift across 5 runs: 0.13468180027021937

Mean SHAP values across 5 runs: 0.0005514491549105287
Standard deviation of SHAP values across 5 runs: 1.129900453609687e-05



In [10]:
run_experiments(n_runs=5, verbose=False, stopOnLoss = 0.02, with_OOD=True)

Starting run 1.


Experiment Progress: 100%|██████████| 5/5 [03:10<00:00, 38.14s/it]


Starting run 2.


Experiment Progress: 100%|██████████| 5/5 [03:00<00:00, 36.12s/it]


Starting run 3.


Experiment Progress: 100%|██████████| 5/5 [03:00<00:00, 36.16s/it]


Starting run 4.


Experiment Progress: 100%|██████████| 5/5 [02:59<00:00, 35.87s/it]


Starting run 5.


Experiment Progress: 100%|██████████| 5/5 [02:52<00:00, 34.44s/it]


Mean accuracy across 5 runs: 0.20118
Standard deviation of accuracy across 5 runs: 0.0050967636790418295

Mean total confusion across 5 runs: 0.3520012
Standard deviation of total confusion across 5 runs: 0.01887235752734672

Mean intra-phase confusion across 5 runs: 0.34562892
Standard deviation of intra-phase confusion across 5 runs: 0.018263727836260575

Mean per-task confusion across 5 runs: 0.07485453104021554
Standard deviation of per-task confusion across 5 runs: 0.00806487088908623

Mean embedding drift across 5 runs: 10.17542839050293
Standard deviation of embedding drift across 5 runs: 0.5715770296552717

Mean SHAP values across 5 runs: 0.00046121513314196236
Standard deviation of SHAP values across 5 runs: 3.328521156119207e-05



In [11]:
run_experiments(n_runs=5, verbose=False, stopOnLoss = 0.02, full_CE = False, with_OOD=True)

Starting run 1.


Experiment Progress: 100%|██████████| 5/5 [02:56<00:00, 35.21s/it]


Starting run 2.


Experiment Progress: 100%|██████████| 5/5 [02:57<00:00, 35.42s/it]


Starting run 3.


Experiment Progress: 100%|██████████| 5/5 [03:03<00:00, 36.78s/it]


Starting run 4.


Experiment Progress: 100%|██████████| 5/5 [03:05<00:00, 37.09s/it]


Starting run 5.


Experiment Progress: 100%|██████████| 5/5 [03:09<00:00, 37.83s/it]


Mean accuracy across 5 runs: 0.76278
Standard deviation of accuracy across 5 runs: 0.04507030064244082

Mean total confusion across 5 runs: 0.3381334
Standard deviation of total confusion across 5 runs: 0.011750315672142611

Mean intra-phase confusion across 5 runs: 0.3351177
Standard deviation of intra-phase confusion across 5 runs: 0.011500446784364496

Mean per-task confusion across 5 runs: 0.05976915174421091
Standard deviation of per-task confusion across 5 runs: 0.006660504678722782

Mean embedding drift across 5 runs: 6.203188705444336
Standard deviation of embedding drift across 5 runs: 0.38552032661363195

Mean SHAP values across 5 runs: 0.0009006374135017964
Standard deviation of SHAP values across 5 runs: 6.116743572649732e-05



In [12]:
run_experiments(n_runs=5, verbose=False, stopOnLoss = 0.02, full_CE = False, kd_loss = 1, with_OOD=True)

Starting run 1.


Experiment Progress: 100%|██████████| 5/5 [03:23<00:00, 40.64s/it]


Starting run 2.


Experiment Progress: 100%|██████████| 5/5 [03:35<00:00, 43.08s/it]


Starting run 3.


Experiment Progress: 100%|██████████| 5/5 [03:36<00:00, 43.28s/it]


Starting run 4.


Experiment Progress: 100%|██████████| 5/5 [03:28<00:00, 41.69s/it]


Starting run 5.


Experiment Progress: 100%|██████████| 5/5 [03:45<00:00, 45.12s/it]


Mean accuracy across 5 runs: 0.84322
Standard deviation of accuracy across 5 runs: 0.016848204652128385

Mean total confusion across 5 runs: 0.33744808000000004
Standard deviation of total confusion across 5 runs: 0.013354935730919852

Mean intra-phase confusion across 5 runs: 0.33412914
Standard deviation of intra-phase confusion across 5 runs: 0.013293854543096234

Mean per-task confusion across 5 runs: 0.06099390556090557
Standard deviation of per-task confusion across 5 runs: 0.00274550748986781

Mean embedding drift across 5 runs: 5.186639595031738
Standard deviation of embedding drift across 5 runs: 0.259908944499797

Mean SHAP values across 5 runs: 0.0004815936748197613
Standard deviation of SHAP values across 5 runs: 4.3243486941720384e-05



In [13]:
run_experiments(n_runs=5, verbose=False, stopOnLoss = 0.02, full_CE = False, kd_loss = 1)

Starting run 1.


Experiment Progress: 100%|██████████| 5/5 [02:19<00:00, 27.92s/it]


Starting run 2.


Experiment Progress: 100%|██████████| 5/5 [02:19<00:00, 27.93s/it]


Starting run 3.


Experiment Progress: 100%|██████████| 5/5 [02:21<00:00, 28.30s/it]


Starting run 4.


Experiment Progress: 100%|██████████| 5/5 [02:17<00:00, 27.59s/it]


Starting run 5.


Experiment Progress: 100%|██████████| 5/5 [02:41<00:00, 32.37s/it]


Mean accuracy across 5 runs: 0.7285200000000001
Standard deviation of accuracy across 5 runs: 0.041093576140316605

Mean total confusion across 5 runs: 0.39564756
Standard deviation of total confusion across 5 runs: 0.008445913656200883

Mean intra-phase confusion across 5 runs: 0.39362472
Standard deviation of intra-phase confusion across 5 runs: 0.008372517833184951

Mean per-task confusion across 5 runs: 0.05749871795516995
Standard deviation of per-task confusion across 5 runs: 0.0029189804114646426

Mean embedding drift across 5 runs: 4.801405525207519
Standard deviation of embedding drift across 5 runs: 0.11469332839415294

Mean SHAP values across 5 runs: 0.0004116192882103642
Standard deviation of SHAP values across 5 runs: 3.6677832461491065e-05



In [14]:
run_experiments(n_runs=5, verbose=False, stopOnLoss = 0.02, full_CE = False, kd_loss = 1, with_dropout=True)

Starting run 1.


Experiment Progress: 100%|██████████| 5/5 [03:49<00:00, 45.95s/it]


Starting run 2.


Experiment Progress: 100%|██████████| 5/5 [04:03<00:00, 48.67s/it]


Starting run 3.


Experiment Progress: 100%|██████████| 5/5 [04:02<00:00, 48.47s/it]


Starting run 4.


Experiment Progress: 100%|██████████| 5/5 [04:01<00:00, 48.31s/it]


Starting run 5.


Experiment Progress: 100%|██████████| 5/5 [04:03<00:00, 48.67s/it]


Mean accuracy across 5 runs: 0.77592
Standard deviation of accuracy across 5 runs: 0.019157557255558455

Mean total confusion across 5 runs: 0.39213982
Standard deviation of total confusion across 5 runs: 0.005603796637726973

Mean intra-phase confusion across 5 runs: 0.39050850000000004
Standard deviation of intra-phase confusion across 5 runs: 0.005608870130873021

Mean per-task confusion across 5 runs: 0.056725291065110366
Standard deviation of per-task confusion across 5 runs: 0.0015721690216526512

Mean embedding drift across 5 runs: 4.493971157073974
Standard deviation of embedding drift across 5 runs: 0.24495078220048352

Mean SHAP values across 5 runs: 0.0005218379968121305
Standard deviation of SHAP values across 5 runs: 5.0939184468199486e-05



In [15]:
run_experiments(n_runs=5, verbose=False, stopOnLoss = 0.02, full_CE = False, kd_loss = 1, with_OOD=True, with_dropout=True)

Starting run 1.


Experiment Progress: 100%|██████████| 5/5 [08:29<00:00, 101.99s/it]


Starting run 2.


Experiment Progress: 100%|██████████| 5/5 [08:07<00:00, 97.56s/it] 


Starting run 3.


Experiment Progress: 100%|██████████| 5/5 [06:18<00:00, 75.67s/it] 


Starting run 4.


Experiment Progress: 100%|██████████| 5/5 [07:24<00:00, 88.86s/it] 


Starting run 5.


Experiment Progress: 100%|██████████| 5/5 [07:53<00:00, 94.65s/it] 


Mean accuracy across 5 runs: 0.8506
Standard deviation of accuracy across 5 runs: 0.03785941098326809

Mean total confusion across 5 runs: 0.32519064
Standard deviation of total confusion across 5 runs: 0.01114584482320656

Mean intra-phase confusion across 5 runs: 0.32300544
Standard deviation of intra-phase confusion across 5 runs: 0.0111097349371171

Mean per-task confusion across 5 runs: 0.050529086045436974
Standard deviation of per-task confusion across 5 runs: 0.002420955818193815

Mean embedding drift across 5 runs: 4.840276145935059
Standard deviation of embedding drift across 5 runs: 0.27403502367552246

Mean SHAP values across 5 runs: 0.0005207426627989993
Standard deviation of SHAP values across 5 runs: 4.3099602605594366e-05



In [21]:
# training on entire dataset
globals.ITERATIONS = 1
globals.CLASSES_PER_ITER = 10
run_experiments(n_runs=5, verbose=False, stopOnLoss = 0.02)

Starting run 1.


Experiment Progress: 100%|██████████| 1/1 [04:32<00:00, 272.77s/it]


Starting run 2.


Experiment Progress: 100%|██████████| 1/1 [03:24<00:00, 204.35s/it]


Starting run 3.


Experiment Progress: 100%|██████████| 1/1 [03:19<00:00, 199.88s/it]


Starting run 4.


Experiment Progress: 100%|██████████| 1/1 [03:09<00:00, 189.94s/it]


Starting run 5.


Experiment Progress: 100%|██████████| 1/1 [03:17<00:00, 197.87s/it]


Mean accuracy across 5 runs: 0.99084
Standard deviation of accuracy across 5 runs: 0.000684105255059517

Mean total confusion across 5 runs: 0.19484374
Standard deviation of total confusion across 5 runs: 0.009626986084595703

Mean intra-phase confusion across 5 runs: 0.0
Standard deviation of intra-phase confusion across 5 runs: 0.0

Mean per-task confusion across 5 runs: 0.19484374
Standard deviation of per-task confusion across 5 runs: 0.009626986084595703

Mean embedding drift across 5 runs: 7.962159907037858e-06
Standard deviation of embedding drift across 5 runs: 5.002824389924782e-07

Mean SHAP values across 5 runs: 0.0031779530191561206
Standard deviation of SHAP values across 5 runs: 0.00011392721770175777

