In [None]:
#Training pipeline for ShapleyNAS implementation by Tomas Slaven of University Of Cape Town
#ShapleyNAS architecture search derived from https://github.com/euphoria16/shapley-nas

# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
import pandas as pd
import numpy as np
from PIL import Image
import ast
import argparse
import os

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

import sys
# Append a custom path to the system's search path
sys.path.append('/content/drive/My Drive/ShapleyNAS/')

# Import custom modules and functions
import utils
from model_search import Network
import genotypes
from utils import drop_path

# Declare global variables for mean accuracy and macro-f1
global highest_mean_accuracy
global highest_f1_score
highest_mean_accuracy = 0.0
highest_f1_score = 0.0



# Define a function to change the alpha parameter in the model
def change_alpha(model, shap_values, accu_shap_values, momentum=0.8, step_size=0.1):
    assert len(shap_values) == len(model.arch_parameters())

    shap = [torch.from_numpy(shap_values[i]).cuda() for i in range(len(model.arch_parameters()))]

    for i, params in enumerate(shap):
        mean = params.data.mean()
        std = params.data.std()
        params.data.add_(-mean).div_(std)

    updated_shap = [
        accu_shap_values[i] * momentum \
        + shap[i] * (1. - momentum)
        for i in range(len(model.arch_parameters()))
    ]

    for i, p in enumerate(model.arch_parameters()):
        p.data.add_((step_size * updated_shap[i]).to(p.device))

    return updated_shap


# Define a function to remove players (weights) from the model
def remove_players(normal_weights, reduce_weights, op):
    selected_cell = str(op.split('_')[0])
    selected_eid = int(op.split('_')[1])
    opid = int(op.split('_')[-1])
    proj_mask = torch.ones_like(normal_weights[selected_eid])
    proj_mask[opid] = 0
    if selected_cell in ['normal']:
        normal_weights[selected_eid] = normal_weights[selected_eid] * proj_mask
    else:
        reduce_weights[selected_eid] = reduce_weights[selected_eid] * proj_mask


# Define a function for Shapley value estimation
def shap_estimation(valid_queue, model, criterion, players, num_samples, threshold=0.5, batch_size=64):
    """
    Implementation of Monte-Carlo sampling of Shapley value for operation importance evaluation

    """
    permutations = None
    n = len(players)
    sv_acc = np.zeros((n,num_samples))

    with torch.no_grad():

        if permutations is None:
            permutations = [np.random.permutation(n) for _ in range(num_samples)]

        for j in range(num_samples):
            x, y = next(iter(valid_queue))
            x, y = x.cuda(), y.cuda(non_blocking=True)
            logits = model(x, weights_dict=None)

            predicted_labels = (torch.sigmoid(logits) > 0.5).float()
            total_correct = (predicted_labels == y).all(dim=1).sum().item()
            old_acc = total_correct/batch_size
            acc = old_acc
            normal_weights = model.get_projected_weights('normal')
            reduce_weights = model.get_projected_weights('reduce')

            print('MC sampling %d times' % (j+1))
            for idx,i in enumerate(permutations[j]):
                remove_players(normal_weights, reduce_weights, players[i])

                logits = model(x,  weights_dict={'normal': normal_weights,'reduce':reduce_weights})

                predicted_labels = (torch.sigmoid(logits) > 0.5).float()
                total_correct = (predicted_labels == y).all(dim=1).sum().item()

                new_acc = total_correct/batch_size
                delta_acc=acc-new_acc
                sv_acc[i][j]=delta_acc
                acc=new_acc

                if acc < threshold * old_acc:
                    break

        result = np.mean(sv_acc,axis=-1)-np.std(sv_acc,axis=-1)
        shap_acc = np.reshape(result, (2, model.num_edges, model.num_ops))
        shap_normal, shap_reduce = shap_acc[0], shap_acc[1]
        return shap_normal, shap_reduce


# Custom dataset class
class CustomDataset(Dataset):
    def __init__(self, image_dir, csv_file, transform=None):
      # Initialize the dataset with image directory, CSV file, and optional transformations
        self.image_dir = image_dir
        self.data = pd.read_csv(csv_file, encoding="utf-8")
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image_id = self.data.iloc[idx, 1]
        image_path = f"{self.image_dir}/{image_id}.jpeg"
        image = Image.open(image_path).convert("RGB")
        label = torch.tensor(self.data.iloc[idx, 2:], dtype=torch.float32)

        if self.transform:
            image = self.transform(image)

        return image, label


# Define a function to load data and create data loaders
def load_data(image_dir, train_csv, validation_csv, test_csv, batch_size, resolution):

    # Load and preprocess data, and create data loaders
    transform = transforms.Compose(
        [
            transforms.Resize((resolution, resolution)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.3721, 0.3721, 0.3721], std=[0.1801, 0.1801, 0.1801]),
        ])

    train_dataset = CustomDataset(image_dir, train_csv, transform)
    validation_dataset = CustomDataset(image_dir, validation_csv, transform)
    test_dataset = CustomDataset(image_dir, test_csv, transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
    validation_loader = DataLoader(validation_dataset, batch_size=batch_size, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, pin_memory=True)

    return train_loader, validation_loader, test_loader

# Define a function for training one epoch of the model
def train_epoch(model, criterion, optimizer, train_loader,
    device, epoch):
    model.to(device)
    model.train()
    train_loss = 0.0
    total_correct = 0
    label_correct = [0] * args.num_classes
    label_total = [0] * args.num_classes
    TP = [0] * args.num_classes
    FP = [0] * args.num_classes
    TN = [0] * args.num_classes
    FN = [0] * args.num_classes
    num_batch = 0
    batches = len(train_loader.dataset) // train_loader.batch_size
    print("Epoch: ", epoch)
    print("Number of batches:", batches )
    precision, specificity, sensi, FNR, F1, label_accuracy =  [], [], [], [], [], []

    for images, labels in train_loader:
        num_batch += 1
        images = images.to(device)
        labels = labels.to(device)
        print("Batch Number:", num_batch)

        optimizer.zero_grad()
        outputs = model(images)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # Calculate accuracies
        train_loss += loss.item() * images.size(0)
        predicted_labels = (torch.sigmoid(outputs) > 0.5).float()
        total_correct += (predicted_labels == labels).all(dim=1).sum().item()
        correct_per_label = (predicted_labels == labels).sum(dim=0).tolist()

        for i in range(args.num_classes):
            label_correct[i] += correct_per_label[i]
            label_total[i] += len(images)

            #calc TP and FP
            true_positive = ((predicted_labels[:, i] == 1) & (labels[:, i] == 1)).sum().item()
            false_positive = ((predicted_labels[:, i] == 1) & (labels[:, i] == 0)).sum().item()
            true_negative = ((predicted_labels[:, i] == 0) & (labels[:, i] == 0)).sum().item()
            false_negative = ((predicted_labels[:, i] == 0) & (labels[:, i] == 1)).sum().item()

            TP[i] += true_positive
            FP[i] += false_positive
            TN[i] += true_negative
            FN[i] += false_negative

    #Calculate Label Specific Metrics
    for i in range(args.num_classes):
        label_accuracy.append((label_correct[i] / label_total[i] if label_total[i] > 0 else 0)*100)
        precision.append((TP[i] / (TP[i] + FP[i]) if TP[i] + FP[i] > 0 else 0)*100)
        sensi.append((TP[i] / (TP[i] + FN[i]) if TP[i] + FN[i] > 0 else 0)*100)
        specificity.append((TN[i] / (TN[i] + FP[i]) if TN[i] + FP[i] > 0 else 0)*100)
        FNR.append((FN[i] / (TP[i] + FN[i]) if (TP[i] + FN[i]) > 0 else 0)*100)
        F1.append( 2 * (precision[i] * sensi[i]) / (precision[i] + sensi[i]) if precision[i] + sensi[i] > 0 else 0)


    # Calculate macro-averages and overall training accuracy
    train_loss /= len(train_loader.dataset)
    overall_accuracy = total_correct / len(train_loader.dataset)
    train_accuracy = overall_accuracy * 100  # Multiply by 100 to get percentage
    label_accuracy.insert(0, train_accuracy )
    precision.insert(0, (sum(precision) / args.num_classes))
    sensi.insert(0, (sum(sensi) / args.num_classes))
    specificity.insert(0, (sum(specificity) / args.num_classes))
    FNR.insert(0, (sum(FNR) / args.num_classes))
    F1_score = (sum(F1) / args.num_classes)
    F1.insert(0, (sum(F1) / args.num_classes))

    # Create a dictionary for training metrics and save to DataFrame
    train_metrics_dict = {
        "Epoch": [epoch] * len(label_names),
        "Type (TRAIN)": label_names,
        "Mean Accuracy (TRAIN)": label_accuracy,
        "Precision (TRAIN)": precision,
        "Sensitivity (TRAIN)": sensi,
        "Specificity (TRAIN)": specificity,
        "FNR (TRAIN)": FNR,
        "F1 Score (TRAIN)": F1,
    }
    print(f"Overall Accuracy: {train_accuracy:.4f}")
    print(f"Train Loss: {train_loss:.4f}")
    print(f"F1 Score: {F1_score:.4f}")
    print("")

    a = pd.DataFrame(train_metrics_dict)

    return a

# Define a function to validate the model
def validate_model(model, criterion, validation_loader, device, output_folder_path, epoch):

    model.to(DEVICE)
    model.eval()
    global highest_mean_accuracy
    global highest_f1_score


    with torch.no_grad():
        validation_loss = 0.0
        total_correct = 0
        label_correct = [0] * args.num_classes
        label_total = [0] * args.num_classes
        TP = [0] * args.num_classes
        FP = [0] * args.num_classes
        TN = [0] * args.num_classes
        FN = [0] * args.num_classes
        precision, specificity, sensi, FNR, F1, label_accuracy =  [], [], [], [], [], []

        for images, labels in validation_loader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)
            validation_loss += loss.item() * images.size(0)


            # Calculate accuracies
            predicted_labels = (torch.sigmoid(outputs) > 0.5).float()
            total_correct += (predicted_labels == labels).all(dim=1).sum().item()
            correct_per_label = (predicted_labels == labels).sum(dim=0).tolist()

            for i in range(args.num_classes):
                label_correct[i] += correct_per_label[i]
                label_total[i] += len(images)

                #calc TP and FP
                true_positive = ((predicted_labels[:, i] == 1) & (labels[:, i] == 1)).sum().item()
                false_positive = ((predicted_labels[:, i] == 1) & (labels[:, i] == 0)).sum().item()
                true_negative = ((predicted_labels[:, i] == 0) & (labels[:, i] == 0)).sum().item()
                false_negative = ((predicted_labels[:, i] == 0) & (labels[:, i] == 1)).sum().item()

                TP[i] += true_positive
                FP[i] += false_positive
                TN[i] += true_negative
                FN[i] += false_negative

        #Calculate Label Specific Metrics
        for i in range(args.num_classes):
            label_accuracy.append((label_correct[i] / label_total[i] if label_total[i] > 0 else 0)*100)
            precision.append((TP[i] / (TP[i] + FP[i]) if TP[i] + FP[i] > 0 else 0)*100)
            sensi.append((TP[i] / (TP[i] + FN[i]) if TP[i] + FN[i] > 0 else 0)*100)
            specificity.append((TN[i] / (TN[i] + FP[i]) if TN[i] + FP[i] > 0 else 0)*100)
            FNR.append((FN[i] / (TP[i] + FN[i]) if (TP[i] + FN[i]) > 0 else 0)*100)
            F1.append( 2 * (precision[i] * sensi[i]) / (precision[i] + sensi[i]) if precision[i] + sensi[i] > 0 else 0)

        # Calculate macro-averages and overall validation accuracy
        validation_loss /= len(validation_loader.dataset)
        overall_accuracy = total_correct / len(validation_loader.dataset)
        validation_accuracy = (overall_accuracy * 100)
        label_accuracy.insert(0, validation_accuracy )
        precision.insert(0, (sum(precision) / args.num_classes))
        sensi.insert(0, (sum(sensi) / args.num_classes))
        specificity.insert(0, (sum(specificity) / args.num_classes))
        FNR.insert(0, (sum(FNR) / args.num_classes))
        F1.insert(0, (sum(F1) / args.num_classes))

        #Create a dictionary for validation metrics and save to DataFrame
        val_metrics_dict = {
            "Epoch": [epoch] * len(label_names),
            "Type (VAL)": label_names,
            "Mean Accuracy (VAL)": label_accuracy,
            "Precision (VAL)": precision,
            "Sensitivity (VAL)": sensi,
            "Specificity (VAL)": specificity,
            "FNR (VAL)": FNR,
            "F1 Score (VAL)": F1,
        }

        mean_accuracy = validation_accuracy
        f1_score_mean = F1[0]

        # Update the highest mean accuracy and save the model if needed
        if ((mean_accuracy > highest_mean_accuracy) or (highest_mean_accuracy == 0.0)):
            highest_mean_accuracy = mean_accuracy
            model_save_path = os.path.join(output_folder_path, 'model_highest_mean_accuracy.pth')
            torch.save(model.state_dict(), model_save_path)

        # Update the highest F1 score and save the model if needed
        if ((f1_score_mean > highest_f1_score) or (highest_f1_score == 0.0)):
            print("New F1 model")
            highest_f1_score = f1_score_mean
            model_save_path = os.path.join(output_folder_path, 'model_highest_f1_score.pth')
            torch.save(model.state_dict(), model_save_path)

        # Print label accuracies and overall accuracy
        print(f"Validation Overall Accuracy: {validation_accuracy:.4f}")
        print(f"Validation Loss: {validation_loss:.4f}")
        print("-------------------------------------------------")
        print("")

    return val_metrics_dict


# Define a function to test the model
def test_model(model, criterion, test_loader, device):
    model.eval()
    test_loss = 0.0

    with torch.no_grad():
        total_correct = 0
        label_correct = [0] * args.num_classes
        label_total = [0] * args.num_classes
        TP = [0] * args.num_classes
        FP = [0] * args.num_classes
        TN = [0] * args.num_classes
        FN = [0] * args.num_classes
        precision, specificity, sensi, FNR, F1, label_accuracy =  [], [], [], [], [], []


        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item() * images.size(0)

            # Calculate accuracies
            predicted_labels = (torch.sigmoid(outputs) > 0.5).float()
            total_correct += (predicted_labels == labels).all(dim=1).sum().item()
            correct_per_label = (predicted_labels == labels).sum(dim=0).tolist()

            for i in range(args.num_classes):
                label_correct[i] += correct_per_label[i]
                label_total[i] += len(images)

                #calc TP and FP
                true_positive = ((predicted_labels[:, i] == 1) & (labels[:, i] == 1)).sum().item()
                false_positive = ((predicted_labels[:, i] == 1) & (labels[:, i] == 0)).sum().item()
                true_negative = ((predicted_labels[:, i] == 0) & (labels[:, i] == 0)).sum().item()
                false_negative = ((predicted_labels[:, i] == 0) & (labels[:, i] == 1)).sum().item()

                TP[i] += true_positive
                FP[i] += false_positive
                TN[i] += true_negative
                FN[i] += false_negative

       #Calculate Label Specific Metrics
        for i in range(args.num_classes):
            label_accuracy.append(
                (label_correct[i] / label_total[i] if label_total[i] > 0 else 0)*100
            )
            precision.append((TP[i] / (TP[i] + FP[i]) if TP[i] + FP[i] > 0 else 0)*100)
            sensi.append((TP[i] / (TP[i] + FN[i]) if TP[i] + FN[i] > 0 else 0)*100)
            specificity.append((TN[i] / (TN[i] + FP[i]) if TN[i] + FP[i] > 0 else 0)*100)
            FNR.append((FN[i] / (TP[i] + FN[i]) if (TP[i] + FN[i]) > 0 else 0)*100)
            F1.append( 2 * (precision[i] * sensi[i]) / (precision[i] + sensi[i]) if precision[i] + sensi[i] > 0 else 0)

        # Calculate overall test accuracy and macro averages
        overall_accuracy = total_correct / len(test_loader.dataset)
        test_loss /= len(test_loader.dataset)
        test_accuracy = (overall_accuracy * 100)  # Multiply by 100 to get percentage
        label_accuracy.insert(0, test_accuracy )
        precision.insert(0, (sum(precision) / args.num_classes))
        sensi.insert(0, (sum(sensi) / args.num_classes))
        specificity.insert(0, (sum(specificity) / args.num_classes))
        FNR.insert(0, (sum(FNR) / args.num_classes))
        F1_score = (sum(F1) / args.num_classes)
        F1.insert(0, (sum(F1) / args.num_classes))
        epoch = 1

        #Create a dictionary for test metrics and save to DataFrame
        test_metrics_dict = {
            "Epoch": [epoch] * len(label_names),
            "Type (TEST)": label_names,
            "Mean Accuracy (TEST)": label_accuracy,
            "Precision (TEST)": precision,
            "Sensitivity (TEST)": sensi,
            "Specificity (TEST)": specificity,
            "FNR (TEST)": FNR,
            "F1 Score (TEST)": F1,
        }

        print(f"Test Overall Accuracy: {test_accuracy:.4f}")
        print(f"Test Loss: {test_loss:.4f}")
        print(f"F1 Score: {F1_score:.4f}")
        print("-------------------------------------------------")

    return test_metrics_dict

# Function to calculate class weights based on a CSV file
def calcWeights(csv_path):
    train_df = pd.read_csv(csv_path)
    columns = train_df.keys()
    columns = list(columns)
    columns.remove("Num")
    columns.remove("ID")
    pos_count = []
    neg_count = []
    pos_weights = []
    total = 746
    for column in columns:
        pos_count.append(train_df[column].sum())
        neg_count.append(total - (train_df[column].sum()))
        pos_weights.append((total - (train_df[column].sum())) / total)

    return pos_weights


if __name__ == "__main__":

    torch.cuda.empty_cache()
    args = argparse.Namespace(
        image_dir="/content/drive/MyDrive/Medical data/Neck/Unlocalized/supervised data/Images",
        train_csv="/content/drive/MyDrive/Medical data/Neck/Unlocalized/supervised data/train.csv",
        validation_csv="/content/drive/MyDrive/Medical data/Neck/Unlocalized/supervised data/val.csv",
        test_csv="/content/drive/MyDrive/Medical data/Neck/Unlocalized/supervised data/test.csv",
        num_classes=8,
        output_dir="/content/drive/My Drive/ShapleyNeckResults",
        output_file_prefix="shap",
        output_file_extension=".csv",
        )

    print("=======================================================")
    print("Arguments:")
    for arg, value in vars(args).items():
        print(f"{arg}: {value}")
    print("=======================================================")
    print("")

    #Initializations
    results_df = pd.DataFrame()
    label_names = [
        "Overall (Macro-Average)",
        "Alignment",
        "Soft_tissue_Swelling",
        "Listhesis",
        "Fracture",
        "Dislocation",
        "Spinous",
        "Other_Pathogens",
        "normal",
    ]
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    pos_weights = torch.tensor(calcWeights(args.train_csv), device=DEVICE)

    # Create the output directory if it doesn't exist
    os.makedirs(args.output_dir, exist_ok=True)

    #Initial Hyperparameter Grid
    hyperparameter_grid = [
        # {"bs": 32, "lr": 0.001, "res": 64, "epochs": 2},
        # {"bs": 32, "lr": 0.001, "res": 128, "epochs": 2},
        # {"bs": 32, "lr": 0.01, "res": 64, "epochs": 10},
        # {"bs": 32, "lr": 0.01, "res": 128, "epochs": 10},
        # {"bs": 64, "lr": 0.01, "res": 64, "epochs": 10},
        # {"bs": 64, "lr": 0.01, "res": 128, "epochs": 15},
        # {"bs": 32, "lr": 0.02, "res": 64, "epochs": 15},
        {"bs": 32, "lr": 0.0002, "res": 64, "epochs": 8},
        {"bs": 32, "lr": 0.0002, "res": 128, "epochs": 8},
        # {"bs": 64, "lr": 0.02, "res": 64, "epochs": 15},
        # {"bs": 64, "lr": 0.02, "res": 128, "epochs": 15},
        # {"bs": 32, "lr": 0.015, "res": 96, "epochs": 15},
        # Add more combinations if desired
    ]

    for hyperparams in hyperparameter_grid:
        #Initialize hyperparamters
        batch_size = hyperparams["bs"]
        learning_rate = hyperparams["lr"]
        resolution = hyperparams["res"]
        epochs = hyperparams["epochs"]

        #initialize output directory
        output_df = pd.DataFrame()
        folder_name = f"bs{batch_size}_lr{learning_rate:.5f}_res{resolution}_ep{epochs}"
        output_folder_path = os.path.join(args.output_dir, folder_name)
        os.makedirs(output_folder_path, exist_ok=True)
        output_file_path = os.path.join(output_folder_path, "output.csv")

        #Print Current Hyperparameters
        print("Batch size:", batch_size)
        print("Learning Rate:", learning_rate)
        print("Resolution:", resolution)
        print("")

        # Load data with the current hyperparameter settings
        train_loader, validation_loader, test_loader = load_data(
            args.image_dir,
            args.train_csv,
            args.validation_csv,
            args.test_csv,
            batch_size,
            resolution,
        )

        #START OF SHAPLEY NAS train and search procedure
        np.random.seed(2)
        torch.manual_seed(2)
        torch.cuda.manual_seed(2)

        criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weights)
        criterion = criterion.cuda()
        model = Network(16, 8, 50, criterion)
        arch_params = list(map(id, model.arch_parameters()))
        weight_params = filter(lambda p: id(p) not in arch_params, model.parameters())
        optimizer = torch.optim.SGD(
            weight_params,
            learning_rate,
            momentum=0.9,
            weight_decay=3e-4)

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, float(epochs), eta_min=0.0)

        ops = []
        for cell_type in ['normal','reduce']:
            for edge in range(model.num_edges):
                ops.append(['{}_{}_{}'.format(cell_type, edge, i) for i in
                                range(0, model.num_ops)])
        ops = np.concatenate(ops)

        pretrain_epochs=25
        train_epochs = [5,  epochs]

        epoch = 0
        accum_shaps = [1e-3 * torch.randn(model.num_edges, model.num_ops).cuda(),1e-3 * torch.randn(model.num_edges, model.num_ops).cuda()]

        train_df = pd.DataFrame()
        val_df = pd.DataFrame()

        model = model.cuda()
        for i, current_epochs in enumerate(train_epochs):
            print("i:",i)
            print("current_epochs:", current_epochs)

            for e in range(current_epochs):
              print("e:",e)
              scheduler.step()
              lr = scheduler.get_lr()[0]


              model.show_arch_parameters()
              genotype = model.genotype()


              if i == len(train_epochs)-1:


                  shap_normal, shap_reduce = shap_estimation(validation_loader, model, criterion,
                                                                ops, num_samples=5, threshold=0.5)
                  accum_shaps = change_alpha(model, [shap_normal, shap_reduce], accum_shaps, momentum=0.8, step_size=0.1)

              torch.cuda.empty_cache()
              a_df = train_epoch(model, criterion, optimizer, train_loader, DEVICE, epoch)
              train_df = pd.concat([train_df, a_df], axis=0)

              torch.cuda.empty_cache()
              val_metrics_dict = validate_model(
                  model, criterion, validation_loader, DEVICE, output_folder_path, epoch
              )
              b = pd.DataFrame(val_metrics_dict)
              val_df = pd.concat([val_df, b], axis=0)
              epoch += 1

        train_val_df = pd.concat([train_df, val_df], axis=1)

        model.show_arch_parameters()
        del model


        # Test the model on best Mean Acc
        best_mean_accuracy_model = Network(16, 8, 50, criterion)
        best_mean_accuracy_model.load_state_dict(torch.load(os.path.join(output_folder_path, 'model_highest_mean_accuracy.pth')))
        best_mean_accuracy_model.to(DEVICE)
        mean_acc_test_metrics_dict = test_model(
            best_mean_accuracy_model, criterion, test_loader, DEVICE
        )
        del best_mean_accuracy_model


        # Test the model on best F1 Macro-Average
        best_f1_score_model = Network(16, 8, 50, criterion)
        best_f1_score_model.load_state_dict(torch.load(os.path.join(output_folder_path, 'model_highest_f1_score.pth')))
        best_f1_score_model.to(DEVICE)
        macro_F1_test_metrics_dict = test_model(
            best_f1_score_model, criterion, test_loader, DEVICE
        )
        del best_f1_score_model

        #accumulate results
        mean_acc_test_df = pd.DataFrame(mean_acc_test_metrics_dict)
        mean_F1_test_df = pd.DataFrame(macro_F1_test_metrics_dict)
        test_df = pd.concat([mean_acc_test_df, mean_F1_test_df ], axis=1)
        csv_df = pd.concat([train_val_df, test_df ], axis=1)

        #save results
        csv_df.to_csv(output_file_path, index=False)

        #Find best Mean Accuracy and F1 Score
        overall_mean_acc = mean_acc_test_metrics_dict["Mean Accuracy (TEST)"][0]
        mac_f1 =  mean_acc_test_metrics_dict["F1 Score (TEST)"][0]
        overall_macro_F1 = macro_F1_test_metrics_dict["F1 Score (TEST)"][0]
        mac_acc = macro_F1_test_metrics_dict["Mean Accuracy (TEST)"][0]
        if mac_acc > overall_mean_acc:
            overall_mean_acc = mac_acc

        if mac_f1 > overall_macro_F1:
            overall_macro_F1 = mac_f1

        #Save Best Results to dictionary
        results_dict = {
        "hyperparameters": hyperparams,
        "test_metrics_mean_acc": overall_mean_acc,
        "test_metrics_macro_F1": overall_macro_F1
        }

        #Append dictionary to other results for same hyperparameter search
        a_df = pd.DataFrame(results_dict)
        results_df = pd.concat([results_df, a_df ], axis=0)

    #initilize hyperparameter results file
    output_file_name = "GridSearchResults"
    output_file_path = os.path.join(args.output_dir, output_file_name + ".csv")

    # Create a DataFrame from the results and save to a CSV file
    results_df.to_csv(output_file_path, index=False)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Arguments:
image_dir: /content/drive/MyDrive/Medical data/Neck/Unlocalized/supervised data/Images
train_csv: /content/drive/MyDrive/Medical data/Neck/Unlocalized/supervised data/train.csv
validation_csv: /content/drive/MyDrive/Medical data/Neck/Unlocalized/supervised data/val.csv
test_csv: /content/drive/MyDrive/Medical data/Neck/Unlocalized/supervised data/test.csv
num_classes: 8
num_epochs: 1
batch_sizes: 32
learning_rates: [0.0075]
resolutions: [224]
output_dir: /content/drive/My Drive/ShapleyNeckResults
output_file_prefix: shap
output_file_extension: .csv

Batch size: 32
Learning Rate: 0.0002
Resolution: 64

i: 0
current_epochs: 5
e: 0




Epoch:  0
Number of batches: 23
Batch Number: 1
Batch Number: 2
Batch Number: 3
Batch Number: 4
Batch Number: 5
Batch Number: 6
Batch Number: 7
Batch Number: 8
Batch Number: 9
Batch Number: 10
Batch Number: 11
Batch Number: 12
Batch Number: 13
Batch Number: 14
Batch Number: 15
Batch Number: 16
Batch Number: 17
Batch Number: 18
Batch Number: 19
Batch Number: 20
Batch Number: 21
Batch Number: 22
Batch Number: 23
Batch Number: 24
Overall Accuracy: 0.1340
Train Loss: 0.6581
F1 Score: 21.8184

tensor([[1., 0., 0., 0., 1., 0., 1., 0.],
        [1., 0., 0., 0., 1., 0., 1., 0.],
        [1., 0., 0., 0., 1., 0., 1., 0.],
        [0., 0., 0., 0., 1., 0., 1., 0.],
        [1., 0., 0., 1., 1., 0., 1., 0.],
        [1., 0., 0., 0., 1., 0., 1., 0.],
        [1., 0., 0., 0., 1., 0., 1., 0.],
        [1., 0., 0., 0., 1., 0., 1., 0.],
        [1., 0., 0., 0., 1., 0., 1., 0.],
        [1., 0., 0., 0., 1., 0., 1., 0.],
        [1., 0., 0., 0., 1., 0., 1., 0.],
        [0., 0., 0., 1., 1., 0., 1., 0.],
  



Epoch:  1
Number of batches: 23
Batch Number: 1
Batch Number: 2
Batch Number: 3
Batch Number: 4
Batch Number: 5
Batch Number: 6
Batch Number: 7
Batch Number: 8
Batch Number: 9
Batch Number: 10
Batch Number: 11
Batch Number: 12
Batch Number: 13
Batch Number: 14
Batch Number: 15
Batch Number: 16
Batch Number: 17
Batch Number: 18
Batch Number: 19
Batch Number: 20
Batch Number: 21
Batch Number: 22
Batch Number: 23
Batch Number: 24
Overall Accuracy: 2.5469
Train Loss: 0.5760
F1 Score: 16.8178

tensor([[1., 0., 0., 0., 0., 0., 1., 0.],
        [1., 0., 0., 0., 1., 0., 1., 0.],
        [1., 0., 0., 0., 1., 0., 1., 0.],
        [1., 1., 0., 0., 1., 0., 1., 0.],
        [1., 1., 0., 0., 0., 0., 1., 0.],
        [1., 0., 0., 0., 1., 0., 1., 0.],
        [1., 1., 0., 0., 1., 0., 1., 0.],
        [0., 1., 0., 0., 1., 0., 1., 0.],
        [1., 0., 0., 0., 1., 0., 1., 0.],
        [1., 0., 0., 0., 1., 0., 1., 0.],
        [0., 0., 0., 0., 1., 0., 1., 0.],
        [1., 0., 0., 0., 1., 0., 1., 0.],
  



Epoch:  2
Number of batches: 23
Batch Number: 1
Batch Number: 2
Batch Number: 3
Batch Number: 4
Batch Number: 5
Batch Number: 6
Batch Number: 7
Batch Number: 8
Batch Number: 9
Batch Number: 10
Batch Number: 11
Batch Number: 12
Batch Number: 13
Batch Number: 14
Batch Number: 15
Batch Number: 16
Batch Number: 17
Batch Number: 18
Batch Number: 19
Batch Number: 20
Batch Number: 21
Batch Number: 22
Batch Number: 23
Batch Number: 24
Overall Accuracy: 7.5067
Train Loss: 0.5132
F1 Score: 14.0703

tensor([[0., 0., 0., 0., 1., 0., 0., 0.],
        [1., 0., 0., 0., 1., 0., 1., 0.],
        [1., 0., 0., 0., 1., 0., 1., 0.],
        [1., 0., 0., 0., 1., 0., 1., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 1., 0.],
        [0., 0., 0., 0., 1., 0., 1., 0.],
  



Epoch:  3
Number of batches: 23
Batch Number: 1
Batch Number: 2
Batch Number: 3
Batch Number: 4
Batch Number: 5
Batch Number: 6
Batch Number: 7
Batch Number: 8
Batch Number: 9
Batch Number: 10
Batch Number: 11
Batch Number: 12
Batch Number: 13
Batch Number: 14
Batch Number: 15
Batch Number: 16
Batch Number: 17
Batch Number: 18
Batch Number: 19
Batch Number: 20
Batch Number: 21
Batch Number: 22
Batch Number: 23
Batch Number: 24
Overall Accuracy: 10.9920
Train Loss: 0.4795
F1 Score: 9.9979

tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 1., 0.],
        [1., 0., 0., 0., 0., 0., 1., 0.],
        [1., 0., 0., 0., 0., 0., 1., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 1., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 1., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
  



Epoch:  4
Number of batches: 23
Batch Number: 1
Batch Number: 2
Batch Number: 3
Batch Number: 4
Batch Number: 5
Batch Number: 6
Batch Number: 7
Batch Number: 8
Batch Number: 9
Batch Number: 10
Batch Number: 11
Batch Number: 12
Batch Number: 13
Batch Number: 14
Batch Number: 15
Batch Number: 16
Batch Number: 17
Batch Number: 18
Batch Number: 19
Batch Number: 20
Batch Number: 21
Batch Number: 22
Batch Number: 23
Batch Number: 24
Overall Accuracy: 8.8472
Train Loss: 0.4609
F1 Score: 6.3154

tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0.],
   



tensor([[0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 1., 1., 1., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 1., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 1., 0.],
        [1., 0., 1., 1., 0., 0., 1., 0.],
        [0., 1., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 1., 1., 0., 0., 1., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 1., 1., 0., 1



tensor([[0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 1., 1., 1., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 1., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 1., 0.],
        [1., 0., 1., 1., 0., 0., 1., 0.],
        [0., 1., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 1., 1., 0., 0., 1., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 1., 1., 0., 1



tensor([[0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 1., 1., 1., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 1., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 1., 0.],
        [1., 0., 1., 1., 0., 0., 1., 0.],
        [0., 1., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 1., 1., 0., 0., 1., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 1., 1., 0., 1



tensor([[0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 1., 1., 1., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 1., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 1., 0.],
        [1., 0., 1., 1., 0., 0., 1., 0.],
        [0., 1., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 1., 1., 0., 0., 1., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 1., 1., 0., 1



tensor([[0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 1., 1., 1., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 1., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 1., 0.],
        [1., 0., 1., 1., 0., 0., 1., 0.],
        [0., 1., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 1., 1., 0., 0., 1., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 1., 1., 0., 1



tensor([[0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 1., 1., 1., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 1., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 1., 0.],
        [1., 0., 1., 1., 0., 0., 1., 0.],
        [0., 1., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 1., 1., 0., 0., 1., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 1., 1., 0., 1



tensor([[0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 1., 1., 1., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 1., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 1., 0.],
        [1., 0., 1., 1., 0., 0., 1., 0.],
        [0., 1., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 1., 1., 0., 0., 1., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 1., 1., 0., 1



tensor([[0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 1., 1., 1., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 1., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 1., 0.],
        [1., 0., 1., 1., 0., 0., 1., 0.],
        [0., 1., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 1., 1., 0., 0., 1., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 1., 1., 0., 1



Epoch:  0
Number of batches: 23
Batch Number: 1


OutOfMemoryError: ignored