In [1]:
import torch
import torchvision
from torchvision import transforms, datasets
from sklearn import metrics
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Subset
import numpy as np
import copy
import statistics
import matplotlib.pyplot as plt
import time
from tqdm import tqdm
torch.cuda.empty_cache()
import sys
sys.path.append('./../src')
import globals
from model import MnistCNN, TinyImageNetCNN, Cifar10CNN
from training import train_model, train_model_CL
from visualizations import plot_embeddings, plot_confusion_matrix
from feature_attribution import Feature_Importance_Evaluations
from pytorch_utils import get_features, get_labels
from data_utils import initialize_data
from embedding_measurements import measure_embedding_confusion_knn, measure_embedding_drift

In [2]:
SEED = globals.SEED
DEVICE = globals.DEVICE
full_trainset = globals.full_trainset
trainset = globals.trainset
testset = globals.testset
trainloaders = globals.trainloaders
valloaders = globals.valloaders
testloaders = globals.testloaders

In [3]:
def run_experiment(
        verbose = False,
        stopOnLoss = None,
        full_CE = True,
        with_OOD = False,
        ood_method = 'jigsaw',
        kd_loss = 0,
        stopOnValAcc = None,
        epochs = 1000000,
        with_dropout = False,
        optimiser_type = 'sgd',
        dataset = 'mnist',
        joint_training = False,
        save_saliency=False,
        plotting=False
        ):
    def _print(*args, **kwargs):
        if verbose:
            print(*args, **kwargs)
    if dataset == 'tiny_imagenet':
        stopOnLoss = 1.3
        if joint_training:
            globals.CLASSES_PER_ITER = 200
            globals.ITERATIONS = 1
        else:
            globals.CLASSES_PER_ITER = 40
            globals.ITERATIONS = 5
    elif dataset == 'mnist':
        stopOnLoss = 0.03
        if joint_training:
            globals.CLASSES_PER_ITER = 10
            globals.ITERATIONS = 1
        else:
            globals.CLASSES_PER_ITER = 2
            globals.ITERATIONS = 5
    elif dataset == 'cifar10':
        stopOnLoss = 0.3
        if joint_training:
            globals.CLASSES_PER_ITER = 10
            globals.ITERATIONS = 1
        else:
            globals.CLASSES_PER_ITER = 2
            globals.ITERATIONS = 5
    if with_OOD:
        globals.toggle_OOD(ood_method)
    else:
        globals.disable_OOD()
    initialize_data(dataset)
    prevModel = None
    
    globals.WITH_DROPOUT = with_dropout

    #[Denis] added code:
    Feature_Importance_Eval=Feature_Importance_Evaluations(globals.testloaders, DEVICE)

    for i in tqdm(range(globals.ITERATIONS), desc="Experiment Progress"):
        if dataset == 'mnist':
            model = MnistCNN((i+1)*(globals.CLASSES_PER_ITER+globals.OOD_CLASS))
        elif dataset == 'tiny_imagenet':
            model = TinyImageNetCNN((i+1)*(globals.CLASSES_PER_ITER+globals.OOD_CLASS))
        elif dataset == 'cifar10':
            model = Cifar10CNN((i+1)*(globals.CLASSES_PER_ITER+globals.OOD_CLASS))
        if prevModel is not None:
            with torch.no_grad():
                model.copyPrev(prevModel)
        train_loader = globals.trainloaders[i]
        val_loader = globals.valloaders[i]
        if prevModel:
            _print("CL TRAIN!!")
            train_model_CL(
                model,
                prevModel,
                train_loader,
                val_loader,
                i,
                verbose,
                epochs,
                True,
                kd_loss=kd_loss,
                stopOnLoss=stopOnLoss,
                stopOnValAcc = stopOnValAcc,
                full_CE=full_CE,
                optimiser_type=optimiser_type,
                plotting=plotting
                )
        else:
            _print("TRAINING!")
            train_model(
                model, 
                train_loader, 
                val_loader, 
                verbose, 
                epochs=epochs, 
                stopOnLoss=stopOnLoss,
                optimiser_type=optimiser_type,
                plotting=plotting
                )

        #[Denis] added code:
        Feature_Importance_Eval.Task_Feature_Attribution(model, i)
        
        if verbose or i == globals.ITERATIONS-1:
            _print("Starting evaluation")
            _print("ITERATION", i+1)
            _print("ACCURACIES PER TASK:")
            accumPred = []
            all_labels = []
            all_embeddings = []
            with torch.no_grad():
                for j in range(i+1):
                    val_loader = globals.testloaders[j]
                    val_labels = get_labels(val_loader).to(DEVICE)
                    all_labels.append(val_labels)
                    model.eval()
                    pred, embeddings = model.get_pred_and_embeddings((get_features(val_loader).to(DEVICE)))
                    model.train()
                    accumPred.append(pred)
                    all_embeddings.append(embeddings)
                    sliced_pred = pred[:, j*(globals.CLASSES_PER_ITER+globals.OOD_CLASS):(j+1)*(globals.CLASSES_PER_ITER+globals.OOD_CLASS)]
                    _, predicted = torch.max(sliced_pred, 1)  # Get the class predictions
                    predicted += j*globals.CLASSES_PER_ITER
                    correct = (predicted == val_labels).sum().item()  # Count how many were correct
                    accuracy = correct / val_labels.size(0)  # Accuracy as a percentage
                    _print(str(accuracy), end=' ')
            accumPred = torch.cat(accumPred)
            all_labels = torch.cat(all_labels)
            all_embeddings = torch.cat(all_embeddings)
            predicted = []
            for x in accumPred:
                if globals.OOD_CLASS == 1:
                    x_pred = x[[i for i in range(x.size(0)) if (i + 1) % (globals.CLASSES_PER_ITER+1) != 0]]
                else:
                    x_pred = x
                x_pred = torch.softmax(x_pred, dim=-1)
                max = 0
                for (k, v) in enumerate(x_pred):
                    if v > max:
                        max = v
                        p = k
                predicted.append(p)
            predicted = torch.tensor(predicted).to(DEVICE)
            correct = (predicted == all_labels).sum().item()  # Count how many were correct
            accuracy = correct / all_labels.size(0)  # Accuracy as a percentage
            _print("Accuracy on tasks so far:", accuracy)
            embedding_drift = measure_embedding_drift(all_embeddings, all_labels, model.prev_test_embedding_centers)
            _print("Average embedding drift based on centroids:", embedding_drift)
            total_confusion, intra_phase_confusion, per_task_confusion = measure_embedding_confusion_knn(all_embeddings, all_labels, k = 300, task=i+1)
            _print("Total confusion", total_confusion)
            _print("Intra-phase confusion", intra_phase_confusion)
            _print("Per task confusions", per_task_confusion)
            if verbose and dataset == 'mnist':
                plot_confusion_matrix(predicted.cpu(), all_labels.cpu(), list(range(globals.CLASSES_PER_ITER*(i+1))))
        prevModel = copy.deepcopy(model)
        
    #[Denis] added code:
    [avg_att_diff,att_diffs,_,_,avg_att_spread,att_spreads]=Feature_Importance_Eval.Get_Feature_Change_Score(prevModel)
    _print("Average SHAPC values (ordered as tasks):", att_diffs)
    _print("Averaged SHAPC value (the smaller the better):", avg_att_diff)
    _print("Average attention spread values (ordered as tasks):", att_spreads)
    _print("Averaged attention spread value (the bigger the better):", avg_att_spread)
    if save_saliency:
        Feature_Importance_Eval.Save_Random_Picture_Salency() #prints the salcency maps for 1 example by class (first row: image, second row: salency map after training, third row: salency map after training task where class is included)
    
    return accuracy, total_confusion, intra_phase_confusion, per_task_confusion, embedding_drift, avg_att_diff, avg_att_spread

In [4]:
def run_experiments(n_runs=globals.EXPERIMENT_N_RUNS, *args, **kwargs):
    verbose = kwargs.get('verbose', None)
    def _print(*args, **kwargs):
        if verbose:
            print(*args, **kwargs)
    def report_stats(data, name):
        #print(name, data)
        mean = statistics.mean(data)
        std = statistics.stdev(data)
        print(f"Mean " + name + f" across {n_runs} runs: {mean}")
        print(f"Standard deviation of " + name + f" across {n_runs} runs: {std}\n")
    accuracies = []
    total_confusions = []
    intra_phase_confusions = []
    per_task_confusions = []
    att_diffs = []
    embedding_drifts = []
    att_spreads = []
    for r in range(n_runs):
        print(f"Starting run {r+1}.")
        accuracy, total_confusion, intra_phase_confusion, per_task_confusion, embedding_drift, avg_att_diff, avg_att_spread = run_experiment(*args, **kwargs)
        accuracies.append(accuracy)
        total_confusions.append(total_confusion)
        intra_phase_confusions.append(intra_phase_confusion)
        per_task_confusions.append(per_task_confusion)
        att_diffs.append(avg_att_diff)
        embedding_drifts.append(embedding_drift)
        att_spreads.append(avg_att_spread)
        print(f"Run {r} finished with accuracy {accuracy}")

    report_stats(accuracies, "accuracy")
    report_stats(total_confusions, "total confusion")
    report_stats(intra_phase_confusions, "intra-phase confusion")
    report_stats(per_task_confusions, "per-task confusion")
    report_stats(embedding_drifts, "embedding drift")
    report_stats(att_diffs, "attention drift")
    report_stats(att_spreads, "attention spread")

In [None]:
run_experiments(dataset='mnist', n_runs=2, ood_method='jigsaw', with_dropout=True, kd_loss=1, full_CE=False, with_OOD=True, optimiser_type='sgd')

Starting run 1.


Experiment Progress:  80%|████████  | 4/5 [02:51<00:35, 35.26s/it]

In [None]:
run_experiments(dataset='cifar10', n_runs=2, verbose=True, ood_method='fmix', with_dropout=True, kd_loss=1, full_CE=False, with_OOD=True, optimiser_type='sgd')

Starting run 1.
Files already downloaded and verified
Files already downloaded and verified


Experiment Progress:   0%|          | 0/5 [00:00<?, ?it/s]

TRAINING!
Epoch 0, CE Loss: 1.0987, CE Loss (no OOD): 1.0902
Validation loss 0.6930926382541657 validation accuracy 0.5 

Epoch 1, CE Loss: 1.0983, CE Loss (no OOD): 1.0950
Validation loss 0.6929800391197205 validation accuracy 0.5 

Epoch 2, CE Loss: 1.0982, CE Loss (no OOD): 1.0981
Validation loss 0.6927933931350708 validation accuracy 0.512 

Epoch 3, CE Loss: 1.0978, CE Loss (no OOD): 1.0960
Validation loss 0.6925217807292938 validation accuracy 0.545 

Epoch 4, CE Loss: 1.0974, CE Loss (no OOD): 1.0983
Validation loss 0.6920787930488587 validation accuracy 0.552 

Epoch 5, CE Loss: 1.0967, CE Loss (no OOD): 1.0982
Validation loss 0.6913796424865722 validation accuracy 0.541 

Epoch 6, CE Loss: 1.0958, CE Loss (no OOD): 1.0934
Validation loss 0.6901637554168701 validation accuracy 0.555 

Epoch 7, CE Loss: 1.0941, CE Loss (no OOD): 1.0948
Validation loss 0.6879110872745514 validation accuracy 0.584 

Epoch 8, CE Loss: 1.0910, CE Loss (no OOD): 1.0919
Validation loss 0.6833425104618

Experiment Progress:  20%|██        | 1/5 [03:04<12:18, 184.64s/it]

Total confusion 0.12621166666666672
Intra-phase confusion 0.0
Per task confusions 0.12621166666666672
CL TRAIN!!
Epoch 0  CELoss: 0.6142, KLLoss: 0.0211, CELoss (no OOD): 0.7998
Fraction of nonzero parameters 0.999945231325961
Validation losses: 0.6642903208732605 0.025236868858337404
Validation accuracy (for last task) 0.0
Total validation accuracy 0.431


Epoch 1  CELoss: 0.4995, KLLoss: 0.0252, CELoss (no OOD): 0.6703
Fraction of nonzero parameters 0.9999426506031006
Validation losses: 0.620998814702034 0.021495026350021363
Validation accuracy (for last task) 0.001
Total validation accuracy 0.44


Epoch 2  CELoss: 0.4689, KLLoss: 0.0214, CELoss (no OOD): 0.6406
Fraction of nonzero parameters 0.9999456614464378
Validation losses: 0.643565633893013 0.04081907868385315
Validation accuracy (for last task) 0.012
Total validation accuracy 0.426


Epoch 3  CELoss: 0.4623, KLLoss: 0.0279, CELoss (no OOD): 0.6274
Fraction of nonzero parameters 0.9999435108440541
Validation losses: 0.61790878

In [None]:
run_experiments(dataset='cifar10', n_runs=2, with_dropout=True, kd_loss=1, full_CE=False, with_OOD=False, optimiser_type='sgd')

Starting run 1.
Files already downloaded and verified
Files already downloaded and verified


Experiment Progress:  20%|██        | 1/5 [00:10<00:43, 10.89s/it]

In [None]:
run_experiments(dataset='cifar10', ood_method='jigsaw', n_runs=2, with_dropout=True, kd_loss=1, full_CE=False, with_OOD=True, optimiser_type='sgd')

Starting run 1.
Files already downloaded and verified
Files already downloaded and verified


Experiment Progress:  20%|██        | 1/5 [01:39<06:39, 99.78s/it]

In [None]:
run_experiments(dataset='cifar10', n_runs=2)

Starting run 1.
Files already downloaded and verified
Files already downloaded and verified


Experiment Progress: 100%|██████████| 5/5 [01:21<00:00, 16.21s/it]


Run 0 finished with accuracy 0.1294
Starting run 2.
Files already downloaded and verified
Files already downloaded and verified


Experiment Progress: 100%|██████████| 5/5 [01:07<00:00, 13.59s/it]


Run 1 finished with accuracy 0.1
Mean accuracy across 2 runs: 0.1147
Standard deviation of accuracy across 2 runs: 0.020788939366884484

Mean total confusion across 2 runs: 0.8719761666666667
Standard deviation of total confusion across 2 runs: 0.013871313726536499

Mean intra-phase confusion across 2 runs: 0.7688921666666667
Standard deviation of intra-phase confusion across 2 runs: 0.014828264903742283

Mean per-task confusion across 2 runs: 0.466561
Standard deviation of per-task confusion across 2 runs: 0.025700974473527104

Mean embedding drift across 2 runs: 4.2100794315338135
Standard deviation of embedding drift across 2 runs: 0.22579752621226368

Mean attention drift across 2 runs: 4.828478976477869e-07
Standard deviation of attention drift across 2 runs: 7.833737674377161e-08

Mean attention spread across 2 runs: 132.76986112772622
Standard deviation of attention spread across 2 runs: 15.01144139842277



: 

In [7]:
run_experiments(dataset='cifar10', with_OOD = True, ood_method = 'jigsaw', n_runs=2, verbose=True)

Starting run 1.
Files already downloaded and verified
Files already downloaded and verified


Experiment Progress:   0%|          | 0/5 [00:00<?, ?it/s]

TRAINING!
Epoch 0, CE Loss: 1.0926, CE Loss (no OOD): 1.0836
Validation loss 0.6420930862426758 validation accuracy 0.686 

Epoch 1, CE Loss: 0.7330, CE Loss (no OOD): 0.7753
Validation loss 0.36105107963085176 validation accuracy 0.848 

Epoch 2, CE Loss: 0.2887, CE Loss (no OOD): 0.3815
Validation loss 0.25312662944197656 validation accuracy 0.898 

Epoch 3, CE Loss: 0.1963, CE Loss (no OOD): 0.2682
Validation loss 0.23949620202183725 validation accuracy 0.901 

Starting evaluation
ITERATION 1
ACCURACIES PER TASK:
0.91 Accuracy on tasks so far: 0.91
Average embedding drift based on centroids: 3.475126868579537e-05


Experiment Progress:  20%|██        | 1/5 [00:59<03:56, 59.09s/it]

Total confusion 0.13116499999999998
Intra-phase confusion 0.0
Per task confusions 0.13116499999999998
CL TRAIN!!
Epoch 0  CELoss: 0.5461, KLLoss: 0.0000, CELoss (no OOD): 7.6303
Fraction of nonzero parameters 0.9999437975910386
Validation losses: 0.5388980209827423 0.3359523892402649
Validation accuracy (for last task) 0.737
Total validation accuracy 0.3685


Epoch 1  CELoss: 0.3567, KLLoss: 0.0000, CELoss (no OOD): 9.5378
Fraction of nonzero parameters 0.9999433674705618
Validation losses: 0.5199232846498489 0.33753766119480133
Validation accuracy (for last task) 0.745
Total validation accuracy 0.3725


Epoch 2  CELoss: 0.3198, KLLoss: 0.0000, CELoss (no OOD): 10.3282
Fraction of nonzero parameters 0.9999432240970696
Validation losses: 0.4569292485713959 0.45044625997543336
Validation accuracy (for last task) 0.799
Total validation accuracy 0.3995


Epoch 3  CELoss: 0.2938, KLLoss: 0.0000, CELoss (no OOD): 10.9145
Fraction of nonzero parameters 0.9999478120488215
Validation losses: 0.

: 

In [5]:
run_experiments(dataset='cifar10', with_OOD = True, ood_method = 'fmix', n_runs=2)

Starting run 1.
Files already downloaded and verified
Files already downloaded and verified


Experiment Progress:   0%|          | 0/5 [00:00<?, ?it/s]

: 

In [None]:
run_experiments(dataset='mnist', n_runs=2, epochs=1, with_dropout=True, kd_loss=1, full_CE=False, with_OOD=False, optimiser_type='sgd')