## Script Plan:

1. Imports
2. Define Training Parameters + Variables:
    - Epochs
    - Batch_size
    - Random Seed
    - Input Size
    - Patience
    - Learning Rate
    - Input Data Folder - Labelled Folders
3. Data Import/Organisation
    - TF import data (from folder)
    - Train/Val/Test Splitting
    - K-Fold Validation
4. Load and Prep Models
    - Pretrained ENB3 (Freezing AFIB Networks, Retraining Classifier)
    - Pretrained ENB3 (Unfrozen)
    - Fresh ENB3 (Frozen ImageNet Pretraining)
    - Fresh ENB3 (Unfrozen ImageNet Pretraining)
5. Training Iterations
    - Binary (0+1R VS 2R)
    - Binary (0 vs 1+2R)
    - Multiclass (0R VS 1R Vs 2R)
    - Tran/Val/test (60%/20%/20%) vs 5-Folds Cross-validation 
6. Validation/Test Results
    - Write Results Function

Training plan: 
1. SWOIP: single with only image pretraining 

Tuning: 
1. Learning rate 
2. Patience 
3. Freezing to unfrozen certain amounts of layers 
4. change binary classifications 
5. multiclass 
6. ensembling 4 ENB3 models, feed each model output into 10 classifiers 

## IMPORTS

In [None]:
import os

import tensorflow as tf
from tensorflow import keras 
from keras.optimizers import Adam

from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colors
from matplotlib.ticker import PercentFormatter
import seaborn as sns
import pandas as pd
import PIL
import PIL.Image
from datetime import datetime

from PIL import Image

import sklearn.metrics as metrics
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, ConfusionMatrixDisplay, roc_curve, RocCurveDisplay, roc_auc_score

from sklearn.utils import class_weight

import csv

from keras import regularizers
import time 
import itertools

print (tf.version.VERSION)

## TRAINING PARAMETERS & VARIABLES

In [None]:
### Training Parameters ###
seed = 2806
batch_size = 8 #16 
# learning_rate = 0.0001
max_epochs = 100
# patience = 15
image_size = 300
img_width, img_height = image_size, image_size

### Label Information ### 
# target_labels = ['mild rejection','moderate-to-severe rejection','no rejection']
# classifier = 'binary'
training_style = "Ensemble"
Dataset = "Binary-1"
Pretrained_model = "MI"

# ==== Create target directory for saving/loading models === 
project_dir = 'D://Rui//'

# === Create net directory for saving/loading models === 
date = datetime.now()
net_dir = project_dir + 'Transplant_prediction_models//' + training_style + Pretrained_model + '_' + Dataset + '_' + str(date.day).rjust(2, '0')+str(date.month).rjust(2, '0')+str(date.year)+'_'+ str(date.hour)+str(date.minute)
os.mkdir(net_dir)

if Dataset == "Binary-1":
    target_labels = ['0-1R', "2R"]
    nClasses = 2
elif Dataset == "Binary-2":
    target_labels = ['0', "1-2R"]
    nClasses = 2

# === Pre-trained Net Location ===
if Pretrained_model == "AF":
    pretrained_net_dir = project_dir + 'AFIBvsNOTNet(080120254x3cols)//'
elif Pretrained_model == "MI":
    pretrained_net_dir = project_dir + 'MulticlassMIPTBXLTraining(4x3)//'

print("target_labels =", target_labels)
print("nClasses = ", nClasses)
# print(pretrained_net_dir)


## DATA IMPORT

### TFDS Import & Organise Function

In [None]:
# Data Input Folders
if Dataset == "Binary-1":
    dataset_dir = "D://Rui//NIDACT2 Transplant ECGs//Multiclass_HTx_Rejection_Dataset_1(300_300)//Binary1_0-1Rvs2R"
elif Dataset == "Binary-2": 
    dataset_dir = "D://Rui//NIDACT2 Transplant ECGs//Multiclass_HTx_Rejection_Dataset_1(300_300)//Binary2_0vs1-2R"

print("dataset_dir =", dataset_dir)

### Train/val/test split 

In [None]:
# train_dir = dataset_dir + '(Splits)//train//'
# train_ds = tf.keras.preprocessing.image_dataset_from_directory(
#     directory = train_dir,
#     # labels = "inferred"
#     label_mode='categorical', #(sparse -> image can only belong to one group)
#     class_names = target_labels, 
#     # colour_mode = "rgb" -> three colour channels image
#     batch_size = batch_size,
#     image_size = (img_height, img_width),
#     shuffle = True, 
#     seed = seed
#     # validation_split = 0.2, # Spliting dataset into 80% training, 20% validation set
#     # subset = "both", 
# )

In [None]:
# val_dir = dataset_dir + '(Splits)//val//'
# val_ds = tf.keras.preprocessing.image_dataset_from_directory(
#     directory = val_dir,
#     # labels = "inferred"
#     label_mode='categorical', #(sparse -> image can only belong to one group)
#     class_names = target_labels, 
#     # colour_mode = "rgb" -> three colour channels image
#     batch_size = batch_size,
#     image_size = (img_height, img_width),
#     shuffle = True, 
#     seed = seed,
#     # validation_split = 0.2, # Spliting dataset into 80% training, 20% validation set
#     # subset = "both", 
# )

In [None]:
# test_dir = dataset_dir + '(Splits)//test//'
# test_ds = tf.keras.preprocessing.image_dataset_from_directory(
#     directory = test_dir,
#     # labels = "inferred"
#     label_mode='categorical', #(sparse -> image can only belong to one group)
#     class_names = target_labels, 
#     # colour_mode = "rgb" -> three colour channels image
#     batch_size = batch_size,
#     image_size = (img_height, img_width),
#     shuffle = True, 
#     seed = seed
#     # validation_split = 0.2, # Spliting dataset into 80% training, 20% validation set
#     # subset = "both", 
# )

### Sixty Forty Split 

In [None]:
# train_dir = dataset_dir + '(6040)//train//'
# train_ds = tf.keras.preprocessing.image_dataset_from_directory(
#     directory = train_dir,
#     # labels = "inferred"
#     label_mode='categorical', #(sparse -> image can only belong to one group)
#     class_names = target_labels, 
#     # colour_mode = "rgb" -> three colour channels image
#     batch_size = batch_size,
#     image_size = (img_height, img_width),
#     shuffle = True, 
#     seed = seed
# )

In [None]:
# val_dir = dataset_dir + '(6040)//val//'
# val_ds = tf.keras.preprocessing.image_dataset_from_directory(
#     directory = val_dir,
#     # labels = "inferred"
#     label_mode='categorical', #(sparse -> image can only belong to one group)
#     class_names = target_labels, 
#     # colour_mode = "rgb" -> three colour channels image
#     batch_size = batch_size,
#     image_size = (img_height, img_width),
#     shuffle = True, 
#     seed = seed
# )

In [None]:
both_ds = tf.keras.preprocessing.image_dataset_from_directory(
    directory = dataset_dir,
    # labels = "inferred"
    label_mode='categorical', #(sparse -> image can only belong to one group)
    class_names = target_labels, 
    # colour_mode = "rgb" -> three colour channels image
    batch_size = batch_size,
    image_size = (img_height, img_width),
    shuffle = True, 
    seed = seed,
    validation_split = 0.4, # Spliting dataset into 60% training, 40% validation set
    subset = "both", 
)

train_ds = both_ds[0]
val_ds = both_ds[1]

In [None]:
# # ==== George ==== #
# train_temp_labels = np.empty([0,2])

# for batch,label in train_ds:
#         train_temp_labels = np.concatenate([train_temp_labels, label], axis=0)

# # print(train_temp_labels)
# print(len(train_temp_labels))
# print(np.sum(train_temp_labels,axis=0))

# val_temp_labels = np.empty([0,2])

# for batch, label in val_ds:
#         val_temp_labels = np.concatenate([val_temp_labels, label], axis=0)

# # print(val_temp_labels)
# print(len(val_temp_labels))
# print(np.sum(val_temp_labels,axis=0))

## LOAD MODELS

### Load fresh ENB3 Function (SWOIP)

In [None]:
# ENB3_model = keras.applications.EfficientNetB3(
#     include_top = True # whether to include the fully-connected layer at the top of the network 
#     # weights = "imagenet" / "none" -> never seen imageNet 
#     # input_tensor = None
#     # input_shape = None 
#     # pooling = None
#     # classes = 1000
#     # classifier_activation = "softmax"
# )
# # Freezing 
# for layer in ENB3_model.layers[:385]: #frozen to layer 125, unfreezed the rest 
#     layer.trainable = True # false: freezing, true: unfreeze 
# # How to look into the structure of the model 
# # loaded_model.layers[385]
# # summary = loaded_model.summary()
# # print("model summary")

### Load four Pretrained ENB3 Function (AF OR MI model)

In [None]:
model_names = [125, 198, 272, 385]
pretrained_ensemble_models = []
for name in model_names: 
    model = tf.keras.models.load_model(
            filepath = pretrained_net_dir + "//" + str(name) + ".h5"
            # custom_objects=None, 
            # compile=True, 
            # safe_mode=True
            )
    for layer in model.layers[:name]:
        layer.trainable = True # false: freezing, true: unfreeze 
    pretrained_ensemble_models.append(model)

## TRAINING

### Adding classifier layer function

In [None]:
# USE THIS ONE script 
def add_classification_layers(cutModel): #l1_reg, l2_reg, dropout_rate_1, dropout_rate_2, nClasses
    classifier = cutModel.output
    classifier = tf.keras.layers.GlobalAveragePooling2D()(classifier) #AveragePooling2D(pool_size=(4,4)) GlobalAveragePooling2D()
    # First dense layer with customizable regularization and dropout
    classifier = tf.keras.layers.Dense(128, activation="relu", 
                                     kernel_regularizer=regularizers.L1L2(l1=0.001, l2=0.001))(classifier) #l2_reg
    classifier = tf.keras.layers.Dropout(0.4)(classifier) #dropout_rate_1
    # Second dense layer with customizable regularization and dropout
    classifier = tf.keras.layers.Dense(64, activation="relu", 
                                     kernel_regularizer=regularizers.L1L2(l1=0.0001, l2=0.0001))(classifier)
    classifier = tf.keras.layers.Dropout(0.5)(classifier) #dropout_rate_2
    classifier = tf.keras.layers.Dense(nClasses, activation="softmax")(classifier)
    full_model = tf.keras.Model(inputs=cutModel.input, outputs=classifier)
    return full_model

### Training function 

In [None]:
# USE THIS ONE script 
def training_function(full_model, filepath, train_ds, val_ds): 
    # Define optimizer with an initial learning rate 
    optimizer = tf.keras.optimizers.Adam(learning_rate= 0.001) # This is my initial learning rate 

    # Define ReduceLROnPlateau callback (based on validation loss)
    reduce_lr_patience = 10
    lr_callback = tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.2, # Reduce LR by a factor of 5 (1/5 = 0.2)
        patience=reduce_lr_patience,
        min_lr=0.00001, # Don't let the LR get ridiculously small
        verbose=1
    )
    
    # Early stopping 
    early_stopping_patience = 25 #50
    early_stopper = EarlyStopping(
        monitor='val_loss',
        patience=early_stopping_patience,
        verbose=1,
        restore_best_weights=True # Automatically restore the model from the best epoch
    )

    saving_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath = filepath, 
        monitor="val_loss", 
        verbose=1, 
        save_best_only=True, 
        save_weights_only=False, 
        save_freq="epoch",
        initial_value_threshold=None
    )

    full_model.compile(optimizer=optimizer, 
                    loss = tf.keras.losses.BinaryCrossentropy(), #CategoricalFocalCrossentropy
                    metrics=['binary_accuracy'])
    
    # Calculate class weights 
    labels_train = []
    for _ , label in train_ds:
        labels_train.append (np.array (label))
    labels_train = np.concatenate(labels_train)
    weight_function = lambda i: (1/sum(labels_train[:,i])*(labels_train.shape[0]/labels_train.shape[1]))
    class_weights = {i: weight_function(i) for i in range(labels_train.shape[1])}
    # print(class_weights)

    # Training 
    results = full_model.fit(train_ds,
                        batch_size=batch_size, 
                        validation_data=val_ds, 
                        epochs=max_epochs,
                        callbacks=[early_stopper, lr_callback, saving_callback], # early_stopper, lr_callback
                        class_weight = class_weights
                        )
    return results 

### Loss Function

In [None]:
def graph_loss_function(results, filepath): 
    # Extract validation loss for easier access
    val_loss = results.history['val_loss']

    ### 1. Find the Best Epoch and Minimum Validation Loss
    # Use np.argmin to find the INDEX of the minimum loss
    best_epoch_idx = np.argmin(val_loss)
    min_val_loss = val_loss[best_epoch_idx]

    print(f"Minimum validation loss of {min_val_loss:.4f} was found at epoch {best_epoch_idx}.")

    # ### 2. Calculate a Suggested Patience ###
    # # Define how much worse the loss can get before we consider it overfitting
    # tolerance_percentage = 0.05 # e.g., 5% worse than the minimum
    # overfit_threshold = min_val_loss * (1 + tolerance_percentage)

    # # Find the first epoch *after* the best epoch where loss exceeds the threshold
    # patience_epoch = -1
    # for i in range(best_epoch_idx + 1, len(val_loss)):
    #     if val_loss[i] > overfit_threshold:
    #         patience_epoch = i
    #         break # Stop as soon as we find it

    # # The suggested patience is the number of epochs between the best and the overfit point
    # suggested_patience = -1
    # if patience_epoch != -1:
    #     suggested_patience = patience_epoch - best_epoch_idx

    ### 3. Plot the Results and Annotations ###
    plt.rcParams.update({'font.family':'sans-serif', 'font.sans-serif':['Arial']})
    plt.figure(figsize=(10, 6))
    plt.plot(results.history['loss'], label='Training Loss')
    plt.plot(val_loss, label='Validation Loss')

    # # Annotate the minimum validation loss
    # plt.axvline(x = best_epoch_idx, color = 'r', linestyle='--', label = f'Best Epoch: {best_epoch_idx}')
    # plt.scatter(best_epoch_idx, min_val_loss, color = 'red', zorder = 5) # Mark the best point

    # # Annotate the suggested patience point, if found
    # if suggested_patience > 0:
    #     plt.axvline(x=patience_epoch, color='g', linestyle='--', label=f'Recommended patience: {suggested_patience}')
    #     plt.scatter(patience_epoch, val_loss[patience_epoch], color='green', zorder=5)
    #     print(f"A good starting patience value could be around {suggested_patience}.")
    #     print(f"This is how many epochs it took for the validation loss to increase by more than {tolerance_percentage:.0%} from its minimum (at {patience_epoch} epoch).")
    # else:
    #     print("\nThe model did not appear to overfit within the training run (based on the defined tolerance).")

    plt.title('Training and Validation Loss Analysis')
    plt.xlabel('Epochs')
    plt.ylabel('Loss (Binary Cross Entropy)') 
    plt.legend()
    plt.grid(True)
    plt.show()
    plt.savefig(filepath + 'loss graph.png')

In [None]:
def graph_all_losses(histories, split_layers, filepath):
    plt.figure(figsize=(10, 6))
    plt.rcParams.update({'font.family':'sans-serif', 'font.sans-serif':['Arial']})
    # Map each layer number to a specific color
    color_map = {
        125: 'red',
        198: 'blue',
        272: 'green',
        385: 'purple'
    }

    # Loop through each history object and its corresponding split layer
    for history, layer in zip(histories, split_layers):
        color = color_map.get(layer, 'grey') # Default to grey if layer not in map
        label_prefix = f'truncated @{layer}'

        # Plot VALIDATION loss with a SOLID line
        plt.plot(
            history.history['val_loss'],
            color=color,
            linestyle='-',  # Solid line
            label=f'Val loss for model {label_prefix}'
        )

        # Plot TRAINING loss with a DASHED line
        plt.plot(
            history.history['loss'],
            color=color,
            linestyle='--', # Dashed line
            label=f'Train loss for model {label_prefix}'
        )
    
    plt.title('Combined Training & Validation Loss for 4 Truncated ENB3 models')
    plt.xlabel('Epochs')
    plt.ylabel('Loss (Binary Cross Entropy)')
    plt.legend()
    # plt.grid(color='grey', linestyle='--', linewidth=0.5)
    plt.savefig(filepath + '_combined_loss_graphs.png')
    plt.show()

### Training (SWOIP/SWAFP)

#### GRID SEARCH

In [None]:
# ### ======== GRID SEARCH (SWAFP) ============= ###
# # Define the grid of hyperparameters to search
# hyperparameter_grid = {
#     'l1_reg': [0.0001, 0.001, 0.01],
#     'l2_reg': [0.0001, 0.001, 0.01],
#     'dropout_rate_1': [0.3, 0.4, 0.5],
#     'dropout_rate_2': [0.3, 0.4, 0.5]
# }

# # This is the main block for SWAFP
# if training_style == "SWAFP":
#     split_layers = [125, 198, 272, 385]

#     best_val_loss = float('inf')
#     best_hyperparameters = None

#     keys = hyperparameter_grid.keys()
#     values = hyperparameter_grid.values()

#     # The outer grid search loop
#     for combo in itertools.product(*values):
#         params = dict(zip(keys, combo))
#         print(f"--- Training with hyperparameters: {params} ---")

#         all_results = []
#         pretrained_ensemble_models_copy = list(pretrained_ensemble_models) # Create a copy to prevent overwriting

#         # This is the inner training loop for the ensemble
#         for idx, split_layer in enumerate(split_layers):
#             loaded_model = pretrained_ensemble_models_copy[idx]
#             cutModel = tf.keras.Model(loaded_model.input, loaded_model.layers[split_layer-1].output)

#             # Build the model with the current hyperparameters
#             full_model = add_classification_layers(
#                 cutModel,
#                 nClasses=nClasses,
#                 l1_reg=params['l1_reg'],
#                 l2_reg=params['l2_reg'],
#                 dropout_rate_1=params['dropout_rate_1'],
#                 dropout_rate_2=params['dropout_rate_2']
#             )

#             # Define a unique filepath for each hyperparameter combination and model
#             filepath = os.path.join(net_dir, f"hparams__{str(combo)}__model_{split_layer}")
#             modelpath = filepath + "_" + "model.keras"

#             print(f"--- Training model truncated at layer {split_layer} --- ")

#             results = training_function(full_model, modelpath, train_ds, val_ds)
#             all_results.append(results)

#         # Evaluate the performance of this hyperparameter combination
#         # For simplicity, let's take the average of all ensemble models' best validation loss
#         avg_min_val_loss = np.mean([min(res.history['val_loss']) for res in all_results])

#         # Check if this combination is the best so far
#         if avg_min_val_loss < best_val_loss:
#             best_val_loss = avg_min_val_loss
#             best_hyperparameters = params
#             print(f"New best hyperparameter combination found! Average Val Loss: {best_val_loss:.4f} with params: {best_hyperparameters}")

#     print("--- Grid search complete ---")
#     print(f"Best hyperparameters: {best_hyperparameters}")
#     print(f"Best average validation loss: {best_val_loss:.4f}")

#### TRAINING LOOP

In [None]:
# TRAINING ENSEMBLE MODELS Function 
def train_ensemble_models(base_models, split_layers, net_dir, training_style, train_ds, val_ds):
    results = []
    print("--- Starting Ensemble Training ---")

    # Zip combines the two lists, to get one model and one split_layer per iteration.
    for base_model, split_layer in zip(base_models, split_layers):
        print(f"--- Training model truncated at layer {split_layer} --- ")
        
        # 1. Truncate the base model
        cut_model = tf.keras.Model(inputs=base_model.input, outputs=base_model.layers[split_layer - 1].output)
        
        # 2. Add new classification layers
        full_model = add_classification_layers(cut_model)
        
        # 3. Define paths and train the model
        filepath = os.path.join(net_dir, f"{training_style}_{split_layer}")
        modelpath = f"{filepath}_model.keras"
        
        t_start_model = time.time()
        history = training_function(full_model, modelpath, train_ds, val_ds)
        results.append(history)
        t_end_model = time.time()
        print(f"--- Model training finished in {t_end_model - t_start_model:.2f} seconds ---")

    print("--- All training complete. Generating combined graph. ---")
    combined_filepath = os.path.join(net_dir, training_style)
    
    # The graphing function with the collected results
    graph_all_losses(results, split_layers, combined_filepath)
    
    return results

In [None]:
# if __name__ == "__main__":
split_layers = [125, 198, 272, 385]

# Determine the list of base models to use
if Pretrained_model == "ENB3":
    models_to_train = [ENB3_model] * len(split_layers)
else:
    models_to_train = pretrained_ensemble_models

# Call the training function
training_histories = train_ensemble_models(
    base_models=models_to_train,
    split_layers=split_layers,
    net_dir=net_dir,
    training_style=training_style,
    train_ds=train_ds,
    val_ds=val_ds
)


### Generate a smoothed loss curve 

In [None]:
# val_loss_series = pd.Series(val_loss)

# # Calculate a moving average over a window of 20 epochs
# smoothed_val_loss = val_loss_series.rolling(window=20, center=True).mean()

# # Plot both for comparison
# plt.figure(figsize=(10, 6))
# plt.plot(val_loss, label='Raw Validation Loss', color = 'green', alpha = 0.4)
# plt.plot(smoothed_val_loss, label='Smoothed Validation Loss (20-Epoch Avg)', color='orange', linewidth=2)

# plt.title('Smoothed Loss Curve')
# plt.xlabel('Epochs')
# plt.ylabel('Loss (Categorical Cross Entropy)') 
# plt.legend()
# plt.grid(True)
# plt.show()

## MODEL PERFORMANCE

### Write Results File Function

In [None]:
def writeResultsFile(labels_test,ensemble_prediction,ensemble_scores,nClasses, t_start,t_end,writeDir):
    labelResults = []
    # Binary Classifier
    if nClasses == 2:
        positive_labels = sum(labels_test)
        negative_labels = len(labels_test) - positive_labels
        tn, fp, fn, tp = confusion_matrix(labels_test,ensemble_prediction).ravel()
        accuracy = (accuracy_score(labels_test,ensemble_prediction))
        precision = (precision_score(labels_test,ensemble_prediction))
        recall = (recall_score(labels_test,ensemble_prediction))
        f1 = (f1_score(labels_test,ensemble_prediction))
        train_time = (t_end-t_start)
        fpr, tpr, _ = roc_curve(labels_test, ensemble_scores[:,1])
        roc_auc = roc_auc_score(labels_test, ensemble_scores[:,1])
        roc_display = RocCurveDisplay(fpr=fpr, tpr=tpr).plot()
        roc_display.figure_.savefig(writeDir + '/ROC.png')
        
        header = ['Positive Labels', 'Negative Labels', 'TP', 'FP', 'TN', 'FN', 'Accuracy', 'Precision', 'Recall', 'F1-score', 'ROC AUC', 'Training Time']
        with open(writeDir + '/' + 'Ensemble results.csv', 'w', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(header)
            
            writer.writerow([positive_labels, negative_labels, tp, fp, tn, fn, accuracy, precision, recall, f1, roc_auc, train_time])
        
    # Multiclass
    elif nClasses > 2:
        
        binary_truths = labels_test
        binary_preds = np.round(ensemble_scores) # using rounding to calculate confusion matrix
        label_names = target_labels
            
        for labelID in range(nClasses):
            name = label_names[labelID]
            y_truth = binary_truths[:,labelID]
            y_pred = binary_preds[:,labelID]
            positive_labels = sum(y_truth)
            negative_labels = len(y_truth) - positive_labels
            tn, fp, fn, tp  = confusion_matrix(y_truth,y_pred).ravel()
            accuracy = (accuracy_score(y_truth,y_pred))
            precision = (precision_score(y_truth,y_pred))
            recall = (recall_score(y_truth,y_pred))
            f1 = (f1_score(y_truth,y_pred))
            train_time = (t_end-t_start)
            fpr, tpr, _ = roc_curve(binary_truths[:,labelID], ensemble_scores[:,labelID])
            roc_auc = roc_auc_score(binary_truths[:,labelID], ensemble_scores[:,labelID])
            roc_display = RocCurveDisplay(fpr=fpr, tpr=tpr).plot()
            roc_display.figure_.savefig(writeDir + '/ROC ' + name + '.png')
            labelResults.append([name, positive_labels, negative_labels, tp, fp, tn, fn, accuracy, precision, recall, f1, roc_auc, train_time])
            
            header = ['Label', 'Positive Labels', 'Negative Labels', 'TP', 'FP', 'TN', 'FN', 'Accuracy', 'Precision', 'Recall', 'F1-score', 'ROC AUC', 'Training Time']

        with open(writeDir + '/' + ' Ensemble results.csv', 'w', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(header)
            writer.writerows(labelResults)

    # for idx, matrix in enumerate(matrices):
    #     disp = ConfusionMatrixDisplay(matrix)
    #     disp.plot()
    #     plt.show()
    #     plt.savefig([writeDir + '\\' + label_code[idx] + ' Confusion Matrix.png'])
    #     plt.close()

    f.close()

    return labelResults, header

### Write Ensemble Results File Function 

In [None]:
# --- Calculate Metrics ---
def _calculate_metrics(labels_test, prediction, scores, nClasses, t_start, t_end, model_name=""):
    """
    Calculates and returns a dictionary of performance metrics.
    Does NOT write any files or plot anything.

    Args:
        labels_test (np.array): True labels.
        prediction (np.array): Predicted labels.
        scores (np.array): Prediction scores (probabilities).
        nClasses (int): Number of classes.
        t_start (float): Start time for training.
        t_end (float): End time for training.
        model_name (str, optional): Name of the model for identification in results. Defaults to "".

    Returns:
        dict: A dictionary containing calculated metrics.
    """
    metrics = {'Model Name': model_name}
    train_time = (t_end - t_start)

    if nClasses == 2:
        positive_labels = sum(labels_test)
        negative_labels = len(labels_test) - positive_labels
        tn, fp, fn, tp = confusion_matrix(labels_test, prediction).ravel()
        accuracy = accuracy_score(labels_test, prediction)
        precision = precision_score(labels_test, prediction)
        recall = recall_score(labels_test, prediction)
        f1 = f1_score(labels_test, prediction)
        roc_auc = roc_auc_score(labels_test, scores[:, 1])

        metrics.update({
            'Positive Labels': positive_labels,
            'Negative Labels': negative_labels,
            'TP': tp, 'FP': fp, 'TN': tn, 'FN': fn,
            'Accuracy': accuracy,
            'Precision': precision,
            'Recall': recall,
            'F1-score': f1,
            'ROC AUC': roc_auc,
            'Training Time': train_time
        })
    elif nClasses > 2:
        # For combined table, let's use overall averages
        accuracy = accuracy_score(labels_test, prediction)
        precision = precision_score(labels_test, prediction, average='macro', zero_division=0)
        recall = recall_score(labels_test, prediction, average='macro', zero_division=0)
        f1 = f1_score(labels_test, prediction, average='macro', zero_division=0)

        try:
            roc_auc = roc_auc_score(labels_test, scores, multi_class='ovr', average='macro')
        except ValueError:
            roc_auc = np.nan # Or some other indicator

        metrics.update({
            'Accuracy': accuracy,
            'Precision (Macro)': precision,
            'Recall (Macro)': recall,
            'F1-score (Macro)': f1,
            'ROC AUC (Macro OVR)': roc_auc,
            'Training Time': train_time
        })

    return metrics

# --- individual model performance: only saves ROC plots ---
def _save_individual_model_roc_plot(labels_test, scores, nClasses, writeDir, filename_identifier="", target_labels=None):
    """
    Saves individual model ROC plots to the specified directory.
    Does NOT save any CSV files.

    Args:
        labels_test (np.array): True labels.
        scores (np.array): Prediction scores (probabilities).
        nClasses (int): Number of classes.
        writeDir (str): The base directory where files will be saved.
        filename_identifier (str, optional): An identifier to prepend to filenames.
                                            Defaults to "".
        target_labels (list, optional): List of class names for multiclass plotting.
    """
    plt.rcParams['font.family'] = 'Arial'
    
    # Ensure the base directory exists before writing any files
    os.makedirs(writeDir, exist_ok=True)

    # Binary Classifier
    if nClasses == 2:
        fpr, tpr, _ = roc_curve(labels_test, scores[:, 1])
        roc_display = RocCurveDisplay(fpr=fpr, tpr=tpr).plot()
        roc_display.figure_.savefig(os.path.join(writeDir, f"{filename_identifier}ROC.png"))
        plt.close(roc_display.figure_) # Close the figure to free memory

    # Multiclass
    elif nClasses > 2:
        # Ensure target_labels are provided for multiclass ROC plotting
        if target_labels is None:
            print("Warning: target_labels not provided for multiclass ROC plotting. Using default names.")
            target_labels = [f"Class_{i}" for i in range(nClasses)]

        for labelID in range(nClasses):
            name = target_labels[labelID]
            # roc_curve needs binary_truths and scores for that class
            fpr, tpr, _ = roc_curve(labels_test[:,labelID], scores[:,labelID])
            roc_display = RocCurveDisplay(fpr=fpr, tpr=tpr).plot()
            roc_display.figure_.savefig(os.path.join(writeDir, f"{filename_identifier}ROC_{name}.png"))
            plt.close(roc_display.figure_) # Close the figure to free memory

    print(f"ROC plots saved with identifier '{filename_identifier}' in '{writeDir}'")

# --- Write Combined Ensemble Results to a Single CSV ---
def write_combined_ensemble_results_csv(all_metrics_data, writeDir, nClasses):
    """
    Writes combined performance metrics from multiple models (including ensemble)
    into a single CSV file.

    Args:
        all_metrics_data (list): A list of dictionaries, where each dictionary
                                 contains metrics for a model/ensemble.
        writeDir (str): The directory where the combined CSV will be saved.
        nClasses (int): Number of classes (to determine header).
    """
    os.makedirs(writeDir, exist_ok=True) # Ensure the directory exists

    combined_csv_path = os.path.join(writeDir, "Combined_Ensemble_Results.csv")

    # Define the header based on nClasses
    if nClasses == 2:
        header = ['Model Name', 'Positive Labels', 'Negative Labels', 'TP', 'FP', 'TN', 'FN',
                  'Accuracy', 'Precision', 'Recall', 'F1-score', 'ROC AUC', 'Training Time']
    elif nClasses > 2:
        # The metrics dictionary from _calculate_metrics defines these for multiclass
        header = ['Model Name', 'Accuracy', 'Precision (Macro)', 'Recall (Macro)',
                  'F1-score (Macro)', 'ROC AUC (Macro OVR)', 'Training Time']
    else:
        print(f"Warning: nClasses={nClasses} is not supported for combined CSV header.")
        return

    with open(combined_csv_path, 'w', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=header)
        writer.writeheader()
        for metrics_dict in all_metrics_data:
            writer.writerow(metrics_dict)

    
    print(f"Combined ensemble results saved to: {combined_csv_path}")

In [None]:
os.makedirs(net_dir, exist_ok=True)
print(f"Ensured base directory exists: {net_dir}")


### Generate validation results - Sixty/forty Split 

#### ENSEMBLE

In [None]:
def evaluate_ensemble(models_to_evaluate, val_ds, split_layers, nClasses, net_dir, target_labels=None):
    print("--- Starting Ensemble Evaluation ---")
    all_ensemble_metrics = []

    # 1. Prepare true labels
    true_labels = np.concatenate([y for x, y in val_ds], axis=0)
    binary_true_labels = np.argmax(true_labels, axis=1)

    # 2. Gather predictions from all models (single, unified loop)
    print("Gathering predictions for all models...")
    all_model_scores_by_batch = [[] for _ in models_to_evaluate]
    for batch_data, _ in val_ds:
        for idx, model in enumerate(models_to_evaluate):
            batch_score = model(batch_data, training=False)
            all_model_scores_by_batch[idx].append(np.array(batch_score))
    
    final_model_scores = [np.concatenate(scores) for scores in all_model_scores_by_batch]
    print("All predictions gathered.")

    # 3. Evaluate each individual model
    print("\n--- Evaluating Individual Models ---")
    for idx, model_score in enumerate(final_model_scores):
        model_name = f"Ensemble Model {split_layers[idx]}"
        
        t_start_inference = time.time()
        model_prediction = np.argmax(model_score, axis=1)
        t_end_inference = time.time()

        _save_individual_model_roc_plot(binary_true_labels, model_score, nClasses, net_dir, f"Model_{split_layers[idx]}_", target_labels)
        print(f"'{model_name}' ROC plot saved.")

        model_metrics = _calculate_metrics(binary_true_labels, model_prediction, model_score, nClasses, t_start_inference, t_end_inference, model_name)
        all_ensemble_metrics.append(model_metrics)

    # 4. Evaluate the combined ensemble average
    print("\n--- Evaluating Ensemble Average ---")
    t_start_ensemble = time.time()
    average_ensemble_score = np.mean(np.array(final_model_scores), axis=0)
    ensemble_prediction = np.argmax(average_ensemble_score, axis=1)
    t_end_ensemble = time.time()

    _save_individual_model_roc_plot(binary_true_labels, average_ensemble_score, nClasses, net_dir, "Ensemble_Average_", target_labels)
    print("'Ensemble Average' ROC plot saved.")
    
    ensemble_metrics = _calculate_metrics(binary_true_labels, ensemble_prediction, average_ensemble_score, nClasses, t_start_ensemble, t_end_ensemble, "Ensemble Average")
    all_ensemble_metrics.append(ensemble_metrics)

    # 5. Write all metrics to a single CSV file
    write_combined_ensemble_results_csv(all_ensemble_metrics, net_dir, nClasses)
    print(f"\nCombined ensemble results saved successfully to '{net_dir}'.")

    return average_ensemble_score, binary_true_labels

In [None]:
split_layers = [125, 198, 272, 385] 
trained_models = []

print("--- Loading all trained ensemble models from disk ---")

# Loop through the split_layers to load each corresponding model
for layer in split_layers:
    # Re-create the exact filepath used during training
    filepath = os.path.join(net_dir, f"{training_style}_{layer}")
    modelpath = f"{filepath}_model.keras"
    
    print(f"Loading model: {modelpath}")
    
    # Load the model and add it to our list
    model = tf.keras.models.load_model(modelpath)
    trained_models.append(model)

print("--- All models loaded successfully. ---")

# Call the evaluation function with the list of loaded models
average_ensemble_score, binary_true_labels = evaluate_ensemble(
    models_to_evaluate=trained_models,
    val_ds=val_ds,
    split_layers=split_layers,
    nClasses=nClasses,
    net_dir=net_dir,
    target_labels=target_labels if 'target_labels' in locals() else None
)

#### Probability Score Histogram

In [None]:
# Create a SINGLE figure and axis for the combined plot
fig, axs = plt.subplots(figsize=(10, 6), tight_layout=True)
plt.rcParams.update({'font.family':'sans-serif', 'font.sans-serif':['Arial']})
# --- using boolean masking ---
# 1. Get scores for the NEGATIVE class where the true label is actually negative (0)
neg_mask = (binary_true_labels == 0)
negative_scores = average_ensemble_score[neg_mask, 1]

# 2. Get scores for the POSITIVE class where the true label is actually positive (1)
pos_mask = (binary_true_labels == 1)
positive_scores = average_ensemble_score[pos_mask, 1]

# --- Plot both histograms on the same axes ---
# Plot the histogram for the negative class
axs.hist(negative_scores, bins=20, color='skyblue', range=(0, 1), edgecolor='black', alpha=0.7, label='Negative Labels')
# Plot the histogram for the positive class
axs.hist(positive_scores, bins=20, color='salmon', range=(0, 1), edgecolor='black', alpha=0.7, label='Positive Labels')

# --- Formatting ---
# # Remove axes splines for a cleaner look
# for s in ['top', 'right']:
#     axs.spines[s].set_visible(False)

# Add gridlines
# Set the tick locations for the grid
plt.xticks(np.arange(0, 1.1, 0.1))
# Dynamically set y-ticks based on histogram height
_, y_max = plt.ylim()
plt.yticks(np.arange(0, y_max + 2, 2))
# axs.grid(axis='y', color='grey', linestyle='-.', linewidth=0.5, alpha=0.6)
# _, y_max = plt.ylim()
# plt.yticks(np.arange(0, y_max + 1, 2))
plt.grid(True, linestyle='--', alpha=0.6)

# Add labels, title, and a legend
axs.axvline(0.5, color='red', linestyle='--', linewidth=2, label='0.5 Threshold')
plt.xlabel('Predicted Probability Score')
plt.ylabel('Frequency')
plt.title('Distribution of Model Confidence')
plt.legend()
save_path_pos = os.path.join(net_dir, 'histogram_all_labels_combined.png')
plt.savefig(save_path_pos, dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# --- 1. Filter the data based on the TRUE labels ---
# Filter for all truly NEGATIVE samples (includes True Negatives and False Positives)
neg_true_mask = (binary_true_labels == 0)
# Get the model's predicted probability of being POSITIVE for these samples
scores_for_neg_labels = average_ensemble_score[neg_true_mask, 1]

# Filter for all truly POSITIVE samples (includes True Positives and False Negatives)
pos_true_mask = (binary_true_labels == 1)
# Get the model's predicted probability of being POSITIVE for these samples
scores_for_pos_labels = average_ensemble_score[pos_true_mask, 1]

# --- 2. Create the plots separately ---
# == Plot 1: For Negative Labels (TN + FP) ==
plt.figure(figsize=(10, 6))
plt.hist(scores_for_neg_labels, bins=20, range=(0, 1), color='skyblue', edgecolor='black')
plt.axvline(0.5, color='red', linestyle='--', linewidth=2, label='0.5 Threshold')

# Set the tick locations for the grid
plt.xticks(np.arange(0, 1.1, 0.1))
# Dynamically set y-ticks based on histogram height
_, y_max = plt.ylim()
plt.yticks(np.arange(0, y_max + 2, 2))

# Apply the styled grid to both axes
plt.grid(True, linestyle='--', alpha=0.6)
plt.title('Predictions for Negative Labels (TN + FP)', fontweight='bold')
plt.xlabel('Predicted Probability of Being Positive')
plt.ylabel('Frequency')
plt.legend()

save_path_neg = os.path.join(net_dir, 'histogram_negative_labels.png')
plt.savefig(save_path_neg, dpi=300, bbox_inches='tight')
plt.show()

# == Plot 2: For Positive Labels (TP + FN) ==
plt.figure(figsize=(10, 6))
plt.hist(scores_for_pos_labels, bins=20, range=(0, 1), color='salmon', edgecolor='black')
plt.axvline(0.5, color='red', linestyle='--', linewidth=2, label='0.5 Threshold')

# Set the tick locations for the grid
plt.xticks(np.arange(0, 1.1, 0.1))
# Dynamically set y-ticks based on histogram height
_, y_max = plt.ylim()
plt.yticks(np.arange(0, y_max + 2, 2))

# Apply the styled grid to both axes
plt.grid(True, linestyle='--', alpha=0.6)
plt.title('Predictions for Positive Labels (TP + FN)', fontweight='bold')
plt.xlabel('Predicted Probability of Being Positive')
plt.ylabel('Frequency')
plt.legend()

# Save and show the plot
save_path_pos = os.path.join(net_dir, 'histogram_positive_labels.png')
plt.savefig(save_path_pos, dpi=300, bbox_inches='tight')
plt.show()

print(f"Gridded plots saved successfully to '{net_dir}'")

#### GET POSITIVE PREDICTIONS

In [None]:
# pos_val_dir = val_dir + "//2R"
# neg_val_dir = val_dir + "//0-1R"

# neg_filepath = os.walk(neg_val_dir)
# pos_filepath = os.walk(pos_val_dir)

# pos_labels = []
# neg_labels = []

# for path, _, file_names in os.walk(neg_val_dir): 
#     neg_file_names = file_names
#     for file in file_names:
#         ecg = Image.open(path+'//'+file)
#         ecg = np.expand_dims(ecg,axis=0)
#         all_predictions = []
#         for model in trained_models:
#             print(f"  - Predicting with model...")
#             preds = model.predict(ecg)
#             all_predictions.append(preds)
#         neg_labels.append(all_predictions)
# neg_labels = np.array(neg_labels)

# for file in pos_filepath: 
#     ecg = Image.open(file)
#     ecg = np.expand_dims(ecg,axis=0)
#     all_predictions = []
#     for model in trained_models:
#         print(f"  - Predicting with model...")
#         preds = model.predict(ecg)
#         all_predictions.append(preds)
#     pos_labels.append(all_predictions)
# pos_labels = np.array(pos_labels)


# # == TO FIND TRUE POSITIVES ==== 
# # Extract the true labels from the dataset so we can calculate true positives
# print(" - Extracting true labels from the dataset...")
# true_labels = []
# file_paths = []
# for batches, paths in unshuffled_val_ds:
#         true_labels.append(batches[1])
#         file_paths.append(paths)

# true_labels_categorical = np.concatenate([y for x, y in unshuffled_val_ds], axis=0) # This is in a shape of (batch_size, classes) 
# true_label_indices = np.argmax(true_labels_categorical, axis=1) # The same as ensemble indices 

# all_predictions = []
# print("\n--- Getting predictions from each loaded model ---")

# # 1. Loop through loaded models to get their individual predictions
# for model in trained_models:
#     print(f"  - Predicting with model...")
#     preds = model.predict(unshuffled_val_ds)
#     all_predictions.append(preds)

# print("\n--- Averaging predictions for the final ensemble result ---")
# # 2. Average the predictions to get the final ensemble decision
# ensemble_predictions = np.mean(np.array(all_predictions), axis=0) # The prediction has a shape of (models, images, classes), axis=0 is the models 
# ensemble_predicted_indices = np.argmax(ensemble_predictions, axis=1) # The input is a 2D array with a shape of (images, class_probabilities)

# # 4. Define the positive class index 
# print("\n--- Identifying positive class ---")

# positive_class_name = None
# if Dataset == "Binary-1":
#     positive_class_name = '2R'
# elif Dataset == "Binary-2":
#     positive_class_name = '1-2R'
# else: # This handles the case where the Dataset variable is something unexpected
#     print(f"FATAL: Unrecognised dataset name '{Dataset}'. Please check the variable. Exiting.")
#     exit()
# try: # This safely finds the index of the class name
#     positive_class_index = class_names.index(positive_class_name)
#     print(f"Successfully identified positive class '{positive_class_name}' at index {positive_class_index}.")
# except ValueError:
#     print(f"FATAL: The positive class '{positive_class_name}' was not found in the dataset's classes: {class_names}. Exiting.")
#     exit()

# # ======================== DEBUGGING SANITY CHECK ========================
# print("\n--- Sanity Check: Comparing first 15 predictions to true labels ---")
# for i in range(110):
#     # Get the data for the i-th image
#     filename = os.path.basename(image_filenames[i])
#     predicted_class_index = ensemble_predicted_indices[i]
#     true_class_index = true_label_indices[i]
    
#     # Convert indices to readable class names
#     predicted_class_name = class_names[predicted_class_index]
#     true_class_name = class_names[true_class_index]
    
#     # Check if this specific instance is a True Positive
#     is_tp_text = ""
#     if predicted_class_index == positive_class_index and true_class_index == positive_class_index:
#         is_tp_text = "TRUE POSITIVE"

#     print(f"File: {filename:<20} | Predicted: '{predicted_class_name}' | Actual: '{true_class_name}' {is_tp_text}")
# print("--- End of Sanity Check ---\n")
# # ========================================================================

# # 5. Identify which predictions were Positive and which were True Positive
# positive_images = []
# true_positive_images = []

# for i, filename in enumerate(image_filenames):
#     predicted_index = ensemble_predicted_indices[i]
#     true_index = true_label_indices[i]
#     # Check for any image the model PREDICTED as positive
#     if predicted_index == positive_class_index:
#         positive_images.append(os.path.basename(filename))
#     # Check for images that were PREDICTED positive AND ARE ACTUALLY positive
#     if predicted_index == positive_class_index and true_index == positive_class_index:
#         true_positive_images.append(os.path.basename(filename))

# # 6. Report final results and save the True Positive list to a file
# print("\n--- Final Results ---")

# # Get the final counts
# num_predicted_positive = len(positive_images)
# num_true_positive = len(true_positive_images)

# print(f"Total Subjects Classified as Positive: {num_predicted_positive}")
# print(f"Total Subjects that are True Positives: {num_true_positive}")

# # Check if any true positives were found and list them
# if true_positive_images:
#     print("\nThe following subjects were correctly identified as 'positive' (True Positives):")
#     # Sort the list for clean, repeatable output
#     for fname in sorted(true_positive_images):
#         print(f"- {fname}")
#     # Save the TRUE POSITIVES list to an Excel file
#     try:
#         df = pd.DataFrame(sorted(true_positive_images), columns=['True_Positive_Filename'])
#         output_excel_path = os.path.join(net_dir, 'TRUE_POSITIVES_results.xlsx')
#         df.to_excel(output_excel_path, index=False)
#         print(f"\nTrue Positive list successfully saved to: {output_excel_path}")
#     except Exception as e:
#         print(f"\nCould not save Excel file. Error: {e}")
# else:
#     print("\n No subjects were correctly identified as 'positive' (No True Positives).")

#### Original ensemble results cell

In [None]:
# # ORIGINAL SCRIPT:
# # ==== Ensemble ImageNet Pretrained & AFIB/MI Pretrained =====
# if Pretrained_model == "ENB3": 
#     split_layers = [125, 198, 272, 385]
#     results = []
    
#     for split_layer in split_layers:
#         cutModel = tf.keras.Model(ENB3_model.input, ENB3_model.layers[split_layer-1].output)
#         EnsembleENB3_model = add_classification_layers(cutModel)

#         filepath = os.path.join(net_dir, f"{training_style}_{split_layer}")
#         modelpath = filepath + "_" + "model.keras"
        
#         print(f"--- Training model truncated at layer {split_layer} --- ")
#         t_start_model = time.time() 
#         results.append(training_function(EnsembleENB3_model, modelpath, train_ds, val_ds)) #softmax 
#         t_end_model = time.time() 
#         # results = training_function(EnsembleENB3_model, modelpath, train_ds, val_ds)
    
#     print("--- All training complete. Generating combined graph. ---")
#     combined_filepath = os.path.join(net_dir, training_style)
#     # graph_all_losses(results, split_layers, combined_filepath)
#     # graph_loss_function(results, filepath)

# else:
#     split_layers = [125, 198, 272, 385]
#     results = [] # This will store all the history objects
#     graph_labels = [] # This will store labels for each plot line

#     print("--- Starting Ensemble Training ---")
    
#     for idx, split_layer in enumerate(split_layers): 
#         loaded_model = pretrained_ensemble_models[idx]
#         cutModel = tf.keras.Model(loaded_model.input, loaded_model.layers[split_layer-1].output)
#         full_model = add_classification_layers(cutModel)
#         pretrained_ensemble_models[idx] = full_model
        
#         filepath = os.path.join(net_dir, f"{training_style}_{split_layer}") # filepath = net_dir + "//" + training_style + "_" + str(split_layer)
#         modelpath = filepath + "_" + "model.keras"
        
#         # Run training and append the history object to the results list
#         print(f"--- Training model truncated at layer {split_layer} --- ")
#         t_start_model = time.time() # Start time for training
#         results.append(training_function(pretrained_ensemble_models[idx], modelpath, train_ds, val_ds)) #softmax 
#         t_end_model = time.time() # End time for training
    
#     print("--- All training complete. Generating combined graph. ---")
    
#     # Call the graphing function, passing the list of split_layers for color mapping
#     combined_filepath = net_dir + "//" + training_style
#     # graph_all_losses(results, split_layers, combined_filepath)

In [None]:
# # USE THIS SCRIPT FOR ALL ENSEMBLE MODEL RESTULTS #
# print("--- Starting Ensemble Evaluation ---")

# all_ensemble_metrics = [] # List to store metrics dictionary for each model and the ensemble

# # Get all true labels from the validation set 
# true_labels = np.concatenate([y for x, y in val_ds], axis=0)
# binary_true_labels = np.argmax(true_labels, axis=1)

# # Get predictions for all models in a single pass over the data.
# print("Gathering predictions for all models...")
# if Pretrained_model == "ENB3": 
#     all_model_scores_by_batch = [[] for _ in EnsembleENB3_model]
#     for batch_data, _ in val_ds:
#         for idx, model in enumerate(EnsembleENB3_model):
#             batch_score = model(batch_data, training=False) 
#             all_model_scores_by_batch[idx].append(np.array(batch_score))
# else: 
#     all_model_scores_by_batch = [[] for _ in pretrained_ensemble_models]
#     for batch_data, _ in val_ds:
#         for idx, model in enumerate(pretrained_ensemble_models):
#             batch_score = model(batch_data, training=False) 
#             all_model_scores_by_batch[idx].append(np.array(batch_score))

# # Concatenate the batch scores for each model into a single array
# final_model_scores = [np.concatenate(scores) for scores in all_model_scores_by_batch]
# print("All predictions gathered.")

# # --- 1. Evaluate each individual model ---
# print("\n--- Evaluating Individual Models ---")
# for idx, model_score in enumerate(final_model_scores):
#     split_layer = split_layers[idx]
#     model_name = f"Ensemble Model {split_layer}"

#     # Time the inference step (argmax) for this model's predictions
#     t_start_inference = time.time()
#     model_prediction = np.argmax(model_score, axis=1)
#     t_end_inference = time.time()

#     # Save individual ROC plot for this model
#     _save_individual_model_roc_plot(
#         binary_true_labels, model_score, nClasses,
#         writeDir=net_dir, filename_identifier=f"Model_{split_layer}_",
#         target_labels=target_labels if 'target_labels' in locals() else None
#     )
#     print(f"'{model_name}' ROC plot saved.")

#     # Calculate metrics for this individual model
#     model_metrics = _calculate_metrics(
#         binary_true_labels, model_prediction, model_score, nClasses,
#         t_start=t_start_inference,
#         t_end=t_end_inference,
#         model_name=model_name
#     )
#     all_ensemble_metrics.append(model_metrics)

# # --- 2. Evaluate the combined ensemble average ---
# print("\n--- Evaluating Ensemble Average ---")

# # Time the ensemble averaging and prediction step
# t_start_ensemble = time.time()
# average_ensemble_score = np.mean(np.array(final_model_scores), axis=0)
# ensemble_prediction = np.argmax(average_ensemble_score, axis=1)
# t_end_ensemble = time.time()

# # Save ROC plot for the ensemble average
# _save_individual_model_roc_plot(
#     binary_true_labels, average_ensemble_score, nClasses,
#     writeDir=net_dir, filename_identifier="Ensemble_Average_",
#     target_labels=target_labels if 'target_labels' in locals() else None
# )
# print("'Ensemble Average' ROC plot saved.")

# # Calculate metrics for the ensemble average
# ensemble_metrics = _calculate_metrics(
#     binary_true_labels, ensemble_prediction, average_ensemble_score, nClasses,
#     t_start=t_start_ensemble,
#     t_end=t_end_ensemble,
#     model_name="Ensemble Average"
# )
# all_ensemble_metrics.append(ensemble_metrics)

# # --- 3. Write all collected metrics to a single CSV file ---
# write_combined_ensemble_results_csv(all_ensemble_metrics, net_dir, nClasses)
# print(f"\n Combined ensemble results saved successfully to '{net_dir}'.")

#### SWOIP OR ENSEMBLE 

In [None]:
# # original script: 
# if training_style == "SWOIP": 
#     temp_score = []
#     temp_labels = []
#     for batch, label in val_ds:
#         batch_score = SWOIP_model(batch)
#         temp_score.append (np.array(batch_score))
#         temp_labels.append (np.array (label))
#     model_score = np.concatenate(temp_score)
#     true_labels = np.concatenate(temp_labels)
#     model_prediction = np.argmax(model_score, axis=1)
#     binary_true_labels = np.argmax (true_labels, axis=1)
    
#     # Save individual results for SWOIP
#     writeResultsFile(
#         binary_true_labels, model_prediction, model_score, nClasses, 
#         t_start = 0 , t_end = 0, writeDir=net_dir,
#     )
#     print("SWOIP individual results saved.")

# elif training_style == "Ensemble":
#     all_ensemble_metrics = [] # List to store metrics for all models and ensemble
#     ensemble_scores_list = [] # To collect scores for ensemble averaging

#     # Assume true_labels and binary_true_labels are consistent across all batches for val_ds. Initialize outside the loop if val_ds is consistently processed
#     first_batch, first_label = next(iter(val_ds)) # Get first batch to determine true_labels structure
#     true_labels_template = np.concatenate([np.array(l) for _, l in val_ds]) # Get all true labels
#     binary_true_labels_template = np.argmax(true_labels_template, axis=1)

#     for idx, split_layer in enumerate(split_layers):
#         temp_score = []
#         for batch, label in val_ds: # Re-iterate val_ds to get scores for current model
#             batch_score = pretrained_ensemble_models[idx](batch)
#             temp_score.append (np.array(batch_score))
#         model_score = np.concatenate(temp_score)

#         # model_prediction is derived from model_score for this specific model
#         model_prediction = np.argmax(model_score, axis=1)

#         # Save individual ROC plot for this model
#         _save_individual_model_roc_plot(
#             binary_true_labels_template, model_score, nClasses,
#             writeDir= net_dir, filename_identifier=f"Model_{split_layer}_",
#             target_labels=target_labels if 'target_labels' in globals() else None
#         )
#         print(f"Ensemble Model {split_layer} ROC plot saved.")

#         # Calculate metrics for this individual model and add to list
#         model_metrics = _calculate_metrics(
#             binary_true_labels_template, model_prediction, model_score, nClasses,
#             t_start = t_start_model,    # Pass the actual start time
#             t_end = t_end_model,         # Pass the actual end time
#             model_name=f"Ensemble Model {split_layer}"
#         )
#         all_ensemble_metrics.append(model_metrics)
#         ensemble_scores_list.append(model_score)

#     # Calculate ensemble average performance
#     ensemble_scores_array = np.array(ensemble_scores_list)
#     average_ensemble_score = np.mean(ensemble_scores_array, axis=0) # Corrected axis
#     ensemble_prediction = np.argmax(average_ensemble_score, axis=1)

#     # Save individual ROC plot for the ensemble average
#     _save_individual_model_roc_plot(
#         binary_true_labels_template, average_ensemble_score, nClasses,
#         writeDir= net_dir, filename_identifier="Ensemble_Average_",
#         target_labels=target_labels if 'target_labels' in globals() else None
#     )
#     print("Ensemble Average ROC plot saved.")

#     # Calculate metrics for the ensemble average and add to list
#     ensemble_metrics = _calculate_metrics(
#         binary_true_labels_template, ensemble_prediction, average_ensemble_score, nClasses,
#         t_start=t_start_model, 
#         t_end=t_end_model,
#         model_name="Ensemble Average"
#     )
#     all_ensemble_metrics.append(ensemble_metrics)

#     # Finally, write all collected metrics to a single combined CSV
#     write_combined_ensemble_results_csv(all_ensemble_metrics, net_dir, nClasses)
#     print("Combined ensemble results CSV saved.")

### Test set

In [None]:
# # results = fullmodel.evaluate (test_ds)
# # print (results)

# temp_score = []
# temp_labels = []

# for batch, label in test_ds:
#     batch_score = SWOIP_model(batch)
#     temp_score.append (np.array(batch_score))
#     temp_labels.append (np.array (label))

# model_score = np.concatenate(temp_score)
# true_labels = np.concatenate(temp_labels)

# # print (model_score)
# # print (true_labels)
# # print (model_score.shape) 

# model_prediction = np.argmax(model_score, axis=1)
# binary_true_labels = np.argmax (true_labels, axis=1)

# writeResultsFile(binary_true_labels, model_prediction, model_score, nClasses, t_start = 0 , t_end = 0, writeDir=net_dir)

# # print(model_prediction)
# # print(len(model_prediction))