In [1]:
import torch
import torchvision
from torchvision import transforms, datasets
from sklearn import metrics
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Subset
import numpy as np
import copy
import statistics
import matplotlib.pyplot as plt
import time
from tqdm import tqdm
torch.cuda.empty_cache()
import sys
sys.path.append('./../src')
import globals
from model import MnistCNN, TinyImageNetCNN, Cifar10CNN
from training import train_model, train_model_CL
from visualizations import plot_embeddings, plot_confusion_matrix
from feature_attribution import Feature_Importance_Evaluations
from pytorch_utils import get_features, get_labels
from data_utils import initialize_data
from embedding_measurements import measure_embedding_confusion_knn, measure_embedding_drift

In [2]:
SEED = globals.SEED
DEVICE = globals.DEVICE
full_trainset = globals.full_trainset
trainset = globals.trainset
testset = globals.testset
trainloaders = globals.trainloaders
valloaders = globals.valloaders
testloaders = globals.testloaders

In [3]:
def evalModel(model, attr_device, i, only_pred = False):
    accumPred = []
    all_labels = []
    all_embeddings = []
    with torch.no_grad():
        for j in range(i+1):
            val_loader = globals.testloaders[j]
            val_labels = get_labels(val_loader).to(attr_device)
            model.eval()
            pred, embeddings = model.get_pred_and_embeddings((get_features(val_loader).to(attr_device)))
            model.train()
            accumPred.append(pred)
            if not only_pred:
                all_embeddings.append(embeddings)
                all_labels.append(val_labels)
    accumPred = torch.cat(accumPred)
    if not only_pred:
        all_labels = torch.cat(all_labels)
        all_embeddings = torch.cat(all_embeddings)
    if globals.OOD_CLASS == 1:
        accumPred = accumPred[:,[i for i in range(accumPred.size(1)) if (i + 1) % (globals.CLASSES_PER_ITER+1) != 0]]
    predicted = torch.argmax(accumPred, dim=1)
    return predicted, all_labels, all_embeddings

In [4]:
def run_experiment(
        verbose = False,
        full_CE = True,
        with_OOD = False,
        kd_loss = 0,
        stopOnValAcc = None,
        epochs = 1000000,
        with_dropout = False,
        optimiser_type = 'sgd',
        dataset = 'mnist',
        joint_training = False,
        save_saliency=False,
        plotting=False
        ):
    def _print(*args, **kwargs):
        if verbose:
            print(*args, **kwargs)
    if dataset == 'tiny_imagenet':
        ogd_basis_size=None
        patience = 6
        ood_method = 'fmix'
        if joint_training:
            globals.CLASSES_PER_ITER = 200
            globals.ITERATIONS = 1
        else:
            globals.CLASSES_PER_ITER = 40
            globals.ITERATIONS = 5
    elif dataset == 'mnist':
        ogd_basis_size=200
        patience = 2
        ood_method = 'jigsaw'
        if joint_training:
            globals.CLASSES_PER_ITER = 10
            globals.ITERATIONS = 1
        else:
            globals.CLASSES_PER_ITER = 2
            globals.ITERATIONS = 5
    elif dataset == 'cifar10':
        ogd_basis_size=50
        patience = 4
        ood_method = 'fmix'
        if joint_training:
            globals.CLASSES_PER_ITER = 10
            globals.ITERATIONS = 1
        else:
            globals.CLASSES_PER_ITER = 2
            globals.ITERATIONS = 5
    if with_OOD:
        globals.toggle_OOD(ood_method)
    else:
        globals.disable_OOD()
    initialize_data(dataset)
    prevModel = None
    
    globals.WITH_DROPOUT = with_dropout

    #[Denis] added code:
    if dataset == 'tiny_imagenet':
        attr_device = 'cpu' # CPU has high memory in the cluster
    else:
        attr_device = globals.DEVICE
    Feature_Importance_Eval=Feature_Importance_Evaluations(globals.testloaders, attr_device)

    for i in tqdm(range(globals.ITERATIONS), desc="Experiment Progress"):
        if dataset == 'mnist':
            model = MnistCNN((i+1)*(globals.CLASSES_PER_ITER+globals.OOD_CLASS))
        elif dataset == 'tiny_imagenet':
            model = TinyImageNetCNN((i+1)*(globals.CLASSES_PER_ITER+globals.OOD_CLASS))
        elif dataset == 'cifar10':
            model = Cifar10CNN((i+1)*(globals.CLASSES_PER_ITER+globals.OOD_CLASS))
        if prevModel is not None:
            with torch.no_grad():
                model.copyPrev(prevModel)
        train_loader = globals.trainloaders[i]
        val_loader = globals.valloaders[i]
        if prevModel:
            _print("CL TRAIN!!")
            model = train_model_CL(
                model,
                prevModel,
                train_loader,
                val_loader,
                i,
                verbose,
                epochs,
                True,
                kd_loss=kd_loss,
                stopOnLoss=None,
                stopOnValAcc = stopOnValAcc,
                full_CE=full_CE,
                optimiser_type=optimiser_type,
                plotting=plotting,
                patience=patience,
                ogd_basis_size=ogd_basis_size
                )
        else:
            _print("TRAINING!")
            model = train_model(
                model, 
                train_loader,
                val_loader, 
                verbose, 
                epochs=epochs, 
                stopOnLoss=None,
                optimiser_type=optimiser_type,
                plotting=plotting,
                patience=patience,
                ogd_basis_size=ogd_basis_size
                )

        #[Denis] added code:
        Feature_Importance_Eval.Task_Feature_Attribution(model, i)
        
        if verbose or i == globals.ITERATIONS-1:
            if dataset == 'tiny_imagenet':
                model.to('cpu')
            _print("Starting evaluation")
            _print("ITERATION", i+1)
            predicted, all_labels, all_embeddings = evalModel(model, attr_device, i)
            if dataset == 'tiny_imagenet':
                model.to(DEVICE)
            predicted = predicted.to(attr_device)
            correct = (predicted == all_labels).sum().item()  # Count how many were correct
            accuracy = correct / all_labels.size(0)  # Accuracy as a percentage
            _print("Accuracy on tasks so far:", accuracy)
            embedding_drift = measure_embedding_drift(all_embeddings, all_labels, model.prev_test_embedding_centers)
            _print("Average embedding drift based on centroids:", embedding_drift)
            total_confusion, intra_phase_confusion, per_task_confusion = measure_embedding_confusion_knn(all_embeddings, all_labels, k = 300, task=i+1)
            _print("Total confusion", total_confusion)
            _print("Intra-phase confusion", intra_phase_confusion)
            _print("Per task confusions", per_task_confusion)
            if verbose and dataset == 'mnist':
                plot_confusion_matrix(predicted.cpu(), all_labels.cpu(), list(range(globals.CLASSES_PER_ITER*(i+1))))
        prevModel = copy.deepcopy(model)
        
    #[Denis] added code:
    [avg_att_diff,att_diffs,_,_,avg_att_spread,att_spreads]=Feature_Importance_Eval.Get_Feature_Change_Score(prevModel)
    _print("Average SHAPC values (ordered as tasks):", att_diffs)
    _print("Averaged SHAPC value (the smaller the better):", avg_att_diff)
    _print("Average attention spread values (ordered as tasks):", att_spreads)
    _print("Averaged attention spread value (the bigger the better):", avg_att_spread)
    if save_saliency:
        Feature_Importance_Eval.Save_Random_Picture_Salency() #prints the salcency maps for 1 example by class (first row: image, second row: salency map after training, third row: salency map after training task where class is included)
    if not joint_training:
        _print("Training output layer to check for bias...")
        model.output_layer.reset_parameters()
        globals.ITERATIONS = 1
        if dataset == 'mnist' or dataset == 'cifar10':
            globals.CLASSES_PER_ITER = 10
        else:
            globals.CLASSES_PER_ITER = 200
        model = train_model(
                    model, 
                    DataLoader(globals.trainset, batch_size=globals.BATCH_SIZE, shuffle=True, pin_memory=True, num_workers = 0), 
                    DataLoader(globals.valset, batch_size=100, shuffle=False), 
                    verbose, 
                    epochs=epochs, 
                    stopOnLoss=None,
                    optimiser_type='sgd',
                    plotting=plotting,
                    patience=patience,
                    ogd_basis_size=ogd_basis_size,
                    only_output_layer=True
                    )
        predicted, _, _ = evalModel(model, attr_device, i, True)
        predicted = predicted.to(attr_device)
        correct = (predicted == all_labels).sum().item()  # Count how many were correct
        unbiased_output_accuracy = correct / all_labels.size(0)  # Accuracy as a percentage
        output_bias = unbiased_output_accuracy-accuracy
    else:
        output_bias = 0
    return accuracy, total_confusion, intra_phase_confusion, per_task_confusion, embedding_drift, avg_att_diff, avg_att_spread, output_bias

In [5]:
def run_experiments(n_runs=globals.EXPERIMENT_N_RUNS, *args, **kwargs):
    verbose = kwargs.get('verbose', None)
    def _print(*args, **kwargs):
        if verbose:
            print(*args, **kwargs)
    def report_stats(data, name):
        #print(name, data)
        mean = statistics.mean(data)
        std = statistics.stdev(data)
        print(f"Mean " + name + f" across {n_runs} runs: {mean}")
        print(f"Standard deviation of " + name + f" across {n_runs} runs: {std}\n")
    accuracies = []
    total_confusions = []
    intra_phase_confusions = []
    per_task_confusions = []
    att_diffs = []
    embedding_drifts = []
    att_spreads = []
    output_biases = []
    for r in range(n_runs):
        print(f"Starting run {r+1}.")
        accuracy, total_confusion, intra_phase_confusion, per_task_confusion, embedding_drift, avg_att_diff, avg_att_spread, output_bias = run_experiment(*args, **kwargs)
        accuracies.append(accuracy)
        total_confusions.append(total_confusion)
        intra_phase_confusions.append(intra_phase_confusion)
        per_task_confusions.append(per_task_confusion)
        att_diffs.append(avg_att_diff)
        embedding_drifts.append(embedding_drift)
        att_spreads.append(avg_att_spread)
        output_biases.append(output_bias)
        print(f"Run {r} finished with accuracy {accuracy}")

    report_stats(accuracies, "accuracy")
    report_stats(total_confusions, "total confusion")
    report_stats(intra_phase_confusions, "intra-phase confusion")
    report_stats(per_task_confusions, "per-task confusion")
    report_stats(output_biases, "output bias")
    report_stats(embedding_drifts, "embedding drift")
    report_stats(att_diffs, "attention drift")
    report_stats(att_spreads, "attention spread")

Experiments for MNIST:

In [None]:
run_experiments(dataset='mnist', n_runs=5) # baseline

In [None]:
run_experiments(dataset='mnist', n_runs=5, full_CE=False) # baseline with special CE

In [None]:
run_experiments(dataset='mnist', n_runs=5, with_OOD=True) # baseline with OOD

In [None]:
run_experiments(dataset='mnist', n_runs=5, with_dropout=True) # baseline with dropout

In [None]:
run_experiments(dataset='mnist', n_runs=5, kd_loss=1) # baseline with KD loss

In [None]:
run_experiments(dataset='mnist', n_runs=5, optimiser_type='ogd') # baseline with OGD

In [None]:
run_experiments(dataset='mnist', n_runs=5, optimiser_type='adam') # baseline with Adam

In [None]:
run_experiments(dataset='mnist', n_runs=5, with_dropout=True, kd_loss=1, full_CE=False, with_OOD=True, optimiser_type='ogd') # everything

In [None]:
run_experiments(dataset='mnist', n_runs=5, with_dropout=True, kd_loss=1, full_CE=True, with_OOD=True, optimiser_type='ogd') # everything without special CE

In [None]:
run_experiments(dataset='mnist', n_runs=5, with_dropout=True, kd_loss=1, full_CE=False, with_OOD=False, optimiser_type='ogd') # everything without OOD

In [None]:
run_experiments(dataset='mnist', n_runs=5, with_dropout=False, kd_loss=1, full_CE=False, with_OOD=True, optimiser_type='ogd') # everything without dropouts

In [None]:
run_experiments(dataset='mnist', n_runs=5, with_dropout=True, kd_loss=0, full_CE=False, with_OOD=True, optimiser_type='ogd') # everything without kd loss

In [None]:
run_experiments(dataset='mnist', n_runs=5, with_dropout=True, kd_loss=1, full_CE=False, with_OOD=True, optimiser_type='sgd') # everything with sgd

In [None]:
run_experiments(dataset='mnist', n_runs=5, with_dropout=True, kd_loss=1, full_CE=False, with_OOD=True, optimiser_type='adam') # everything with Adam

In [None]:
run_experiments(dataset='mnist', n_runs=5, joint_training=True) # joint training

Experiments for CIFAR10:

In [None]:
run_experiments(dataset='cifar10', n_runs=3) # baseline

In [None]:
run_experiments(dataset='cifar10', n_runs=3, full_CE=False) # baseline with special CE

In [None]:
run_experiments(dataset='cifar10', n_runs=3, with_OOD=True) # baseline with OOD

In [None]:
run_experiments(dataset='cifar10', n_runs=3, with_dropout=True) # baseline with dropout

In [None]:
run_experiments(dataset='cifar10', n_runs=3, kd_loss=1) # baseline with KD loss

In [None]:
run_experiments(dataset='cifar10', n_runs=3, optimiser_type='ogd') # baseline with OGD

In [None]:
run_experiments(dataset='cifar10', n_runs=3, optimiser_type='adam') # baseline with Adam

In [None]:
run_experiments(dataset='cifar10', n_runs=3, with_dropout=True, kd_loss=1, full_CE=False, with_OOD=True, optimiser_type='ogd') # everything

In [None]:
run_experiments(dataset='cifar10', n_runs=3, with_dropout=True, kd_loss=1, full_CE=True, with_OOD=True, optimiser_type='ogd') # everything without special CE

In [6]:
run_experiments(dataset='cifar10', verbose=True, n_runs=3, with_dropout=True, kd_loss=1, full_CE=False, with_OOD=False, optimiser_type='ogd') # everything without OOD

In [None]:
run_experiments(dataset='cifar10', n_runs=3, with_dropout=False, kd_loss=1, full_CE=False, with_OOD=True, optimiser_type='ogd') # everything without dropouts

In [None]:
run_experiments(dataset='cifar10', n_runs=3, with_dropout=True, kd_loss=0, full_CE=False, with_OOD=True, optimiser_type='ogd') # everything without kd loss

In [None]:
run_experiments(dataset='cifar10', verbose=True, n_runs=3, with_dropout=True, kd_loss=1, full_CE=False, with_OOD=True, optimiser_type='sgd') # everything with sgd

In [None]:
run_experiments(dataset='cifar10', n_runs=3, with_dropout=True, kd_loss=1, full_CE=False, with_OOD=True, optimiser_type='adam') # everything with Adam

In [None]:
run_experiments(dataset='cifar10', n_runs=3, joint_training=True) # joint training

Experiments for tiny imagenet:

In [None]:
run_experiments(dataset='tiny_imagenet', n_runs=3) # baseline

In [None]:
run_experiments(dataset='tiny_imagenet', n_runs=3, full_CE=False) # baseline with special CE

In [None]:
run_experiments(dataset='tiny_imagenet', n_runs=3, with_OOD=True) # baseline with OOD

In [None]:
run_experiments(dataset='tiny_imagenet', n_runs=3, with_dropout=True) # baseline with dropout

In [None]:
run_experiments(dataset='tiny_imagenet', n_runs=3, kd_loss=1) # baseline with KD loss

In [None]:
run_experiments(dataset='tiny_imagenet', n_runs=3, optimiser_type='ogd') # baseline with OGD

In [None]:
run_experiments(dataset='tiny_imagenet', n_runs=3, optimiser_type='adam') # baseline with Adam

In [None]:
run_experiments(dataset='tiny_imagenet', verbose=True, n_runs=3, with_dropout=True, kd_loss=1, full_CE=False, with_OOD=True, optimiser_type='sgd') # everything (no ogd)

In [None]:
run_experiments(dataset='tiny_imagenet', verbose=True, n_runs=3, with_dropout=True, kd_loss=1, full_CE=True, with_OOD=True, optimiser_type='sgd') # everything (no ogd) without special CE

In [None]:
run_experiments(dataset='tiny_imagenet', n_runs=3, with_dropout=True, kd_loss=1, full_CE=False, with_OOD=False, optimiser_type='sgd') # everything (no ogd) without OOD

In [None]:
run_experiments(dataset='tiny_imagenet', n_runs=3, with_dropout=False, kd_loss=1, full_CE=False, with_OOD=True, optimiser_type='sgd') # everything (no ogd) without dropouts

In [None]:
run_experiments(dataset='tiny_imagenet', n_runs=3, with_dropout=True, kd_loss=0, full_CE=False, with_OOD=True, optimiser_type='sgd') # everything (no ogd) without kd loss

In [None]:
run_experiments(dataset='tiny_imagenet', n_runs=3, with_dropout=True, kd_loss=1, full_CE=False, with_OOD=True, optimiser_type='adam') # everything with Adam

In [None]:
run_experiments(dataset='tiny_imagenet', n_runs=3, joint_training=True) # joint training