In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from sklearn import metrics
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Subset
import numpy as np
import copy
import statistics
import matplotlib.pyplot as plt
import time
from tqdm import tqdm

import sys
sys.path.append('./../src')
import globals
from model import Net
from training import train_model, train_model_CL
from visualizations import plot_embeddings, plot_confusion_matrix
from feature_attribution import Feature_Importance_Evaluations
from pytorch_utils import get_features, get_labels
from embedding_measurements import measure_embedding_confusion_knn, measure_embedding_drift

In [2]:
SEED = globals.SEED
DEVICE = globals.DEVICE
full_trainset = globals.full_trainset
trainset = globals.trainset
testset = globals.testset
trainloaders = globals.trainloaders
valloaders = globals.valloaders
testloaders = globals.testloaders

In [3]:
# This is the two-step process used to prepare the
# data for use with the convolutional neural network.

# First step is to convert Python Image Library (PIL) format
# to PyTorch tensors.

# Second step is used to normalize the data by specifying a 
# mean and standard deviation for each of the three channels.
# This will convert the data from [0,1] to [-1,1]

# Normalization of data should help speed up conversion and
# reduce the chance of vanishing gradients with certain 
# activation functions.
def initialize_data():
    transform = transforms.Compose([
        transforms.ToTensor()
        #transforms.Normalize((0.5,), (0.5,))  # Normalizes to mean 0.5 and std 0.5 for the single channel
    ])

    globals.full_trainset = torchvision.datasets.MNIST('./../data/', train=True, download=True,
                                transform=transform)
    targets = np.array(globals.full_trainset.targets)

    # Perform stratified split
    train_indices, val_indices = train_test_split(
        np.arange(len(targets)),
        test_size=0.01,
        stratify=targets
    )

    # Create subsets
    valset = Subset(globals.full_trainset, val_indices)
    globals.trainset = Subset(globals.full_trainset, train_indices)

    globals.testset = torchvision.datasets.MNIST('./../data/', train=False, download=True,
                                transform=transform)

    # Define class pairs for each subset
    class_pairs = [tuple(range(i*globals.CLASSES_PER_ITER,(i+1)*globals.CLASSES_PER_ITER)) for i in range(globals.ITERATIONS)]
    #print(class_pairs)

    # Dictionary to hold data loaders for each subset
    globals.trainloaders = []
    globals.testloaders = []
    globals.valloaders = []
    subset_indices = []
    # Loop over each class pair
    for i, t in enumerate(class_pairs):
        # Get indices of images belonging to the specified class pair
        subs_ind = [idx for idx, (_, label) in enumerate(globals.trainset) if label in list(t)]
        val_subset_indices = [idx for idx, (_, label) in enumerate(valset) if label in list(t)]
        test_subset_indices = [idx for idx, (_, label) in enumerate(globals.testset) if label in list(t)]
        # Create a subset for the current class pair
        train_subset = Subset(globals.trainset, subs_ind)
        globals.trainloaders.append(DataLoader(train_subset, batch_size=globals.BATCH_SIZE, shuffle=True, pin_memory=True, num_workers = 0))

        subset_indices.append(subs_ind)
        
        val_subset = Subset(valset, val_subset_indices)
        globals.valloaders.append(DataLoader(val_subset, batch_size=500, shuffle=False))

        test_subset = Subset(globals.testset, test_subset_indices)
        globals.testloaders.append(DataLoader(test_subset, batch_size=500, shuffle=False))


In [4]:
def run_experiment(
        verbose = False,
        stopOnLoss = 0.03,
        full_CE = True,
        with_OOD = False,
        kd_loss = 0,
        stopOnValAcc = None,
        epochs = 1000000,
        with_dropout = False,
        ogd = False
        ):
    def _print(*args, **kwargs):
        if verbose:
            print(*args, **kwargs)
    if with_OOD:
        globals.OOD_CLASS = 1
    else:
        globals.OOD_CLASS = 0
    initialize_data()
    prevModel = None
    globals.BATCH_SIZE=4
    
    globals.WITH_DROPOUT = with_dropout

    #[Denis] added code:
    Feature_Importance_Eval=Feature_Importance_Evaluations(globals.testloaders, DEVICE)

    for i in tqdm(range(globals.ITERATIONS), desc="Experiment Progress"):
        model = Net((i+1)*(globals.CLASSES_PER_ITER+globals.OOD_CLASS))
        if prevModel is not None:
            with torch.no_grad():
                model.copyPrev(prevModel)
        train_loader = globals.trainloaders[i]
        val_loader = globals.valloaders[i]
        if prevModel:
            _print("CL TRAIN!!")
            train_model_CL(
                model,
                prevModel,
                train_loader,
                val_loader,
                i,
                verbose,
                epochs,
                True,
                freeze_nonzero_params=False,
                l1_loss=0,
                ewc_loss=0,
                kd_loss=kd_loss,
                distance_loss=0,
                center_loss=0,
                param_reuse_loss=0,
                stopOnLoss=stopOnLoss,
                stopOnValAcc = stopOnValAcc,
                full_CE=full_CE,
                ogd=ogd,
                )
        else:
            train_model(
                model, 
                train_loader, 
                val_loader, 
                verbose, 
                epochs=epochs, 
                l1_loss=0,
                stopOnLoss=stopOnLoss,
                center_loss =0,
                ogd=ogd
                )

        #[Denis] added code:
        Feature_Importance_Eval.Task_Feature_Attribution(model, i)
        
        if verbose or i == globals.ITERATIONS-1:
            _print("Starting evaluation")
            _print("ITERATION", i+1)
            _print("ACCURACIES PER TASK:")
            accumPred = []
            all_labels = []
            all_embeddings = []
            with torch.no_grad():
                for j in range(i+1):
                    val_loader = globals.testloaders[j]
                    val_labels = get_labels(val_loader).to(DEVICE)
                    all_labels.append(val_labels)
                    model.eval()
                    pred, embeddings = model.get_pred_and_embeddings((get_features(val_loader).to(DEVICE)))
                    model.train()
                    accumPred.append(pred)
                    all_embeddings.append(embeddings)
                    sliced_pred = pred[:, j*(globals.CLASSES_PER_ITER+globals.OOD_CLASS):(j+1)*(globals.CLASSES_PER_ITER+globals.OOD_CLASS)]
                    _, predicted = torch.max(sliced_pred, 1)  # Get the class predictions
                    predicted += j*globals.CLASSES_PER_ITER
                    correct = (predicted == val_labels).sum().item()  # Count how many were correct
                    accuracy = correct / val_labels.size(0)  # Accuracy as a percentage
                    _print(str(accuracy), end=' ')
            accumPred = torch.cat(accumPred)
            all_labels = torch.cat(all_labels)
            all_embeddings = torch.cat(all_embeddings)
            predicted = []
            for x in accumPred:
                if globals.OOD_CLASS == 1:
                    x_pred = x[[i for i in range(x.size(0)) if (i + 1) % (globals.CLASSES_PER_ITER+1) != 0]]
                else:
                    x_pred = x
                x_pred = torch.softmax(x_pred, dim=-1)
                max = 0
                for (k, v) in enumerate(x_pred):
                    if v > max:
                        max = v
                        p = k
                predicted.append(p)
            predicted = torch.tensor(predicted).to(DEVICE)
            correct = (predicted == all_labels).sum().item()  # Count how many were correct
            accuracy = correct / all_labels.size(0)  # Accuracy as a percentage
            _print("Accuracy on tasks so far:", accuracy)

            embedding_drift = measure_embedding_drift(all_embeddings, all_labels, model.prev_test_embedding_centers)
            _print("Average embedding drift based on centroids:", embedding_drift)
            total_confusion, intra_phase_confusion, per_task_confusion = measure_embedding_confusion_knn(all_embeddings, all_labels, k = 1000, task=i+1)
            _print("Total confusion", total_confusion)
            _print("Intra-phase confusion", intra_phase_confusion)
            _print("Per task confusions", per_task_confusion)
            if verbose:
                plot_confusion_matrix(predicted.cpu(), all_labels.cpu(), list(range(globals.CLASSES_PER_ITER*(i+1))))
        prevModel = copy.deepcopy(model)
        
    #[Denis] added code:
    [avg_att_diff,att_diffs,_,_,avg_att_spread,att_spreads]=Feature_Importance_Eval.Get_Feature_Change_Score(prevModel)
    _print("Average SHAPC values (ordered as tasks):", att_diffs)
    _print("Averaged SHAPC value (the smaller the better):", avg_att_diff)
    _print("Average attention spread values (ordered as tasks):", att_spreads)
    _print("Averaged attention spread value (the bigger the better):", avg_att_spread)
    Feature_Importance_Eval.Save_Random_Picture_Salency() #prints the salcency maps for 1 example by class (first row: image, second row: salency map after training, third row: salency map after training task where class is included)
    
    return accuracy, total_confusion, intra_phase_confusion, per_task_confusion, embedding_drift, avg_att_diff, avg_att_spread

In [5]:
def run_experiments(n_runs=1, *args, **kwargs):
    verbose = kwargs.get('verbose', None)
    def _print(*args, **kwargs):
        if verbose:
            print(*args, **kwargs)
    def report_stats(data, name):
        print(name, data)
        mean = statistics.mean(data)
        std = statistics.stdev(data)
        print(f"Mean " + name + f" across {n_runs} runs: {mean}")
        print(f"Standard deviation of " + name + f" across {n_runs} runs: {std}\n")
    accuracies = []
    total_confusions = []
    intra_phase_confusions = []
    per_task_confusions = []
    att_diffs = []
    embedding_drifts = []
    att_spreads = []
    for r in range(n_runs):
        print(f"Starting run {r+1}.")
        accuracy, total_confusion, intra_phase_confusion, per_task_confusion, embedding_drift, avg_att_diff, avg_att_spread = run_experiment(*args, **kwargs)
        accuracies.append(accuracy)
        total_confusions.append(total_confusion)
        intra_phase_confusions.append(intra_phase_confusion)
        per_task_confusions.append(per_task_confusion)
        att_diffs.append(avg_att_diff)
        embedding_drifts.append(embedding_drift)
        att_spreads.append(avg_att_spread)
        _print(f"Run {r} finished with accuracy {accuracy}")

    report_stats(accuracies, "accuracy")
    report_stats(total_confusions, "total confusion")
    report_stats(intra_phase_confusions, "intra-phase confusion")
    report_stats(per_task_confusions, "per-task confusion")
    report_stats(embedding_drifts, "embedding drift")
    report_stats(att_diffs, "attention drift")
    report_stats(att_spreads, "attention spread")

In [6]:
run_experiments(n_runs=10, verbose=False, stopOnLoss = 0.02)

Starting run 1.


Experiment Progress: 100%|██████████| 5/5 [02:16<00:00, 27.28s/it]


Starting run 2.


Experiment Progress: 100%|██████████| 5/5 [01:54<00:00, 22.99s/it]


Starting run 3.


Experiment Progress: 100%|██████████| 5/5 [02:09<00:00, 25.84s/it]


Starting run 4.


Experiment Progress: 100%|██████████| 5/5 [02:01<00:00, 24.21s/it]


Starting run 5.


Experiment Progress: 100%|██████████| 5/5 [02:03<00:00, 24.74s/it]


Starting run 6.


Experiment Progress: 100%|██████████| 5/5 [02:06<00:00, 25.31s/it]


Starting run 7.


Experiment Progress: 100%|██████████| 5/5 [02:00<00:00, 24.12s/it]


Starting run 8.


Experiment Progress: 100%|██████████| 5/5 [02:09<00:00, 25.85s/it]


Starting run 9.


Experiment Progress: 100%|██████████| 5/5 [02:05<00:00, 25.15s/it]


Starting run 10.


Experiment Progress: 100%|██████████| 5/5 [01:59<00:00, 23.99s/it]


accuracy [0.1977, 0.1965, 0.1975, 0.1964, 0.197, 0.1976, 0.1967, 0.1974, 0.1974, 0.1974]
Mean accuracy across 10 runs: 0.19716
Standard deviation of accuracy across 10 runs: 0.0004742245131674291

total confusion [0.44943299999999997, 0.4149712, 0.3875702, 0.42745409999999995, 0.38913489999999995, 0.41174409999999995, 0.40658150000000004, 0.39565300000000003, 0.41054749999999995, 0.4281956]
Mean total confusion across 10 runs: 0.41212851
Standard deviation of total confusion across 10 runs: 0.019233670576965334

intra-phase confusion [0.44132229999999995, 0.4064647, 0.3812308, 0.4186225, 0.38446210000000003, 0.40449710000000005, 0.40128909999999995, 0.38833660000000003, 0.40281809999999996, 0.42079659999999997]
Mean intra-phase confusion across 10 runs: 0.40498399
Standard deviation of intra-phase confusion across 10 runs: 0.01837809933779574

per-task confusion [0.09558462092663422, 0.08496075659449494, 0.08095636096199192, 0.08522670286473828, 0.07100789869842554, 0.0841048491297574,

In [7]:
run_experiments(n_runs=10, verbose=False, stopOnLoss = 0.02, with_dropout=True)

Starting run 1.


Experiment Progress: 100%|██████████| 5/5 [03:12<00:00, 38.54s/it]


Starting run 2.


Experiment Progress: 100%|██████████| 5/5 [03:06<00:00, 37.39s/it]


Starting run 3.


Experiment Progress: 100%|██████████| 5/5 [03:20<00:00, 40.06s/it]


Starting run 4.


Experiment Progress: 100%|██████████| 5/5 [02:58<00:00, 35.73s/it]


Starting run 5.


Experiment Progress: 100%|██████████| 5/5 [03:10<00:00, 38.04s/it]


Starting run 6.


Experiment Progress: 100%|██████████| 5/5 [03:11<00:00, 38.27s/it]


Starting run 7.


Experiment Progress: 100%|██████████| 5/5 [03:14<00:00, 38.84s/it]


Starting run 8.


Experiment Progress: 100%|██████████| 5/5 [03:29<00:00, 41.81s/it]


Starting run 9.


Experiment Progress: 100%|██████████| 5/5 [03:23<00:00, 40.77s/it]


Starting run 10.


Experiment Progress: 100%|██████████| 5/5 [03:23<00:00, 40.74s/it]


accuracy [0.1975, 0.1975, 0.1975, 0.1979, 0.1978, 0.198, 0.1978, 0.1979, 0.1978, 0.1977]
Mean accuracy across 10 runs: 0.19774
Standard deviation of accuracy across 10 runs: 0.00018378731669453315

total confusion [0.43969480000000005, 0.44410689999999997, 0.4666378, 0.44546759999999996, 0.4278931, 0.41982030000000004, 0.43395819999999996, 0.443643, 0.46588229999999997, 0.46773560000000003]
Mean total confusion across 10 runs: 0.44548396
Standard deviation of total confusion across 10 runs: 0.016661391555275986

intra-phase confusion [0.43142040000000004, 0.4318543, 0.45850729999999995, 0.4379881, 0.4203555, 0.41033739999999996, 0.4232504, 0.43658889999999995, 0.45161530000000005, 0.45781629999999995]
Mean intra-phase confusion across 10 runs: 0.43597339
Standard deviation of intra-phase confusion across 10 runs: 0.01610621219277759

per-task confusion [0.114121261966319, 0.10917685907723804, 0.12291185966004921, 0.10234253975880528, 0.10336357900976638, 0.10573744004378918, 0.11714660

In [8]:
run_experiments(n_runs=10, verbose=False, stopOnLoss = 0.02, kd_loss=1)

Starting run 1.


Experiment Progress: 100%|██████████| 5/5 [02:39<00:00, 31.99s/it]


Starting run 2.


Experiment Progress: 100%|██████████| 5/5 [02:53<00:00, 34.64s/it]


Starting run 3.


Experiment Progress: 100%|██████████| 5/5 [02:31<00:00, 30.39s/it]


Starting run 4.


Experiment Progress: 100%|██████████| 5/5 [02:58<00:00, 35.76s/it]


Starting run 5.


Experiment Progress: 100%|██████████| 5/5 [02:26<00:00, 29.29s/it]


Starting run 6.


Experiment Progress: 100%|██████████| 5/5 [03:03<00:00, 36.65s/it]


Starting run 7.


Experiment Progress: 100%|██████████| 5/5 [02:51<00:00, 34.37s/it]


Starting run 8.


Experiment Progress: 100%|██████████| 5/5 [02:30<00:00, 30.03s/it]


Starting run 9.


Experiment Progress: 100%|██████████| 5/5 [02:26<00:00, 29.22s/it]


Starting run 10.


Experiment Progress: 100%|██████████| 5/5 [03:00<00:00, 36.19s/it]


accuracy [0.1976, 0.2012, 0.1985, 0.2189, 0.1976, 0.2014, 0.2005, 0.197, 0.2086, 0.1979]
Mean accuracy across 10 runs: 0.20192
Standard deviation of accuracy across 10 runs: 0.006872942116250753

total confusion [0.3634056, 0.38402119999999995, 0.36573409999999995, 0.3987897, 0.3696667, 0.3781171, 0.3943352, 0.39169149999999997, 0.3904193, 0.38495650000000003]
Mean total confusion across 10 runs: 0.38211369
Standard deviation of total confusion across 10 runs: 0.01242343785337037

intra-phase confusion [0.36075310000000005, 0.3808104, 0.3631723, 0.39430960000000004, 0.3670685, 0.37405140000000003, 0.3907419, 0.387694, 0.3869815, 0.378637]
Mean intra-phase confusion across 10 runs: 0.37842197
Standard deviation of intra-phase confusion across 10 runs: 0.011841213139243788

per-task confusion [0.056471947001980195, 0.06295926726266514, 0.05949299955661831, 0.07777199874932487, 0.056508568737939614, 0.06522173312298656, 0.06821445862777578, 0.07023182762181075, 0.064475190532302, 0.082027

In [9]:
run_experiments(n_runs=10, verbose=False, stopOnLoss = 0.02, full_CE=False)

Starting run 1.


Experiment Progress: 100%|██████████| 5/5 [02:03<00:00, 24.62s/it]


Starting run 2.


Experiment Progress: 100%|██████████| 5/5 [01:52<00:00, 22.58s/it]


Starting run 3.


Experiment Progress: 100%|██████████| 5/5 [01:59<00:00, 23.90s/it]


Starting run 4.


Experiment Progress: 100%|██████████| 5/5 [01:55<00:00, 23.17s/it]


Starting run 5.


Experiment Progress: 100%|██████████| 5/5 [01:51<00:00, 22.32s/it]


Starting run 6.


Experiment Progress: 100%|██████████| 5/5 [01:54<00:00, 22.92s/it]


Starting run 7.


Experiment Progress: 100%|██████████| 5/5 [01:57<00:00, 23.46s/it]


Starting run 8.


Experiment Progress: 100%|██████████| 5/5 [01:58<00:00, 23.63s/it]


Starting run 9.


Experiment Progress: 100%|██████████| 5/5 [01:53<00:00, 22.80s/it]


Starting run 10.


Experiment Progress: 100%|██████████| 5/5 [02:03<00:00, 24.75s/it]


accuracy [0.679, 0.6623, 0.6559, 0.6825, 0.6188, 0.6378, 0.6662, 0.6405, 0.6442, 0.6607]
Mean accuracy across 10 runs: 0.65479
Standard deviation of accuracy across 10 runs: 0.019655389422072855

total confusion [0.4103538, 0.42270490000000005, 0.418976, 0.42667330000000003, 0.41352500000000003, 0.43503480000000005, 0.4172456, 0.41843490000000005, 0.41846300000000003, 0.43340599999999996]
Mean total confusion across 10 runs: 0.42148173
Standard deviation of total confusion across 10 runs: 0.00805593106434287

intra-phase confusion [0.4086381, 0.42050770000000004, 0.41685720000000004, 0.42387129999999995, 0.4114116, 0.4329024, 0.4147005, 0.41602340000000004, 0.41618239999999995, 0.4305721]
Mean intra-phase confusion across 10 runs: 0.41916667
Standard deviation of intra-phase confusion across 10 runs: 0.007872203234172193

per-task confusion [0.060793760203501315, 0.06946660769905703, 0.0645575213698262, 0.07054972383357314, 0.06364486251623458, 0.06869555575174766, 0.06626540071799616,

In [10]:
run_experiments(n_runs=10, verbose=False, stopOnLoss = 0.02, with_OOD=True)

Starting run 1.


Experiment Progress: 100%|██████████| 5/5 [03:20<00:00, 40.06s/it]


Starting run 2.


Experiment Progress: 100%|██████████| 5/5 [03:00<00:00, 36.04s/it]


Starting run 3.


Experiment Progress: 100%|██████████| 5/5 [03:05<00:00, 37.04s/it]


Starting run 4.


Experiment Progress: 100%|██████████| 5/5 [03:09<00:00, 37.94s/it]


Starting run 5.


Experiment Progress: 100%|██████████| 5/5 [03:07<00:00, 37.49s/it]


Starting run 6.


Experiment Progress: 100%|██████████| 5/5 [02:57<00:00, 35.59s/it]


Starting run 7.


Experiment Progress: 100%|██████████| 5/5 [03:07<00:00, 37.47s/it]


Starting run 8.


Experiment Progress: 100%|██████████| 5/5 [03:07<00:00, 37.42s/it]


Starting run 9.


Experiment Progress: 100%|██████████| 5/5 [03:08<00:00, 37.78s/it]


Starting run 10.


Experiment Progress: 100%|██████████| 5/5 [03:01<00:00, 36.29s/it]


accuracy [0.1972, 0.1947, 0.1973, 0.1976, 0.2026, 0.1975, 0.2019, 0.1977, 0.1994, 0.2066]
Mean accuracy across 10 runs: 0.19925
Standard deviation of accuracy across 10 runs: 0.0034830542152924737

total confusion [0.35353219999999996, 0.3224897, 0.3741, 0.32867349999999995, 0.36605010000000004, 0.36746369999999995, 0.34240879999999996, 0.35646029999999995, 0.38574260000000005, 0.3594914]
Mean total confusion across 10 runs: 0.35564123
Standard deviation of total confusion across 10 runs: 0.01976600992726039

intra-phase confusion [0.34341750000000004, 0.31511560000000005, 0.3674995, 0.3227232, 0.35441469999999997, 0.3632388, 0.3365226, 0.350475, 0.3792729, 0.34924429999999995]
Mean intra-phase confusion across 10 runs: 0.34819241
Standard deviation of intra-phase confusion across 10 runs: 0.019753841444159655

per-task confusion [0.08573580984689053, 0.07231327535044012, 0.08093096909107253, 0.0733482354552929, 0.09511231453418094, 0.07380446264268503, 0.069508196567894, 0.07106233735

In [11]:
run_experiments(n_runs=10, verbose=False, stopOnLoss = 0.02, full_CE = False, with_OOD=True)

Starting run 1.


Experiment Progress: 100%|██████████| 5/5 [02:40<00:00, 32.11s/it]


Starting run 2.


Experiment Progress: 100%|██████████| 5/5 [03:03<00:00, 36.65s/it]


Starting run 3.


Experiment Progress: 100%|██████████| 5/5 [02:52<00:00, 34.55s/it]


Starting run 4.


Experiment Progress: 100%|██████████| 5/5 [03:04<00:00, 36.98s/it]


Starting run 5.


Experiment Progress: 100%|██████████| 5/5 [03:15<00:00, 39.18s/it]


Starting run 6.


Experiment Progress: 100%|██████████| 5/5 [03:13<00:00, 38.71s/it]


Starting run 7.


Experiment Progress: 100%|██████████| 5/5 [03:18<00:00, 39.68s/it]


Starting run 8.


Experiment Progress: 100%|██████████| 5/5 [02:36<00:00, 31.22s/it]


Starting run 9.


Experiment Progress: 100%|██████████| 5/5 [02:53<00:00, 34.70s/it]


Starting run 10.


Experiment Progress: 100%|██████████| 5/5 [02:53<00:00, 34.66s/it]


accuracy [0.7555, 0.7746, 0.7027, 0.8095, 0.7049, 0.7875, 0.7755, 0.7456, 0.707, 0.7533]
Mean accuracy across 10 runs: 0.75161
Standard deviation of accuracy across 10 runs: 0.03705369827096285

total confusion [0.3288173, 0.34396879999999996, 0.3717703, 0.34110949999999995, 0.3262279, 0.35404119999999994, 0.34030099999999996, 0.34495469999999995, 0.34428250000000005, 0.3580573]
Mean total confusion across 10 runs: 0.34535305
Standard deviation of total confusion across 10 runs: 0.01341933613412212

intra-phase confusion [0.3260809, 0.3413393, 0.36811249999999995, 0.33831259999999996, 0.32324470000000005, 0.3510626, 0.33763449999999995, 0.34234, 0.34028480000000005, 0.35493929999999996]
Mean intra-phase confusion across 10 runs: 0.34233512
Standard deviation of intra-phase confusion across 10 runs: 0.01321456844545273

per-task confusion [0.05520785792034462, 0.056060945091840876, 0.06932549534193291, 0.06135647294225206, 0.056822218565981866, 0.05927385438189723, 0.0592484491646351, 0

In [12]:
run_experiments(n_runs=10, verbose=False, stopOnLoss = 0.02, full_CE = False, kd_loss = 1, with_OOD=True)

Starting run 1.


Experiment Progress: 100%|██████████| 5/5 [03:35<00:00, 43.10s/it]


Starting run 2.


Experiment Progress: 100%|██████████| 5/5 [04:01<00:00, 48.32s/it]


Starting run 3.


Experiment Progress: 100%|██████████| 5/5 [03:23<00:00, 40.62s/it]


Starting run 4.


Experiment Progress: 100%|██████████| 5/5 [04:06<00:00, 49.22s/it]


Starting run 5.


Experiment Progress: 100%|██████████| 5/5 [03:48<00:00, 45.74s/it]


Starting run 6.


Experiment Progress: 100%|██████████| 5/5 [03:47<00:00, 45.46s/it]


Starting run 7.


Experiment Progress: 100%|██████████| 5/5 [04:03<00:00, 48.65s/it]


Starting run 8.


Experiment Progress: 100%|██████████| 5/5 [03:44<00:00, 44.80s/it]


Starting run 9.


Experiment Progress: 100%|██████████| 5/5 [03:41<00:00, 44.33s/it]


Starting run 10.


Experiment Progress: 100%|██████████| 5/5 [03:42<00:00, 44.53s/it]


accuracy [0.8616, 0.8825, 0.8152, 0.8575, 0.8198, 0.7945, 0.8421, 0.847, 0.8627, 0.8562]
Mean accuracy across 10 runs: 0.8439099999999999
Standard deviation of accuracy across 10 runs: 0.02656365311222586

total confusion [0.33459799999999995, 0.3288046, 0.33512810000000004, 0.34465789999999996, 0.34999020000000003, 0.332723, 0.35960289999999995, 0.33585790000000004, 0.33089250000000003, 0.3471951]
Mean total confusion across 10 runs: 0.33994502
Standard deviation of total confusion across 10 runs: 0.009937724562359989

intra-phase confusion [0.331445, 0.32531410000000005, 0.33142119999999997, 0.3417169, 0.34568489999999996, 0.3288552, 0.356282, 0.33280279999999995, 0.3267947, 0.34407960000000004]
Mean intra-phase confusion across 10 runs: 0.33643964
Standard deviation of intra-phase confusion across 10 runs: 0.010017188320881257

per-task confusion [0.062081280457460974, 0.058755388312693155, 0.06096391725267125, 0.056012167859016485, 0.06531364839805301, 0.06233225717074546, 0.068215

In [13]:
run_experiments(n_runs=10, verbose=False, stopOnLoss = 0.02, full_CE = False, kd_loss = 1)

Starting run 1.


Experiment Progress: 100%|██████████| 5/5 [02:20<00:00, 28.01s/it]


Starting run 2.


Experiment Progress: 100%|██████████| 5/5 [02:35<00:00, 31.07s/it]


Starting run 3.


Experiment Progress: 100%|██████████| 5/5 [02:33<00:00, 30.63s/it]


Starting run 4.


Experiment Progress: 100%|██████████| 5/5 [02:29<00:00, 29.87s/it]


Starting run 5.


Experiment Progress: 100%|██████████| 5/5 [02:30<00:00, 30.19s/it]


Starting run 6.


Experiment Progress: 100%|██████████| 5/5 [02:20<00:00, 28.07s/it]


Starting run 7.


Experiment Progress: 100%|██████████| 5/5 [02:30<00:00, 30.11s/it]


Starting run 8.


Experiment Progress: 100%|██████████| 5/5 [02:22<00:00, 28.45s/it]


Starting run 9.


Experiment Progress: 100%|██████████| 5/5 [02:50<00:00, 34.15s/it]


Starting run 10.


Experiment Progress: 100%|██████████| 5/5 [02:20<00:00, 28.04s/it]


accuracy [0.7817, 0.7097, 0.7277, 0.7791, 0.7482, 0.7326, 0.726, 0.75, 0.6578, 0.6404]
Mean accuracy across 10 runs: 0.72532
Standard deviation of accuracy across 10 runs: 0.04628061269161322

total confusion [0.39739349999999996, 0.408949, 0.39550929999999995, 0.3805647, 0.37810889999999997, 0.38847319999999996, 0.40739309999999995, 0.39386449999999995, 0.41056630000000005, 0.39680249999999995]
Mean total confusion across 10 runs: 0.39576249999999996
Standard deviation of total confusion across 10 runs: 0.011212453495902584

intra-phase confusion [0.39505199999999996, 0.4069349, 0.39340699999999995, 0.3786178, 0.37583599999999995, 0.3863639, 0.405405, 0.391957, 0.40823909999999997, 0.3945554]
Mean intra-phase confusion across 10 runs: 0.39363681
Standard deviation of intra-phase confusion across 10 runs: 0.011203234797092216

per-task confusion [0.06045580769479264, 0.0606780366347411, 0.05672472204941688, 0.054323932602047154, 0.05816458776587892, 0.05673229051835473, 0.0573873624645

In [14]:
run_experiments(n_runs=10, verbose=False, stopOnLoss = 0.02, full_CE = False, kd_loss = 1, with_dropout=True)

Starting run 1.


Experiment Progress: 100%|██████████| 5/5 [03:56<00:00, 47.28s/it]


Starting run 2.


Experiment Progress: 100%|██████████| 5/5 [03:49<00:00, 45.98s/it]


Starting run 3.


Experiment Progress: 100%|██████████| 5/5 [03:53<00:00, 46.66s/it]


Starting run 4.


Experiment Progress: 100%|██████████| 5/5 [04:00<00:00, 48.04s/it]


Starting run 5.


Experiment Progress: 100%|██████████| 5/5 [04:02<00:00, 48.59s/it]


Starting run 6.


Experiment Progress: 100%|██████████| 5/5 [03:48<00:00, 45.76s/it]


Starting run 7.


Experiment Progress: 100%|██████████| 5/5 [03:46<00:00, 45.33s/it]


Starting run 8.


Experiment Progress: 100%|██████████| 5/5 [04:04<00:00, 48.89s/it]


Starting run 9.


Experiment Progress: 100%|██████████| 5/5 [04:14<00:00, 50.91s/it]


Starting run 10.


Experiment Progress: 100%|██████████| 5/5 [03:36<00:00, 43.32s/it]


accuracy [0.7515, 0.7588, 0.7735, 0.7629, 0.7843, 0.712, 0.7861, 0.7856, 0.7738, 0.7883]
Mean accuracy across 10 runs: 0.76768
Standard deviation of accuracy across 10 runs: 0.02332598932045066

total confusion [0.3798897, 0.38332160000000004, 0.3825027, 0.390088, 0.3888752, 0.38298069999999995, 0.38966480000000003, 0.4048952, 0.38741309999999995, 0.39956899999999995]
Mean total confusion across 10 runs: 0.38892
Standard deviation of total confusion across 10 runs: 0.00790320150431316

intra-phase confusion [0.3781044, 0.3817285, 0.3809078, 0.388309, 0.3872504, 0.38133340000000004, 0.38798239999999995, 0.4027269, 0.38566, 0.3981123]
Mean intra-phase confusion across 10 runs: 0.38721151
Standard deviation of intra-phase confusion across 10 runs: 0.007824844968496168

per-task confusion [0.05644981332012735, 0.05477967314406325, 0.05374147101838771, 0.0586109084369467, 0.057987883239672675, 0.055799522425375024, 0.05901799695205616, 0.06528988685019613, 0.05869066885242378, 0.05613500059

In [15]:
run_experiments(n_runs=10, verbose=False, stopOnLoss = 0.02, full_CE = False, kd_loss = 1, with_OOD=True, with_dropout=True, ogd=True)

Starting run 1.


Experiment Progress: 100%|██████████| 5/5 [08:14<00:00, 98.88s/it] 


Starting run 2.


Experiment Progress: 100%|██████████| 5/5 [07:51<00:00, 94.35s/it] 


Starting run 3.


Experiment Progress: 100%|██████████| 5/5 [07:47<00:00, 93.54s/it] 


Starting run 4.


Experiment Progress: 100%|██████████| 5/5 [08:39<00:00, 103.99s/it]


Starting run 5.


Experiment Progress: 100%|██████████| 5/5 [08:20<00:00, 100.06s/it]


Starting run 6.


Experiment Progress: 100%|██████████| 5/5 [08:14<00:00, 98.99s/it] 


Starting run 7.


Experiment Progress: 100%|██████████| 5/5 [08:15<00:00, 99.05s/it] 


Starting run 8.


Experiment Progress: 100%|██████████| 5/5 [08:18<00:00, 99.68s/it] 


Starting run 9.


Experiment Progress: 100%|██████████| 5/5 [07:46<00:00, 93.27s/it] 


Starting run 10.


Experiment Progress: 100%|██████████| 5/5 [08:38<00:00, 103.72s/it]


accuracy [0.8614, 0.8782, 0.8883, 0.8802, 0.8813, 0.8668, 0.8707, 0.8941, 0.8851, 0.8959]
Mean accuracy across 10 runs: 0.8802
Standard deviation of accuracy across 10 runs: 0.011342055467252045

total confusion [0.3218822, 0.327094, 0.30800970000000005, 0.33423899999999995, 0.32600989999999996, 0.32439969999999996, 0.31565200000000004, 0.3159969, 0.32359039999999994, 0.3082323]
Mean total confusion across 10 runs: 0.32051061
Standard deviation of total confusion across 10 runs: 0.008429792467255613

intra-phase confusion [0.31949720000000004, 0.32471839999999996, 0.3057324, 0.33198720000000004, 0.32359689999999997, 0.3223967, 0.31331739999999997, 0.3135527, 0.3210545, 0.3059898]
Mean intra-phase confusion across 10 runs: 0.31818432
Standard deviation of intra-phase confusion across 10 runs: 0.008426880905873646

per-task confusion [0.050377915978045903, 0.05457299622576237, 0.04991771277570356, 0.05269847078014307, 0.05230042596717337, 0.047196952958522685, 0.04979997720523124, 0.0522

In [16]:
run_experiments(n_runs=10, verbose=False, stopOnLoss = 0.02, full_CE = False, kd_loss = 1, with_OOD=True)

Starting run 1.


Experiment Progress: 100%|██████████| 5/5 [03:25<00:00, 41.16s/it]


Starting run 2.


Experiment Progress: 100%|██████████| 5/5 [03:40<00:00, 44.08s/it]


Starting run 3.


Experiment Progress: 100%|██████████| 5/5 [03:43<00:00, 44.66s/it]


Starting run 4.


Experiment Progress: 100%|██████████| 5/5 [03:42<00:00, 44.46s/it]


Starting run 5.


Experiment Progress: 100%|██████████| 5/5 [04:04<00:00, 48.85s/it]


Starting run 6.


Experiment Progress: 100%|██████████| 5/5 [04:09<00:00, 49.81s/it]


Starting run 7.


Experiment Progress: 100%|██████████| 5/5 [03:43<00:00, 44.78s/it]


Starting run 8.


Experiment Progress: 100%|██████████| 5/5 [04:01<00:00, 48.25s/it]


Starting run 9.


Experiment Progress: 100%|██████████| 5/5 [03:51<00:00, 46.37s/it]


Starting run 10.


Experiment Progress: 100%|██████████| 5/5 [03:47<00:00, 45.54s/it]


accuracy [0.8258, 0.8121, 0.8424, 0.713, 0.8106, 0.7814, 0.8615, 0.8555, 0.8617, 0.8547]
Mean accuracy across 10 runs: 0.82187
Standard deviation of accuracy across 10 runs: 0.04656613099190824

total confusion [0.3424703, 0.35254830000000004, 0.3197354, 0.33199120000000004, 0.3422233, 0.3393106, 0.35108779999999995, 0.33810569999999995, 0.3475425, 0.3531221]
Mean total confusion across 10 runs: 0.34181372
Standard deviation of total confusion across 10 runs: 0.010362786703906751

intra-phase confusion [0.3389839, 0.34956580000000004, 0.3161345, 0.3285097, 0.3388044, 0.3361543, 0.347278, 0.33438120000000005, 0.34424639999999995, 0.34957760000000004]
Mean intra-phase confusion across 10 runs: 0.33836358
Standard deviation of intra-phase confusion across 10 runs: 0.010423631106694276

per-task confusion [0.05989238798090919, 0.05911372236187444, 0.05963243054433602, 0.05723384226471466, 0.06783373042319449, 0.060629871745132013, 0.07000971916482764, 0.06214458304169166, 0.062361839787624

In [6]:
run_experiments(n_runs=10, verbose=False, stopOnLoss = 0.02, full_CE = False, kd_loss = 1, with_OOD=True, with_dropout=True)

Starting run 1.


Experiment Progress: 100%|██████████| 5/5 [07:07<00:00, 85.52s/it] 


Starting run 2.


Experiment Progress: 100%|██████████| 5/5 [07:26<00:00, 89.28s/it] 


Starting run 3.


Experiment Progress: 100%|██████████| 5/5 [07:47<00:00, 93.42s/it] 


Starting run 4.


Experiment Progress: 100%|██████████| 5/5 [07:16<00:00, 87.37s/it] 


Starting run 5.


Experiment Progress: 100%|██████████| 5/5 [07:52<00:00, 94.50s/it] 


Starting run 6.


Experiment Progress: 100%|██████████| 5/5 [07:43<00:00, 92.61s/it] 


Starting run 7.


Experiment Progress: 100%|██████████| 5/5 [06:55<00:00, 83.04s/it] 


Starting run 8.


Experiment Progress: 100%|██████████| 5/5 [07:18<00:00, 87.78s/it] 


Starting run 9.


Experiment Progress: 100%|██████████| 5/5 [06:59<00:00, 83.95s/it] 


Starting run 10.


Experiment Progress: 100%|██████████| 5/5 [06:21<00:00, 76.37s/it] 


accuracy [0.8832, 0.8664, 0.8257, 0.8562, 0.8638, 0.8462, 0.8747, 0.8772, 0.8371, 0.8833]
Mean accuracy across 10 runs: 0.86138
Standard deviation of accuracy across 10 runs: 0.019819171414455144

total confusion [0.3128014, 0.3207443, 0.32837019999999995, 0.33641909999999997, 0.3331537, 0.33293269999999997, 0.3228172, 0.31472710000000004, 0.3328434, 0.3220687]
Mean total confusion across 10 runs: 0.32568778
Standard deviation of total confusion across 10 runs: 0.008258057280929388

intra-phase confusion [0.3101087, 0.3185837, 0.32618309999999995, 0.3340997, 0.33052970000000004, 0.33069479999999996, 0.3203538, 0.31198289999999995, 0.33045250000000004, 0.3195703]
Mean intra-phase confusion across 10 runs: 0.32325592
Standard deviation of intra-phase confusion across 10 runs: 0.008366318923291313

per-task confusion [0.05351556350007423, 0.05314371663644586, 0.05197445406789425, 0.05270911752140741, 0.05826406745431898, 0.05447583172700352, 0.05579256025899848, 0.05263171447505943, 0.055

In [18]:
# training on entire dataset
globals.ITERATIONS = 1
globals.CLASSES_PER_ITER = 10
run_experiments(n_runs=10, verbose=False, stopOnLoss = 0.02)

Starting run 1.


Experiment Progress: 100%|██████████| 1/1 [04:20<00:00, 260.81s/it]


Starting run 2.


Experiment Progress: 100%|██████████| 1/1 [03:42<00:00, 222.60s/it]


Starting run 3.


Experiment Progress: 100%|██████████| 1/1 [03:41<00:00, 221.24s/it]


Starting run 4.


Experiment Progress: 100%|██████████| 1/1 [04:09<00:00, 249.30s/it]


Starting run 5.


Experiment Progress: 100%|██████████| 1/1 [03:40<00:00, 220.34s/it]


Starting run 6.


Experiment Progress: 100%|██████████| 1/1 [04:03<00:00, 243.67s/it]


Starting run 7.


Experiment Progress: 100%|██████████| 1/1 [04:03<00:00, 243.27s/it]


Starting run 8.


Experiment Progress: 100%|██████████| 1/1 [03:36<00:00, 216.45s/it]


Starting run 9.


Experiment Progress: 100%|██████████| 1/1 [03:32<00:00, 212.56s/it]


Starting run 10.


Experiment Progress: 100%|██████████| 1/1 [03:40<00:00, 220.24s/it]


accuracy [0.9894, 0.9883, 0.9882, 0.9914, 0.9879, 0.989, 0.9905, 0.9916, 0.989, 0.9903]
Mean accuracy across 10 runs: 0.98956
Standard deviation of accuracy across 10 runs: 0.0013259797216482034

total confusion [0.2295988, 0.2016749, 0.1983973, 0.21697999999999995, 0.21374439999999995, 0.20362369999999996, 0.2011659, 0.18750109999999998, 0.2111398, 0.20915170000000005]
Mean total confusion across 10 runs: 0.20729776
Standard deviation of total confusion across 10 runs: 0.011569503026702954

intra-phase confusion [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
Mean intra-phase confusion across 10 runs: 0.0
Standard deviation of intra-phase confusion across 10 runs: 0.0

per-task confusion [0.2295988, 0.2016749, 0.1983973, 0.21697999999999995, 0.21374439999999995, 0.20362369999999996, 0.2011659, 0.18750109999999998, 0.2111398, 0.20915170000000005]
Mean per-task confusion across 10 runs: 0.20729776
Standard deviation of per-task confusion across 10 runs: 0.011569503026702954

embeddin

In [19]:
torch.cuda.empty_cache()