In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from sklearn import metrics
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Subset
import numpy as np
import copy
import statistics
import matplotlib.pyplot as plt
import time
from tqdm import tqdm

import sys
sys.path.append('./../src')
import globals
from model import Net
from training import train_model, train_model_CL
from visualizations import plot_embeddings, plot_confusion_matrix
from feature_attribution import Feature_Importance_Evaluations
from pytorch_utils import get_features, get_labels
from embedding_measurements import measure_embedding_confusion_knn, measure_embedding_drift

In [None]:
ITERATIONS = globals.ITERATIONS
CLASSES_PER_ITER = globals.CLASSES_PER_ITER
SEED = globals.SEED
DEVICE = globals.DEVICE
full_trainset = globals.full_trainset
trainset = globals.trainset
testset = globals.testset
trainloaders = globals.trainloaders
valloaders = globals.valloaders
testloaders = globals.testloaders

In [None]:
# This is the two-step process used to prepare the
# data for use with the convolutional neural network.

# First step is to convert Python Image Library (PIL) format
# to PyTorch tensors.

# Second step is used to normalize the data by specifying a 
# mean and standard deviation for each of the three channels.
# This will convert the data from [0,1] to [-1,1]

# Normalization of data should help speed up conversion and
# reduce the chance of vanishing gradients with certain 
# activation functions.
def initialize_data():
    transform = transforms.Compose([
        transforms.ToTensor()
        #transforms.Normalize((0.5,), (0.5,))  # Normalizes to mean 0.5 and std 0.5 for the single channel
    ])

    globals.full_trainset = torchvision.datasets.MNIST('./../data/', train=True, download=True,
                                transform=transform)
    targets = np.array(globals.full_trainset.targets)

    # Perform stratified split
    train_indices, val_indices = train_test_split(
        np.arange(len(targets)),
        test_size=0.15,
        stratify=targets
    )

    # Create subsets
    valset = Subset(globals.full_trainset, val_indices)
    globals.trainset = Subset(globals.full_trainset, train_indices)

    globals.testset = torchvision.datasets.MNIST('./../data/', train=False, download=True,
                                transform=transform)

    # Define class pairs for each subset
    class_pairs = [tuple(range(i*CLASSES_PER_ITER,(i+1)*CLASSES_PER_ITER)) for i in range(ITERATIONS)]
    #print(class_pairs)

    # Dictionary to hold data loaders for each subset
    globals.trainloaders = []
    globals.testloaders = []
    globals.valloaders = []
    subset_indices = []
    # Loop over each class pair
    for i, t in enumerate(class_pairs):
        # Get indices of images belonging to the specified class pair
        subs_ind = [idx for idx, (_, label) in enumerate(globals.trainset) if label in list(t)]
        val_subset_indices = [idx for idx, (_, label) in enumerate(valset) if label in list(t)]
        test_subset_indices = [idx for idx, (_, label) in enumerate(globals.testset) if label in list(t)]
        # Create a subset for the current class pair
        train_subset = Subset(globals.trainset, subs_ind)
        globals.trainloaders.append(DataLoader(train_subset, batch_size=globals.BATCH_SIZE, shuffle=True, pin_memory=True, num_workers = 0))

        subset_indices.append(subs_ind)
        
        val_subset = Subset(valset, val_subset_indices)
        globals.valloaders.append(DataLoader(val_subset, batch_size=500, shuffle=False))

        test_subset = Subset(globals.testset, test_subset_indices)
        globals.testloaders.append(DataLoader(test_subset, batch_size=500, shuffle=False))


In [None]:
def run_experiment(
        verbose = False,
        stopOnLoss = 0.03,
        full_CE = True,
        with_OOD = False,
        kd_loss = 0,
        stopOnValAcc = None,
        epochs = 1000000,
        with_dropout = False
        ):
    def _print(*args, **kwargs):
        if verbose:
            print(*args, **kwargs)
    if with_OOD:
        globals.OOD_CLASS = 1
    else:
        globals.OOD_CLASS = 0
    initialize_data()
    prevModel = None
    globals.OOD_CLASS=1
    globals.BATCH_SIZE=4
    globals.WITH_DROPOUT = with_dropout

    #[Denis] added code:
    Feature_Importance_Eval=Feature_Importance_Evaluations(globals.valloaders, DEVICE)

    for i in tqdm(range(ITERATIONS), desc="Experiment Progress"):
        model = Net((i+1)*(CLASSES_PER_ITER+globals.OOD_CLASS))
        if prevModel is not None:
            with torch.no_grad():
                model.copyPrev(prevModel)
        train_loader = globals.trainloaders[i]
        val_loader = globals.valloaders[i]
        if prevModel:
            _print("CL TRAIN!!")
            train_model_CL(
                model,
                prevModel,
                train_loader,
                val_loader,
                i,
                verbose,
                epochs,
                True,
                freeze_nonzero_params=False,
                l1_loss=0,
                ewc_loss=0,
                kd_loss=kd_loss,
                distance_loss=0,
                center_loss=0,
                param_reuse_loss=0,
                stopOnLoss=stopOnLoss,
                stopOnValAcc = stopOnValAcc,
                full_CE=full_CE
                )
        else:
            train_model(
                model, 
                train_loader, 
                val_loader, 
                verbose, 
                epochs=epochs, 
                l1_loss=0,
                stopOnLoss=stopOnLoss,
                center_loss =0,
                )

        #[Denis] added code:
        Feature_Importance_Eval.Task_Feature_Attribution(model, i)
        
        if verbose or i == ITERATIONS-1:
            _print("Starting evaluation")
            _print("ITERATION", i+1)
            _print("ACCURACIES PER TASK:")
            accumPred = []
            all_labels = []
            all_embeddings = []
            with torch.no_grad():
                for j in range(i+1):
                    val_loader = globals.testloaders[j]
                    val_labels = get_labels(val_loader).to(DEVICE)
                    all_labels.append(val_labels)
                    model.eval()
                    pred, embeddings = model.get_pred_and_embeddings((get_features(val_loader).to(DEVICE)))
                    model.train()
                    accumPred.append(pred)
                    all_embeddings.append(embeddings)
                    sliced_pred = pred[:, j*(CLASSES_PER_ITER+globals.OOD_CLASS):(j+1)*(CLASSES_PER_ITER+globals.OOD_CLASS)]
                    _, predicted = torch.max(sliced_pred, 1)  # Get the class predictions
                    predicted += j*CLASSES_PER_ITER
                    correct = (predicted == val_labels).sum().item()  # Count how many were correct
                    accuracy = correct / val_labels.size(0)  # Accuracy as a percentage
                    _print(str(accuracy), end=' ')
            accumPred = torch.cat(accumPred)
            all_labels = torch.cat(all_labels)
            all_embeddings = torch.cat(all_embeddings)
            predicted = []
            for x in accumPred:
                if globals.OOD_CLASS == 1:
                    x_pred = x[[i for i in range(x.size(0)) if (i + 1) % (CLASSES_PER_ITER+1) != 0]]
                else:
                    x_pred = x
                x_pred = torch.softmax(x_pred, dim=-1)
                max = 0
                for (k, v) in enumerate(x_pred):
                    if v > max:
                        max = v
                        p = k
                predicted.append(p)
            predicted = torch.tensor(predicted).to(DEVICE)
            correct = (predicted == all_labels).sum().item()  # Count how many were correct
            accuracy = correct / all_labels.size(0)  # Accuracy as a percentage
            _print("Accuracy on tasks so far:", accuracy)

            embedding_drift = measure_embedding_drift(all_embeddings, all_labels, model.prev_test_embedding_centers)
            _print("Average embedding drift based on centroids:", embedding_drift)
            total_confusion, intra_phase_confusion, per_task_confusion = measure_embedding_confusion_knn(all_embeddings, all_labels, k = 500, task=i+1)
            _print("Total confusion", total_confusion)
            _print("Intra-phase confusion", intra_phase_confusion)
            _print("Per task confusions", per_task_confusion)
            if verbose:
                plot_confusion_matrix(predicted.cpu(), all_labels.cpu(), list(range(CLASSES_PER_ITER*(i+1))))
        prevModel = copy.deepcopy(model)
        
    #[Denis] added code:
    avg_shap_val,shap_vals=Feature_Importance_Eval.Get_Feature_Change_Score(prevModel)
    _print("Average SHAPC values (ordered as tasks):", shap_vals)
    _print("Averaged SHAPC value:", avg_shap_val)
    
    return accuracy, total_confusion, intra_phase_confusion, per_task_confusion, embedding_drift, avg_shap_val

In [None]:
def run_experiments(n_runs=1, *args, **kwargs):
    verbose = kwargs.get('verbose', None)
    def _print(*args, **kwargs):
        if verbose:
            print(*args, **kwargs)
    accuracies = []
    total_confusions = []
    intra_phase_confusions = []
    per_task_confusions = []
    shap_vals = []
    embedding_drifts = []
    for r in range(n_runs):
        print(f"Starting run {r+1}.")
        accuracy, total_confusion, intra_phase_confusion, per_task_confusion, embedding_drift, avg_shap_val = run_experiment(*args, **kwargs)
        accuracies.append(accuracy)
        total_confusions.append(total_confusion)
        intra_phase_confusions.append(intra_phase_confusion)
        per_task_confusions.append(per_task_confusion)
        shap_vals.append(avg_shap_val)
        embedding_drifts.append(embedding_drift)
        _print(f"Run {r} finished with accuracy {accuracy}")
    # Calculate means and standard deviations for each measure
    mean_acc = statistics.mean(accuracies)
    acc_std = statistics.stdev(accuracies)

    mean_total_confusion = statistics.mean(total_confusions)
    total_confusion_std = statistics.stdev(total_confusions)

    mean_intra_phase_confusion = statistics.mean(intra_phase_confusions)
    intra_phase_confusion_std = statistics.stdev(intra_phase_confusions)

    mean_per_task_confusion = statistics.mean(per_task_confusions)
    per_task_confusion_std = statistics.stdev(per_task_confusions)

    mean_embedding_drift = statistics.mean(embedding_drifts)
    embedding_drift_std = statistics.stdev(embedding_drifts)

    mean_shap_val = statistics.mean(shap_vals)
    shap_val_std = statistics.stdev(shap_vals)

    # Print all results
    #print("Accuracies:")
    #print(accuracies)
    print(f"Mean accuracy across {n_runs} runs: {mean_acc}")
    print(f"Standard deviation of accuracy across {n_runs} runs: {acc_std}\n")

    #print("Total Confusions:")
    #print(total_confusions)
    print(f"Mean total confusion across {n_runs} runs: {mean_total_confusion}")
    print(f"Standard deviation of total confusion across {n_runs} runs: {total_confusion_std}\n")

    #print("Intra-Phase Confusions:")
    #print(intra_phase_confusions)
    print(f"Mean intra-phase confusion across {n_runs} runs: {mean_intra_phase_confusion}")
    print(f"Standard deviation of intra-phase confusion across {n_runs} runs: {intra_phase_confusion_std}\n")

    #print("Per-Task Confusions:")
    #print(per_task_confusions)
    print(f"Mean per-task confusion across {n_runs} runs: {mean_per_task_confusion}")
    print(f"Standard deviation of per-task confusion across {n_runs} runs: {per_task_confusion_std}\n")

    #print("Embedding drifts:")
    #print(embedding_drifts)
    print(f"Mean embedding drift across {n_runs} runs: {mean_embedding_drift}")
    print(f"Standard deviation of embedding drift across {n_runs} runs: {embedding_drift_std}\n")

    #print("SHAP Values:")
    #print(shap_vals)
    print(f"Mean SHAP values across {n_runs} runs: {mean_shap_val}")
    print(f"Standard deviation of SHAP values across {n_runs} runs: {shap_val_std}\n")

In [None]:
run_experiments(n_runs=5, verbose=False, stopOnLoss = 0.02)

In [None]:
run_experiments(n_runs=5, verbose=False, stopOnLoss = 0.02, with_dropout=True)

In [None]:
run_experiments(n_runs=5, verbose=False, stopOnLoss = 0.02, kd_loss=1)

In [None]:
run_experiments(n_runs=5, verbose=False, stopOnLoss = 0.02, full_CE=False)

In [None]:
run_experiments(n_runs=5, verbose=False, stopOnLoss = 0.02, with_OOD=True)

In [None]:
run_experiments(n_runs=5, verbose=False, stopOnLoss = 0.02, full_CE = False, with_OOD=True)

In [None]:
run_experiments(n_runs=5, verbose=False, stopOnLoss = 0.02, full_CE = False, kd_loss = 1, with_OOD=True)

In [None]:
run_experiments(n_runs=5, verbose=False, stopOnLoss = 0.02, full_CE = False, kd_loss = 1)

In [None]:
run_experiments(n_runs=5, verbose=False, stopOnLoss = 0.02, full_CE = False, kd_loss = 1, with_dropout=True)

In [None]:
run_experiments(n_runs=5, verbose=False, stopOnLoss = 0.02, full_CE = False, kd_loss = 1, with_OOD=True, with_dropout=True)