### Initial packages

In [None]:
import os
import requests
import zipfile
import datetime
from collections import defaultdict
import pickle
import numpy as np
import src.manipulate_dataset as md
from src import test_metrics
import src.models as models
from src.xai import generate_gradcam
from src.train import tf_model_train
from src import ecg_plot
from tensorflow.keras.models import load_model
from tensorflow.config import list_physical_devices
print('GPUs Available: ', list_physical_devices('GPU'))  # Verify GPU use

### Load data
#### Example for label: PACE

In [None]:
# Input parameters
dataset = 'mimic-iv'  # or 'ptb-xl' (after inspection, some label annotations were manually altered for 'ptb-xl' in the original experiments)
dataset_relative_dir = 'data/mimic-iv/'  # or 'data/ptb-xl/'
metadata_relative_dir = 'output/metadata/'
ecg_plots_relative_dir = 'output/imgs/'
target_labels_dict = {
    'pace': 1000,
    'neg': 1000}  # Example for label: PACE, but could contain other conditions, eg.: {'wpw': 100, 'neg': 200} 
batch_size = 128

In [None]:
# Load data
target_labels_list = [label for label in target_labels_dict.keys() if label!='neg']
test_set_tf = md.tf_bal_dataset(
    ds_name=dataset,
    data_input_dir=dataset_relative_dir, 
    metadata_dir=metadata_relative_dir,
    batch_size = batch_size,
    n_samples_per_label=target_labels_dict)

In [None]:
# Visualization of an ECG sample

# Transform to np.array
sample_np_raw_data, sample_np_labels = md.tf_dataset_to_numpy(test_set_tf.take(1), data_switch=True, labels_switch=True)
print(f'Dimensions of input: {sample_np_raw_data.shape[1:]}, and labels: {sample_np_labels.shape[1:]}')

# Visualize with fine plotting (could also use ecg_plot.quick_plot, for fast plotting with fewer details
ecg_plot.fine_plot(
    signal=sample_np_raw_data[0],
    id_label=f'Fine plot - Labels: PACE - {sample_np_labels[0]}',  # example for label: PACE
    dpi=300,
    save_path=ecg_plots_relative_dir,
    show=False, save=True)  # Save locally instead of viewing inline

### Load pre-trained model
#### Example for label: PACE

In [None]:
# Download .zip files from Zenodo

zenodo_url = 'https://zenodo.org/records/14968732/files/ecg_xplaim.zip'
save_path = 'output/models/ecg_xplaim_PRETRAINED.zip'
extract_dir = 'output/models/ecg_xplaim_PRETRAINED/'

print('Downloading from Zenodo...')
response = requests.get(zenodo_url, stream=True)

if response.status_code == 200:
    with open(save_path, 'wb') as f:
        for chunk in response.iter_content(chunk_size=8192):
            f.write(chunk)
    print(f'Downloaded: {save_path}')
else:
    print(f'Failed to download. Status code: {response.status_code}')

In [None]:
# Unzip the downloaded files
print('Extracting...')
with zipfile.ZipFile(save_path, 'r') as zip_ref:
    zip_ref.extractall(extract_dir)
print(f'Extracted to: {extract_dir}')

In [None]:
# Load a pre-trained, task-specific model version
# Example for PACE
model_ecg_xplaim = load_model('output/models/ecg_xplaim_PRETRAINED/inter_model/ecg_xplaim_PACE.keras')

In [None]:
# Alternatively, load a locally trained model
# model_ecg_xplaim = load_model('output/models/path/to/ecg_xplaim/local/model.keras')

### Make predictions - diagnostic inference
#### Example for label: PACE

In [None]:
# Input parameters
default_metrics_to_report = ['recall', 'specificity', 'auc']
metrics_decimals = 3
labels = ['PACE']

In [None]:
# Extract y_true, y_pred
y_true = md.tf_dataset_to_numpy(test_set_tf, data_switch=False, labels_switch=True)
y_pred = model_ecg_xplaim.predict(test_set_tf)
print(f'y_true, y_pred extracted, shapes: {y_true.shape} and {y_pred.shape}')

In [None]:
# Calculate performance metrics
metrics = test_metrics.calculate_metrics(
    y_true=y_true, 
    y_pred=y_pred, 
    metrics_to_report=default_metrics_to_report,
    label_names=labels,
    round_decimals=metrics_decimals)

### Explainability (Grad-CAM)
#### Example for label: PACE

In [None]:
# Input parameters
sample_idx = 0
target_layer = 'conv1d_20'
ecg_gradcam_visual_relative_dir = 'output/imgs/'
default_color_overlay = 'Reds'

In [None]:
# Produce Grad-CAM activation
sample_x = np.expand_dims(sample_np_raw_data[sample_idx], axis=0)
sample_y_true = sample_np_labels[sample_idx]
sample_y_pred = model_ecg_xplaim.predict(sample_x)
sample_y_pred = (sample_y_pred[0]>0.5).astype(int) # Convert to int (with threshold: 0.5)
gradcam_activation = generate_gradcam(model_ecg_xplaim, sample_x, target_layer_name=target_layer, class_idx=None)
print(f'Sample with labels >> y_true: {sample_y_true}, y_pred: {sample_y_pred}')

In [None]:
# Visualize Grad-CAM

# 12-lead plot
ecg_plot.gradcam_plot(
    signal=sample_x[0],
    gradcam=gradcam_activation,
    id_label=f'Grad-CAM: PACE - y_true: {sample_y_true}, y_pred: {sample_y_pred}',
    save_path=ecg_gradcam_visual_relative_dir,  # Must end with '/'
    dpi=300,
    color_overlay=default_color_overlay,
    show=False, save=True)  # Save locally instead of viewing inline

# Single lead (II) plot
ecg_plot.gradcam_plot_single(
    signal=sample_x[0],
    gradcam=gradcam_activation, 
    lead_index=1,
    id_label=f'Grad-CAM-single: PACE - y_true: {sample_y_true}, y_pred: {sample_y_pred}',
    save_path=ecg_gradcam_visual_relative_dir,  # Must end with '/'
    dpi=300,
    color_overlay=default_color_overlay,
    show=False, save=True)  # Save locally instead of viewing inline

### Model comparison
#### Example against vanilla CNN model, for label: PACE

In [None]:
# Input parameters
model_name = 'vanilla_CNN_PACE'
label_dict = {'pace': 5000, 'neg': 5000}  # example for label: PACE (can be changed accordingly)
train_ds_name = 'mimic-iv'  # Could also be 'ptb-xl'
train_ds_dir = 'data/mimic-iv/'  # Could also be 'data/ptb-xl/'
metadata_dir = 'output/metadata/'
batch_size = 128  # batch_size*(train + val + test batches) must be <= total n of samples
train_batches = 70
val_batches = 5
test_batches = 3
n_epochs = 2  # This significantly affects the time required to run. Here set to a low value for a quick demo - plz change accordingly.
models_output_dir = 'output/models/'
model_generator = models.Simple_CNN_generator()
default_metrics_to_report = ['recall', 'specificity', 'auc']
metrics_decimals = 3
pval_decimals = 4
metrics_labels = ['PACE']  # example for label: PACE (can be changed accordingly)

In [None]:
# Train vanilla CNN on mimic-iv subset and save the model
tf_model_train(
    model_name=model_name, label_dict=label_dict,
    train_ds_name=train_ds_name, train_ds_dir=train_ds_dir, metadata_dir=metadata_dir,
    train_batches=train_batches, val_batches=val_batches, test_batches=test_batches,
    n_epochs=n_epochs, batch_size=batch_size, models_output_dir=models_output_dir,
    model_generator=model_generator)

In [None]:
# load locally trained vanilla CNN model
model_vanilla_cnn = load_model('output/models/vanilla_CNN_PACE_package/model_vanilla_CNN_PACE.keras')  # Load the last saved version after completing all epochs 

In [None]:
# Load the pre-saved, separate test set (created during training on mimic-iv subset)
test_data_path = 'output/models/vanilla_CNN_PACE_package/test_set_vanilla_CNN_PACE.npz'
test_data = np.load(test_data_path)
test_x = test_data['samples']
test_y_true = test_data['labels']

In [None]:
# Make predictions for both models (ECG-XPLAIM and vanilla CNN)
test_y_pred_ECG_XPLAIM = model_ecg_xplaim.predict(test_x)
test_y_pred_VANILLA_CNN = model_vanilla_cnn.predict(test_x)
print(f'y_pred extracted for ECG-XPLAIM and vanilla CNN, shapes: {test_y_pred_ECG_XPLAIM.shape} and {test_y_pred_VANILLA_CNN.shape}')

In [None]:
# Calculate performance metrics

print('Metrics for ECG-XPLAIM:')
metrics_ECG_XPLAIM = test_metrics.calculate_metrics(
    y_true=test_y_true, 
    y_pred=test_y_pred_ECG_XPLAIM, 
    metrics_to_report=default_metrics_to_report,
    label_names=labels,
    round_decimals=metrics_decimals)

print('\n')
print('Metrics for vanilla CNN:')
metrics_VANILLA_CNN = test_metrics.calculate_metrics(
    y_true=test_y_true, 
    y_pred=test_y_pred_VANILLA_CNN, 
    metrics_to_report=default_metrics_to_report,
    label_names=labels,
    round_decimals=metrics_decimals)

In [None]:
# Compare metrics

def metric_comparison_print(metric_comparison):
    for label, stats in metric_comparison.items():
        print(f" Label: {label}")
        print(f"  - Metric - Model 1:  {stats['auc_model1']}")
        print(f"  - Metric - Model 2:  {stats['auc_model2']}")
        print(f"  - Metric - Diff:     {stats['auc_diff']}")
        print(f"  - p-value:      {stats['p_value']}")
        print(f"  - Better Model: {stats['better_model']}")
    return None

auc_comparison = test_metrics.compare_auc_bootstrap(
    y_true=test_y_true,
    y_pred_1=test_y_pred_ECG_XPLAIM,
    y_pred_2=test_y_pred_VANILLA_CNN,
    label_names=metrics_labels, round_decimals = pval_decimals)

recall_comparison = test_metrics.compare_recall_mcnemar(
    y_true=test_y_true,
    y_pred_1=test_y_pred_ECG_XPLAIM,
    y_pred_2=test_y_pred_VANILLA_CNN,
    label_names=metrics_labels, round_decimals = pval_decimals)

specificity_comparison = test_metrics.compare_specificity_mcnemar(
    y_true=test_y_true,
    y_pred_1=test_y_pred_ECG_XPLAIM,
    y_pred_2=test_y_pred_VANILLA_CNN,
    label_names=metrics_labels, round_decimals = pval_decimals)

print('Model comparison: \n Model 1 - ECG-XPLAIM vs. Model 2 - Vanilla CNN')
print('\n\n >> AUC (bootstrap) \n')
metric_comparison_print(auc_comparison)
print('\n\n >> Recall (McNemar) \n')
metric_comparison_print(recall_comparison)
print('\n\n >> Specificity (McNemar) \n')
metric_comparison_print(specificity_comparison)


### End of file