Ideas and code partly taken from https://www.kaggle.com/code/jraska1/ptb-xl-ecg-1d-convolution-neural-network - most of it - EMPTY DATA FILLING WHICH IS NOT ESSENCTIAL FOR THE ECG SIGNAL CLASSIFICATION.

The purpose of this notebook is to create a prediction model, which takes into account the metadata about patient and the samples as the ECG curves. Targets for the model will be superclasses as defined by the dataset.

Superclasses enumerated by dataset description are as follows:
```
Records | Superclass | Description
9528 | NORM | Normal ECG
5486 | MI | Myocardial Infarction
5250 | STTC | ST/T Change
4907 | CD | Conduction Disturbance
2655 | HYP | Hypertrophy
```

In [None]:
import os
import ast
import wfdb

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style('darkgrid')

In [None]:
!python --version

In [None]:
%pip list

In [None]:
import torch

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

First of all I need to load metadata about patients and samples provided by dataset. All metadata will be loaded to **ECG_df** and **SCP_df** dataframes respectively.

In [None]:
PATH_TO_DATA = r'D:\ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3/'

ECG_df = pd.read_csv(os.path.join(PATH_TO_DATA, 'ptbxl_database.csv'), index_col='ecg_id')
ECG_df.scp_codes = ECG_df.scp_codes.apply(lambda x: ast.literal_eval(x))
ECG_df.patient_id = ECG_df.patient_id.astype(int)
ECG_df.nurse = ECG_df.nurse.astype('Int64')
ECG_df.site = ECG_df.site.astype('Int64')
ECG_df.validated_by = ECG_df.validated_by.astype('Int64')

SCP_df = pd.read_csv(os.path.join(PATH_TO_DATA, 'scp_statements.csv'), index_col=0)
SCP_df = SCP_df[SCP_df.diagnostic == 1]

ECG_df

ECG samples are strattified to 10 groups. The authors of PTB-XL ECG dataset suggest use first 8 groups as the training samples. Last two groups then use as the validation and test sample set. 
I will accept this suggestion on my following work.

In [None]:
ECG_df.strat_fold.value_counts()

I am going to add one more column **scp_classes** to ECG_df dataset, which represents all superlasses (as a list of abbreviations) assigned to the sample by cardiologists.

In [None]:
def diagnostic_class(scp):
    res = set()
    for k in scp.keys():
        if k in SCP_df.index:
            res.add(SCP_df.loc[k].diagnostic_class)
    return list(res)
                    
ECG_df['scp_classes'] = ECG_df.scp_codes.apply(diagnostic_class)

First problem, I would like to cope with, are null values in metadata dataframe. There is a quick look at the problem:

In [None]:
import missingno as msno

msno.matrix(ECG_df)
plt.show()

And to add another angle of the view, there is an overview of unique values in all columns of metadata dataframe:

In [None]:
ECG_df[[col for col in ECG_df.columns if col not in ('scp_codes', 'scp_classes')]].nunique(dropna=True)

# Data preparation for modeling

I need first prepare input and output (targets) for my models. 

As inputs I will use both patient metadata (now loaded in the ECG_df dataframe) and ECG curves (in the ECG_data numpy array) respectively. But both require some rework to be useful for modeling, which will be done in following few steps.

As outputs I will create new dataframe with rows equal to samples and columns corresponding with diagnosis superclasses.

Because I will have two inputs and one output, I will preffix all created dataframes as follows:
- X - prefix for patient and sample metadata
- Y - prefix for ECG curves
- Z - prefix for targets

## X dataframe ...

I won't use all columns from ECG_df dataframe, but only a subset of them. 
Created dataframe **X** comprises only columns, witch are related to patient, his health condition and the path of ECG and device with which ECG was recorded

In [None]:
X = pd.DataFrame(index=ECG_df.index)

X['patient_id'] = ECG_df.patient_id

X['age'] = ECG_df.age
X.age.fillna(0)

X['sex'] = ECG_df.sex.astype(float)
X.sex.fillna(0)

X['height'] = ECG_df.height
X.loc[X.height < 50, 'height'] = np.nan
X.height.fillna(0)

X['device'] = ECG_df.device

X['weight'] = ECG_df.weight
X.weight.fillna(0)

X['infarction_stadium1'] = ECG_df.infarction_stadium1.astype(str).replace({
    'unknown': 0,
    'Stadium I': 1,
    'Stadium I-II': 2,
    'Stadium II': 3,
    'Stadium II-III': 4,
    'Stadium III': 5
}).fillna(0)

X['infarction_stadium2'] = ECG_df.infarction_stadium2.astype(str).replace({
    'unknown': 0,
    'Stadium I': 1,
    'Stadium II': 2,
    'Stadium III': 3
}).fillna(0)

X['pacemaker'] = (ECG_df.pacemaker == 'ja, pacemaker').astype(float)

X['filename_lr'] = ECG_df.filename_lr
X['filename_hr'] = ECG_df.filename_hr

X['total_count'] = X.groupby('patient_id')['patient_id'].transform('count')

# Count unique items
unique_count = X['patient_id'].nunique()
print(f'Unique patiens {unique_count}')

X_sort = X.sort_values('total_count')
X_sort

In [None]:
# check the ending of dataframe more
print(X.tail(1000))

In [None]:
print(X['device'].unique())

In [None]:
def draw_unique_device_ECG(df: pd.DataFrame) -> None:
    # Count the unique values in the 'device' column
    device_counts = df['device'].value_counts()
    
    fig, ax = plt.subplots(figsize=(8,6))
    
    # Plot the bar chart on the given axis
    device_counts.plot(kind='bar', ax=ax)
    
    ax.set_facecolor('white')
    ax.spines['top'].set_color('gray')
    ax.spines['right'].set_color('gray')
    ax.spines['bottom'].set_color('gray')
    ax.spines['left'].set_color('gray')
    ax.grid(True, color='gray', alpha=0.65, linewidth=0.5, linestyle='dashed')
    
    # Add title and labels
    ax.set_title('ECGs Counts with a Different Device')
    ax.set_xlabel('Device')
    ax.set_ylabel('Count')

    # Annotate each bar with its value
    for i in range(len(device_counts)):
        ax.text(i, device_counts.iloc[i] + 50, str(device_counts.iloc[i]), ha='center')

    # Show the plot
    plt.show()

In [None]:
draw_unique_device_ECG(X)

## Z targets ...

I am going to create **Z** dataframe with columns corresponding to diagnoses superclasses.

In [None]:
Z = pd.DataFrame(0, index=ECG_df.index, columns=['NORM', 'MI', 'STTC', 'CD', 'HYP'], dtype='int')
for i in Z.index:
    for k in ECG_df.loc[i].scp_classes:
        Z.loc[i, k] = 1

Z

### Check for rows (signals) with multiple annotations

In [None]:
mask = (Z.iloc[:, 1:] == 1).sum(axis=1) > 1  # Skip the first column (ecg_id)

# Get rows with multiple 1s
rows_with_multiple_ones = Z[mask]

print(f'Total signals with multiple annotations: {len(rows_with_multiple_ones)}')
print(rows_with_multiple_ones)

## Splitting to train, validate and test datasets

As the authors of PTB-XL ECG dataset suggest, I will split all input and output dataset to training, validation and test subsets according *strat_fold* column. Let's make different dataframe of data

In [None]:
X_train, Z_train = X[ECG_df.strat_fold <= 8],Z[ECG_df.strat_fold <= 8]
X_valid, Z_valid = X[ECG_df.strat_fold == 9], Z[ECG_df.strat_fold == 9]
X_test, Z_test  = X[ECG_df.strat_fold == 10], Z[ECG_df.strat_fold == 10]

In [None]:
print(f'Train samples: {len(X_train)}')
print(f'Valid samples: {len(X_valid)}')
print(f'Test samples: {len(X_test)}')

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import cm

def draw_ecg_counts(train_df: pd.DataFrame, valid_df: pd.DataFrame, test_df: pd.DataFrame, 
                    categories=None, top_text_offset=50, y_max=None) -> None:
    """
    Draws a grouped bar chart of total ECG counts for each class in Train, Validation, and Test sets.

    Parameters:
    train_df, valid_df, test_df (pd.DataFrame): DataFrames containing the ECG data.
    categories (list of str, optional): List of category names to include in the plot.
                                        Defaults to ['NORM', 'MI', 'STTC', 'CD', 'HYP'].
    top_text_offset (int, optional): Offset for the text annotation above each bar. Default is 50.
    y_max (int, optional): Maximum limit for the y-axis. Auto-scales if not provided.
    """
    if categories is None:
        categories = ['NORM', 'MI', 'STTC', 'CD', 'HYP']
    
    # Compute total counts for each set
    train_counts = train_df[categories].sum()
    valid_counts = valid_df[categories].sum()
    test_counts = test_df[categories].sum()

    # Compute max value dynamically if y_max is not given
    max_value = max(train_counts.max(), valid_counts.max(), test_counts.max())
    if y_max is None:
        y_max = max_value * 1.1  # Add 10% buffer to max value
    
    # Define bar width dynamically
    num_datasets = 3  # Train, Valid, Test
    width = 0.8 / num_datasets  # Adjust width dynamically

    # Define positions for grouped bars
    x = np.arange(len(categories))

    # Use plasma colormap for colors
    plasma = cm.plasma
    colors_train = plasma(0.1)  # Light purple
    colors_valid = plasma(0.5)  # Medium yellow
    colors_test = plasma(0.9)   # Dark red

    fig, ax = plt.subplots(figsize=(12, 7))

    # Plot bars for Train, Valid, Test with plasma colors and transparency (alpha=0.6)
    bars1 = ax.bar(x - width, train_counts, width, label='Train', color=colors_train, edgecolor='black', alpha=0.6)
    bars2 = ax.bar(x, valid_counts, width, label='Valid', color=colors_valid, edgecolor='black', alpha=0.8)
    bars3 = ax.bar(x + width, test_counts, width, label='Test', color=colors_test, edgecolor='black', alpha=0.8)

    # Customize plot
    ax.set_facecolor('white')
    ax.set_title('ECG Count for Each Class (Train, Valid, Test)', fontsize=19, fontweight='bold')
    ax.set_xlabel('ECG Class', fontsize=17)
    ax.set_ylabel('Total Count', fontsize=20)
    ax.set_xticks(x)
    ax.set_xticklabels(categories, fontsize=20)
    
    # Add both vertical and horizontal grid lines
    ax.grid(True, which='both', axis='both', linestyle='dashed', alpha=0.6, linewidth=0.7, color='gray')

    # Make legend bigger
    ax.legend(fontsize=22, loc='upper right', frameon=True)

    # Set y-axis limit if specified
    if y_max is not None:
        ax.set_ylim(0, y_max)

    # Set y-tick font size directly (no argument needed)
    ax.tick_params(axis='y', labelsize=16)

    # Add text annotations on top of bars
    def add_labels(bars):
        for bar in bars:
            height = bar.get_height()
            offset = top_text_offset if height + top_text_offset < (y_max or float('inf')) else height * 0.05
            ax.text(bar.get_x() + bar.get_width() / 2, height + offset, f"{int(height)}",
                    ha='center', va='bottom', fontsize=20, fontweight='bold')

    add_labels(bars1)
    add_labels(bars2)
    add_labels(bars3)

    # Improve layout and show plot
    plt.tight_layout()
    plt.show()


In [None]:
draw_ecg_counts(Z_train, Z_valid, Z_test, y_max=8200)

# JUST FOR VERIFICATION OF NORM IN TEST SET DO INSPECTION HOW MANY SAMPLES ARE LABELED AS NORM AND HAS OTHER ANNOTATION, NOTICE THIS IN THE INVESTIFATION (CONFUSION MATRICES)

In [None]:
# check TRUE NORM IN TEST 
norm_count = Z_test[(Z_test["NORM"] == 1) & 
           ((Z_test["MI"] == 0) & 
           (Z_test["STTC"] == 0) & 
           (Z_test["CD"] == 0) & 
           (Z_test["HYP"] == 0))].shape[0]

print(f'Test samples with NORM count: {norm_count}')

norm_with_other_count = Z_test[(Z_test["NORM"] == 1) & 
                       ((Z_test["MI"] != 0) |
                       (Z_test["STTC"] != 0) | 
                       (Z_test["CD"] != 0) |
                       (Z_test["HYP"] != 0))].shape[0]
print(f'Test samples with NORM and some other annotation in the same signal count: {norm_with_other_count}')
print(Z_test[(Z_test["NORM"] == 1) & 
                       ((Z_test["MI"] != 0) |
                       (Z_test["STTC"] != 0) | 
                       (Z_test["CD"] != 0) |
                       (Z_test["HYP"] != 0))])

## Load Y - ECG for further processing and checking

In [None]:
from tqdm import tqdm
import wfdb
import os
import numpy as np

sampling_rate = 100

def load_raw_data(df, sampling_rate, path):
    if sampling_rate == 100:
        filenames = df.filename_lr
    else:
        filenames = df.filename_hr

    # Add tqdm here to show the progress
    data = [wfdb.rdsamp(os.path.join(path, f)) for f in tqdm(filenames, desc="Loading Data", unit="file")]
    
    # Extracting signals
    data = np.array([signal for signal, meta in data])
    
    return data

In [None]:
%%time
print('Loading training ECGs...')
Y_train = load_raw_data(X_train, sampling_rate, PATH_TO_DATA)
print(f'Loaded {len(Y_train)} samples')

In [None]:
%%time
print('Loading validation ECGs...')
Y_valid = load_raw_data(X_valid, sampling_rate, PATH_TO_DATA)
print(f'Loaded {len(Y_valid)} samples')

In [None]:
%%time
print('Loading testing ECGs...')
Y_test = load_raw_data(X_test, sampling_rate, PATH_TO_DATA)
print(f'Loaded {len(Y_test)} samples')

### Also very good work on data wrangling was already done at [PTB XL Dataset Wrangling](https://www.kaggle.com/code/khyeh0719/ptb-xl-dataset-wrangling) notebook.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties

def plot_ecg(ecg: np.array, trend: np.array = None, title: str = 'ECG signal', y_title='Voltage (mV)', sampling_rate: int = 100):
    # Standard lead names for 12-lead ECG
    lead_names = ['I', 'II', 'III', 'aVL', 'aVR', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
    
    fig, axs = plt.subplots(12, 1, figsize=(14, 15), sharex=True)  # Share x-axis for alignment

    for i, ax in enumerate(axs.flat):
        if trend is not None:
            ax.plot(trend[:, i], alpha=0.85, linewidth=1.5, linestyle='dashed', color='coral')
            fig.legend(['Trend', 'Signal'], loc='upper right', bbox_to_anchor=(0.92, 1.03), ncol=2, prop=FontProperties(size=16))

        ax.plot(ecg[:, i], alpha=0.85, linewidth=0.85, color='indigo')
        ax.set_facecolor('white')
        ax.spines['top'].set_color('gray')
        ax.spines['right'].set_color('gray')
        ax.spines['bottom'].set_color('gray')
        ax.spines['left'].set_color('gray')

        ax.grid(True, color='gray', alpha=0.65, linewidth=0.5, linestyle='dashed')
        
         # Remove extra ticks and set x-limits
        ax.set_xticks(np.arange(0, ecg.shape[0] + 1, ecg.shape[0] / 10))  # Customize the number of ticks
        ax.set_xlim([0, ecg.shape[0]])  # Set xlim to prevent offset
        #if i != len(ecg[1]) - 1:
        #    ax.set_xticklabels([])  # Remove x-axis labels

        ax.tick_params(axis='both', which='major', labelsize=12)

        # Compute min and max
        y_min, y_max = np.min(ecg[:, i]), np.max(ecg[:, i])
        ax.set_ylim([y_min, y_max])  # Set y-axis limits

        # Move voltage labels to the left
        #ax.text(-ecg.shape[0] * 0.02, (y_max + y_min) / 2, f"{y_min:.1f} to {y_max:.1f} mV", 
        #        fontsize=12, color='black', ha='center', va='center')

        # Move lead annotations to the right
        ax.text(ecg.shape[0] * 1.025, (y_max + y_min) / 2, lead_names[i], 
                fontsize=15, color='black', ha='center', va='center')

    # Adjust the main title position (moved up slightly)
    plt.suptitle(title, fontsize=20, y=1.02)

    # Adjust x and y-axis labels positioning
    fig.text(0.5, 0.01, f'Time (in samples)', ha='center', va='center', fontsize=16)  # Moved down
    fig.text(0.04, 0.5, y_title, ha='center', va='center', rotation='vertical', fontsize=16)  # Moved left

    # Adjust layout to ensure labels are well positioned
    plt.subplots_adjust(left=0.08, right=0.90, top=0.99, bottom=0.04)

    plt.show()


In [None]:
plot_ecg(Y_train[0], title='Original Train ECG Sample')

In [None]:
plot_ecg(Y_valid[0], title='Original Validation ECG Sample')

In [None]:
plot_ecg(Y_test[0], title='Original Test ECG Sample')

In [None]:
import time
def detrend_signal_polynom(ecg_signal, poly_degree=10):
    """
    Detrend ECG signals using polynomial detrending.
    
    Args:
        signal (ndarray): 2D NumPy array representing ECG signals with shape (signal, lead).
        poly_degree (int): Degree of polynomial for detrending.
        
    Returns:
        ndarray: Detrended ECG signals.
    """
    detrended_signals = np.zeros_like(ecg_signal)
    trend_curves = np.zeros_like(ecg_signal)
    # preallocate once
    trend_curve = np.linspace(0, ecg_signal.shape[0], ecg_signal.shape[0])  # Do no take into consideration time, it is datapoint in general
    
    # Detrend each lead
    for i in range(ecg_signal.shape[1]):
        #start_time = time.time()
        detrended_signal = np.polyfit(trend_curve, ecg_signal[:, i], poly_degree)
        trend = np.polyval(detrended_signal, trend_curve)
        #end_time = time.time()
        #execution_time = end_time - start_time
        #print(f"Execution Time: {execution_time} seconds")
        trend_curves[:, i] = trend
        detrended_signals[:, i] = ecg_signal[:, i] - trend
    return detrended_signals, trend_curves

## Detrend signal

In [None]:
detrended_train_signal, trend_train_curves = detrend_signal_polynom(Y_train[0])
plot_ecg(detrended_train_signal, trend_train_curves, title='Detrended Train ECG Sample')

In [None]:
#plot_ecg(trend_train_curves, title='Train Trend Curves')

In [None]:
detrended_valid_signal, trend_valid_curves = detrend_signal_polynom(Y_valid[0])
plot_ecg(detrended_valid_signal, trend_valid_curves, title='Detrended Valid ECG Sample')

In [None]:
#plot_ecg(trend_valid_curves, title='Valid Trend Curves')

In [None]:
detrended_test_signal, trend_test_curves = detrend_signal_polynom(Y_test[0])
plot_ecg(detrended_test_signal, trend_test_curves, title='Detrended Test ECG Sample')

In [None]:
#plot_ecg(trend_test_curves, title='Test Trend Curves')

## Lets chose polinomial approach to baseline all the ECGs... it looks better and it is fast... (TODO: inspect other approaches)

In [None]:
%%time
from tqdm import tqdm
# Preallocated arrays for the detranded signals
Y_train_detrend = np.zeros_like(Y_train)
Y_test_detrend = np.zeros_like(Y_test)
Y_valid_detrend = np.zeros_like(Y_valid)

for i in tqdm(range(Y_train.shape[0]), desc ='Detrending Training'):
    Y_train_detrend[i], _ = detrend_signal_polynom(Y_train[i])

for i in tqdm(range(Y_test.shape[0]), desc ='Detrending Testing'):
    Y_test_detrend[i], _ = detrend_signal_polynom(Y_test[i])
    
for i in tqdm(range(Y_valid.shape[0]), desc ='Detrending Valid'):
    Y_valid_detrend[i], _ = detrend_signal_polynom(Y_valid[i])

## Lets scale ECGs

In [None]:
#from sklearn.preprocessing import StandardScaler 
#Y_scaler = StandardScaler()
#Y_scaler.fit(Y_train_detrend.reshape(-1, Y_train_detrend.shape[-1]))
#print(f'Mean: {Y_scaler.mean_}')
#print(f'Variance: {Y_scaler.var_}')
#print(f'Scale: {Y_scaler.scale_}')

In [None]:
# scale and reshape back
#Y_train_scaled = Y_scaler.transform(Y_train_detrend.reshape(-1, Y_train_detrend.shape[-1])).reshape(Y_train_detrend.shape)
#Y_valid_scaled = Y_scaler.transform(Y_valid_detrend.reshape(-1, Y_valid_detrend.shape[-1])).reshape(Y_valid_detrend.shape)
#Y_test_scaled  = Y_scaler.transform(Y_test_detrend.reshape(-1, Y_test_detrend.shape[-1])).reshape(Y_test_detrend.shape)

In [None]:
Y_train_detrend.shape

In [None]:
def z_normalize_sample_lead_wise(ecg_data):
    """
    Apply Z-normalization sample-wise and lead-wise to an ECG signal array.

    Parameters:
        ecg_data (numpy.ndarray): Input ECG data array of shape (num_samples, num_timesteps, num_leads).

    Returns:
        numpy.ndarray: Z-normalized ECG data array of the same shape as input.
    """
    # Initialize an array to store the normalized data
    normalized_ecg_data = np.zeros_like(ecg_data)

    # Loop through each sample
    for sample_idx in range(ecg_data.shape[0]):
        # Loop through each lead (signal)
        for lead_idx in range(ecg_data.shape[2]):
            # Extract the lead data for the current sample (shape: (num_timesteps,))
            lead_data = ecg_data[sample_idx, :, lead_idx]

            # Calculate the mean and standard deviation for the lead within the sample
            mean = np.mean(lead_data)
            std = np.std(lead_data)

            # Apply Z-normalization to the lead
            if std != 0:  # Avoid division by zero
                normalized_lead_data = (lead_data - mean) / std
            else:
                normalized_lead_data = np.zeros_like(lead_data)  # Handle constant signals

            # Store the normalized lead data back into the array
            normalized_ecg_data[sample_idx, :, lead_idx] = normalized_lead_data

    return normalized_ecg_data

In [None]:
def robust_scale_sample_lead_wise(ecg_data):
    """
    Apply Robust Scaling sample-wise and lead-wise to an ECG signal array.

    Parameters:
        ecg_data (numpy.ndarray): Input ECG data array of shape (num_samples, num_timesteps, num_leads).

    Returns:
        numpy.ndarray: Robust-scaled ECG data array of the same shape as input.
    """
    # Initialize an array to store the scaled data
    scaled_ecg_data = np.zeros_like(ecg_data)

    # Loop through each sample
    for sample_idx in range(ecg_data.shape[0]):
        # Loop through each lead (signal)
        for lead_idx in range(ecg_data.shape[2]):
            # Extract the lead data for the current sample (shape: (num_timesteps,))
            lead_data = ecg_data[sample_idx, :, lead_idx]

            # Calculate the median and IQR for the lead within the sample
            median = np.median(lead_data)
            q1 = np.percentile(lead_data, 25)  # 25th percentile (Q1)
            q3 = np.percentile(lead_data, 75)  # 75th percentile (Q3)
            iqr = q3 - q1  # Interquartile range

            # Apply Robust Scaling to the lead
            if iqr != 0:  # Avoid division by zero
                scaled_lead_data = (lead_data - median) / iqr
            else:
                scaled_lead_data = np.zeros_like(lead_data)  # Handle constant signals

            # Store the scaled lead data back into the array
            scaled_ecg_data[sample_idx, :, lead_idx] = scaled_lead_data

    return scaled_ecg_data

In [None]:
def min_max_scale_sample_lead_wise(ecg_data):
    """
    Apply Min-Max Scaling sample-wise and lead-wise to an ECG signal array, scaling data to the range [-1, 1].

    Parameters:
        ecg_data (numpy.ndarray): Input ECG data array of shape (num_samples, num_timesteps, num_leads).

    Returns:
        numpy.ndarray: Min-Max scaled ECG data array of the same shape as input, with values scaled to [-1, 1].
    """
    # Initialize an array to store the scaled data
    scaled_ecg_data = np.zeros_like(ecg_data)

    # Loop through each sample
    for sample_idx in range(ecg_data.shape[0]):
        # Loop through each lead (signal)
        for lead_idx in range(ecg_data.shape[2]):
            # Extract the lead data for the current sample (shape: (num_timesteps,))
            lead_data = ecg_data[sample_idx, :, lead_idx]

            # Calculate the min and max for the lead within the sample
            lead_min = np.min(lead_data)
            lead_max = np.max(lead_data)

            # Apply Min-Max Scaling to the lead and scale to [-1, 1]
            if lead_max != lead_min:  # Avoid division by zero
                scaled_lead_data = 2 * (lead_data - lead_min) / (lead_max - lead_min) - 1
            else:
                scaled_lead_data = np.zeros_like(lead_data)  # Handle constant signals

            # Store the scaled lead data back into the array
            scaled_ecg_data[sample_idx, :, lead_idx] = scaled_lead_data

    return scaled_ecg_data

In [None]:
def min_max_scale_centered_to_0(ecg_data):
    """
    Apply Min-Max Scaling to ECG data to be in the range [-1, 1], while preserving the baseline around 0.
    
    This is done by centering the data around 0 (using median or mean) and then scaling.
    
    Parameters:
        ecg_data (numpy.ndarray): Input ECG data array of shape (num_samples, num_timesteps, num_leads).
        
    Returns:
        numpy.ndarray: Min-Max scaled ECG data array of the same shape as input, with values scaled to [-1, 1] and centered around 0.
    """
    # Initialize an array to store the scaled data
    scaled_ecg_data = np.zeros_like(ecg_data)

    # Loop through each sample
    for sample_idx in range(ecg_data.shape[0]):
        # Loop through each lead (signal)
        for lead_idx in range(ecg_data.shape[2]):
            # Extract the lead data for the current sample (shape: (num_timesteps,))
            lead_data = ecg_data[sample_idx, :, lead_idx]

            # Center the data by subtracting the median (or mean) for each lead to preserve baseline around 0
            median = np.median(lead_data)  # You can use np.mean if preferred
            centered_data = lead_data - median

            # Find the max absolute value of the centered data
            max_abs_value = np.max(np.abs(centered_data))

            # Apply Min-Max Scaling to the centered data, ensuring the range is [-1, 1]
            if max_abs_value != 0:  # Avoid division by zero
                scaled_lead_data = centered_data / max_abs_value
            else:
                scaled_lead_data = np.zeros_like(centered_data)  # Handle constant signals

            # Store the scaled lead data back into the array
            scaled_ecg_data[sample_idx, :, lead_idx] = scaled_lead_data

    return scaled_ecg_data

In [None]:
%%time
Y_train_scaled = min_max_scale_centered_to_0(Y_train_detrend)

In [None]:
%%time
Y_test_scaled = min_max_scale_centered_to_0(Y_test_detrend)

In [None]:
%%time
Y_valid_scaled = min_max_scale_centered_to_0(Y_valid_detrend)

In [None]:
plot_ecg(Y_train_scaled[0], title='Train ECG Sample Scaled', y_title='Normalized scale')

In [None]:
plot_ecg(Y_valid_scaled[0], title='Validation ECG Sample Scaled', y_title='Normalized scale')

In [None]:
plot_ecg(Y_test_scaled[0], title='Test ECG Sample Scaled', y_title='Normalized scale')

## Augment!

In [None]:
from matplotlib.font_manager import FontProperties
import numpy as np
import matplotlib.pyplot as plt

def plot_2_ecgs(ecg_1: np.array, title_ecg_1: str, ecg_2: np.array, title_ecg_2: str, title: str = '2 ECG signals', sampling_rate: int = 1000, render_title = True):
    fig, axs = plt.subplots(12, 1, figsize=(14, 15))
    # Standard lead names for 12-lead ECG
    lead_names = ['I', 'II', 'III', 'aVL', 'aVR', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
    for i, ax in enumerate(axs.flat):
        ax.plot(ecg_1[:,i], alpha=0.85, linewidth=0.85, color='indigo', label=title_ecg_1)
        ax.plot(ecg_2[:,i], alpha=0.75, linewidth=0.85, color='coral', label=title_ecg_2) 
        ax.set_facecolor('white')
        ax.spines['top'].set_color('gray')
        ax.spines['right'].set_color('gray')
        ax.spines['bottom'].set_color('gray')
        ax.spines['left'].set_color('gray')
        ax.grid(True, color='gray', alpha=0.65, linewidth=0.5, linestyle='dashed')

        # Set xlim to prevent offset and remove any gap
        ax.set_xlim([0, ecg_1.shape[0]])

        # Increase the font size of ticks
        ax.tick_params(axis='both', which='major', labelsize=13)  # You can adjust this value as needed

        ax.set_xticks(np.arange(0, ecg_1.shape[0] + 1, ecg_1.shape[0] / 10))  # Add more ticks
        if not i == len(ecg_1[1]) - 1:
            ax.set_xticklabels([])  # Remove x-axis labels

        # Compute min and max
        y_min, y_max = np.min(ecg_1[:, i]), np.max(ecg_1[:, i])
        ax.set_ylim([y_min, y_max])  # Set y-axis limits

        # Move lead annotations to the right
        ax.text(ecg_1.shape[0] * 1.025, (y_max + y_min) / 2, lead_names[i], 
                fontsize=15, color='black', ha='center', va='center')
        
    fig.text(0.5, 0.00, f'Time (in samples)', ha='center', va='center', fontsize=16)
    fig.text(0.0, 0.5, 'Normalized scale', ha='center', va='center', rotation='vertical', fontsize=16)
    # Add a main title
    if render_title:
        plt.suptitle(title, fontsize=25, y=0.99)
    # Add legend at the top
    fig.legend([title_ecg_1, title_ecg_2], loc='upper right', bbox_to_anchor=(1.0, 1.0), ncol=2, prop=FontProperties(size=16))

    # Remove space between subplots
    plt.tight_layout()
    plt.show()


In [None]:
import ecgmentations as E

In [None]:
transform = E.Sequential([
    E.GaussNoise(p=1.0,
                variance=0.002 # mV
                )
])
transformed_ecg = transform(ecg=Y_train_scaled[0])['ecg']
plot_2_ecgs(Y_train_scaled[0], 'Original', transformed_ecg, 'Gaussian Noise', 'Augmentation Gaussian Noise')

In [None]:
transform = E.Sequential([
    E.SinePulse(p=1.0,
                ecg_frequency=100., # 100 Hz 
                amplitude_limit=0.2, # mV
                )
])
transformed_ecg = transform(ecg=Y_train_scaled[0])['ecg']
plot_2_ecgs(Y_train_scaled[0], 'Original', transformed_ecg, 'Sine Pulse', 'Augmentation Sine Pulse')

In [None]:
transform = E.Sequential([
    E.RespirationNoise(p=1.0,
                ecg_frequency=100., # 100Hz
                breathing_rate_range=(12, 18), # breathing rate range in bpm
                amplitude_limit=0.2, # mV
                )
])
transformed_ecg = transform(ecg=Y_train_scaled[0])['ecg']
plot_2_ecgs(Y_train_scaled[0], 'Original', transformed_ecg, 'Breathing', 'Augmentation Breathing')

In [None]:
transform = E.Sequential([
    E.AmplitudeScale(p=1.0,
                    scaling_range=(-0.2,0.2)
    )
])
transformed_ecg = transform(ecg=Y_train_scaled[0])['ecg']
plot_2_ecgs(Y_train_scaled[0], 'Original', transformed_ecg, 'Amplitute', 'Augmentation Amplitude')

In [None]:
transform = E.Sequential([
    E.TimeShift(p=1.0,
                shift_limit=0.02
    )
])
transformed_ecg = transform(ecg=Y_train_scaled[0])['ecg']
plot_2_ecgs(Y_train_scaled[0], 'Original', transformed_ecg, 'Shifted', 'Augmentation Time Shift')

In [None]:
transform = E.Sequential([
    E.RandomTimeWrap(p=1.0,
                num_steps=4,
                wrap_limit=0.1,
    )
])
transformed_ecg = transform(ecg=Y_train_scaled[0])['ecg']
plot_2_ecgs(Y_train_scaled[0], 'Original', transformed_ecg, 'Warped', 'Augmentation Time Warp')

In [None]:
# form augmentation pipeline
transform = E.Sequential([
        E.TimeShift(p=0.5,
                shift_limit=0.03),
        E.AmplitudeScale(p=0.5,
                    scaling_range=(-0.2,0.2)),
        E.GaussNoise(p=0.5,
                variance=0.001),
        E.OneOf([
            E.SinePulse(p=0.5,
                ecg_frequency=100., # 100 Hz 
                amplitude_limit=0.2),
            E.RespirationNoise(p=0.5,
                ecg_frequency=100., # 100Hz
                breathing_rate_range=(12, 18), # breathing rate range in bpm
                amplitude_limit=0.2)
                ])
        ])

In [None]:
transformed_ecg = transform(ecg=Y_train_scaled[0])['ecg']
plot_2_ecgs(Y_train_scaled[0], 'Original', transformed_ecg, 'Augmented', 'Augmentation Pipeline', render_title=True)

In [None]:
transformed_ecg = transform(ecg=Y_train_scaled[0])['ecg']
plot_2_ecgs(Y_train_scaled[0], 'Original', transformed_ecg, 'Augmented', 'Augmentation Pipeline', render_title=True)

In [None]:
transformed_ecg = transform(ecg=Y_train_scaled[0])['ecg']
plot_2_ecgs(Y_train_scaled[0], 'Original', transformed_ecg, 'Augmented', 'Augmentation Pipeline', render_title=True)

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
#from models import SignalCNNTransformer
import ecgmentations
from sklearn.metrics import accuracy_score

In [None]:
from torch import nn
import torch.nn.functional as F
from torchsummary import summary

class SignalCNN(nn.Module):
    def __init__(self,
                 signal_channels=12,
                 initial_conv_kernel_size=11,
                 conv_kernel_size = 7,
                 conv_filters=16,
                 downscale_cnn_blocks=2,
                 number_of_classes=2,
                 add_cam=False) -> None:
            super(SignalCNN, self).__init__()
            self.downscale_cnn_blocks = downscale_cnn_blocks
            self.conv_filters = conv_filters
            self.add_cam = add_cam
            # in the input do one convolutional operation with
            self.pre_cnn_extractor =  nn.Sequential(
                nn.Conv1d(in_channels=signal_channels,
                          out_channels=conv_filters,
                          kernel_size=initial_conv_kernel_size,
                          padding='same'),
                nn.BatchNorm1d(conv_filters),
                nn.LeakyReLU(negative_slope=0.1),
                nn.MaxPool1d(2, stride=2))
            # initially do feature extraction with CNN layers
            self.cnn_extractor = nn.ModuleList([
                nn.Sequential(
                nn.Conv1d(in_channels=conv_filters if i == 0 else 2 ** (i - 1) * conv_filters,
                          out_channels=2 ** i * conv_filters,
                          kernel_size=conv_kernel_size,
                          padding='same'),
                nn.BatchNorm1d(2 ** i * conv_filters),
                nn.LeakyReLU(negative_slope=0.1),
                nn.MaxPool1d(2, stride=2)
            )
            for i in range(downscale_cnn_blocks)
            ])
            # lets put the last convolution with separate operation for easier operation access in cam
            self.last_conv = nn.Conv1d(in_channels=2 ** (downscale_cnn_blocks - 1) * conv_filters,
                          out_channels=2 ** downscale_cnn_blocks * conv_filters,
                          kernel_size=conv_kernel_size,
                          padding='same')
            self.last_batch_norm = nn.BatchNorm1d(2 ** downscale_cnn_blocks * conv_filters)
            self.last_activation = nn.LeakyReLU(negative_slope=0.1)
            self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
            number_of_classes = number_of_classes if number_of_classes > 2 else 1 # two class is sigmoid
            self.pred_layer = nn.Linear(2 ** (downscale_cnn_blocks) * conv_filters, number_of_classes)
            self.sigmoid = nn.Sigmoid()
            
    def forward(self, x_input):
        # build CNN part of network
        x = self.pre_cnn_extractor(x_input)
        for layer in self.cnn_extractor:
            x = layer(x)
        features = self.last_conv(x)
        x = self.last_batch_norm(features)
        x = self.last_activation(x)
        x_classification = self.global_avg_pool(x)
        x_classification = x_classification.view(x_classification.size(0), -1)  # Flatten the output
        x_classification = self.pred_layer(x_classification)
        x_classification = self.sigmoid(x_classification)
        if self.add_cam:
            #print(features.shape)
            weights = self.pred_layer.weight.data.unsqueeze(-1)
            features_reshaped = features.permute(0, 2, 1)
            result = torch.matmul(features_reshaped, weights).unsqueeze(0)
            # Interpolate between the tensors using linear interpolation
            interpolated_result = F.interpolate(result,
                                                size=(x_input.shape[2], 1),
                                                mode='bilinear',
                                                align_corners=False)
            interpolated_result = (interpolated_result - interpolated_result.min()) / (interpolated_result.max() - interpolated_result.min())
            return x_classification, interpolated_result.squeeze(0).squeeze(-1) # remove first and last dims
        return x_classification

In [None]:
def acc_metrics(op, labels):
    # Detach from computational graph and convert to numpy
    op_probs = torch.detach(op).numpy()
    
    # Convert probabilities to predicted class labels (0 or 1)
    op_labels = (op_probs > 0.5).astype(int)
    
    # Convert labels to numpy array if it's a tensor
    labels = labels.numpy() if isinstance(labels, torch.Tensor) else labels
    
    # Calculate accuracy
    acc = accuracy_score(op_labels, labels)
    
    return acc

In [None]:
# make dataset loader
class PTBLoader(Dataset):
    def __init__(self, signal_array: np.array,
                 labels: np.array,
                 augmentation: ecgmentations.core.compositions.Sequential = None) -> None:
        super().__init__()
        self.data: np.array = signal_array
        self.labels: np.array = labels
        self.augmentation: ecgmentations.core.compositions.Sequential = augmentation
        print(f'Data loader with {self.data.shape[0]} elements')
        print(f'Data loader with {self.labels.shape[0]} labels')
    
    def __len__(self):
        return self.data.shape[0]
    
    def __getitem__(self,idx):
        try:
            ecg = self.data[idx]
            if not self.augmentation is None:
                ecg_ = self.augmentation(ecg=ecg)['ecg']
            else:
                ecg_ = ecg
            # Convert one-hot encoded array to integer
            # label = np.argmax(self.labels[idx])
            #if label > 0:
            #   label = 1
            label = 0 if np.all(self.labels[idx] == np.array([1, 0, 0, 0, 0])) else 1 # Only norm with any other
            data_final = torch.Tensor(ecg_.T) # transpose
        except BaseException:
            print(f'Failed to fetch index {idx}')
            return None, None
        return data_final, label

In [None]:
# epoch to train
epochs = 25
batch_size = 128
# train - test
train_dataset = PTBLoader(Y_train_scaled, Z_train.to_numpy(), transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_dataset = PTBLoader(Y_valid_scaled, Z_valid.to_numpy())
valid_loader = DataLoader(valid_dataset, batch_size=batch_size)  

In [None]:
# output path for model
signal_model_output_path = 'model_0.pt'
# create neural model
ecg_leads_count = 12
model = SignalCNN(ecg_leads_count, conv_filters=16, number_of_classes=2)

In [None]:
%%time
#Optimizer
optimizer = torch.optim.Adam(params=model.parameters(),lr=0.0005)
# Learning rate scheduling [automatically reduce]
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                       mode='min',
                                                       factor=0.5,
                                                       patience=8,
                                                       threshold=0.0005,
                                                       threshold_mode='rel',
                                                       cooldown=0,
                                                       min_lr=0,
                                                       eps=1e-08)
#Loss function
loss = torch.nn.BCELoss()
# best test accuraccy for training
best_accuracy = 0.0
best_accuracy_epoch = -1

statistics = {'train_loss': [],
              'train_acc': [],
              'test_loss': [],
              'test_acc': [],
              'epoch': [],
              'lr': [],
              'best_acc': 0.0,
              'best_acc_epoch': -1}
    
for epoch_index in range(epochs):
    total_acc = 0.0
    total_loss = 0.0
    total_acc_avg = 0.0
    total_loss_avg = 0.0
    for idx, data in enumerate(train_loader):
        signal,label=data
        optimizer.zero_grad()
        output=model(signal)
        loss_=loss(output.squeeze(),label.float())
        loss_.backward()
        optimizer.step()
        acc_score=acc_metrics(output.squeeze(), label.float())
        # accumulate accuracy and loss
        total_acc += acc_score
        total_loss += loss_.item()
        total_acc_avg = total_acc if idx == 0 else total_acc / float(idx + 1)
        total_loss_avg = total_loss if idx == 0 else total_loss / float(idx + 1)
        print('[Training] Epoch: {}/{} Mini Batch: {}/{} Accuracy : {:4f}, Loss: {:4f}'.format(epoch_index + 1, epochs, idx, len(train_loader), total_acc_avg, total_loss_avg), end='\r')
    # testing stage
    total_acc = 0.0
    total_loss = 0.0
    total_val_acc_avg = 0.0
    total_val_loss_avg = 0.0
    with torch.no_grad():
        for idx, data in enumerate(valid_loader):
            signal,label=data
            output=model(signal)
            acc_score=acc_metrics(output.squeeze(),label.float())
            loss_=loss(output.squeeze(),label.float())
            # accumulate accuracy and loss
            total_acc += acc_score
            total_loss += loss_.item()
            total_val_acc_avg = total_acc if idx == 0 else total_acc / float(idx + 1)
            total_val_loss_avg = total_loss if idx == 0 else total_loss / float(idx + 1)
            print('[Testing] Epoch: {}/{} Mini Batch: {}/{} Accuracy : {:4f}, Loss: {:4f}'.format(epoch_index + 1, epochs, idx, len(valid_loader), total_acc_avg, total_loss_avg), end='\r')
        current_accuracy = (total_acc / len(valid_loader))
        # let scheduler do reduction in case validation score does not change
        scheduler.step(current_accuracy)
        if current_accuracy > best_accuracy:
            print(f'New best weights found with accuracy {current_accuracy} at epoch {epoch_index}. Saving to {signal_model_output_path}...')
            torch.save(model.state_dict(), signal_model_output_path)
            best_accuracy = current_accuracy
            best_accuracy_epoch = epoch_index
    lr = optimizer.param_groups[0]['lr']
    print(f'Epoch {epoch_index} learning rate: {lr}')
    # save parameters
    statistics['train_loss'].append(total_loss_avg)
    statistics['train_acc'].append(total_acc_avg)
    statistics['test_loss'].append(total_val_loss_avg)
    statistics['test_acc'].append(total_val_acc_avg)
    statistics['epoch'].append(epoch_index)
    statistics['lr'].append(lr)
    statistics['best_acc'] = best_accuracy
    statistics['best_acc_epoch'] = best_accuracy_epoch

In [None]:
from matplotlib.ticker import ScalarFormatter

def plot_train_metrics(metrics):
    epochs = range(len(metrics['train_loss']))
    best_acc_epoch = metrics['best_acc_epoch']
    best_acc = metrics['best_acc']
    
    plt.figure(figsize=(12, 8))
    
    # First subplot
    ax1 = plt.subplot(2, 1, 1)
    ax1.set_facecolor('white')
    ax1.spines['top'].set_color('gray')
    ax1.spines['right'].set_color('gray')
    ax1.spines['bottom'].set_color('gray')
    ax1.spines['left'].set_color('gray')
    ax1.grid(True, color='gray', alpha=0.65, linewidth=0.5, linestyle='dashed')
    ax1.set_xticks(np.arange(0, len(metrics['train_loss']) + 1, len(metrics['train_loss']) / 10).astype(int))
    
    ax1.plot(epochs, metrics['train_loss'], 'b--', label='Training loss')
    ax1.plot(epochs, metrics['test_loss'], 'r--', label='Test loss')
    ax1.plot(epochs, metrics['train_acc'], 'b-', label='Training accuracy')
    ax1.plot(epochs, metrics['test_acc'], 'r-', label='Test accuracy')
    ax1.plot(best_acc_epoch, best_acc, 'g*', markersize=10, label=f'Best accuracy: {best_acc:.4f}')  # Mark best accuracy with a star
    ax1.axvline(x=best_acc_epoch, color='gray', linestyle='--')  # Vertical line at best accuracy epoch
    ax1.set_title('Training and Test Metrics')
    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Metrics')
    ax1.legend()
    
    # Annotate best accuracy
    #ax1.annotate(f'Best accuracy: {best_acc:.4f}', xy=(best_acc_epoch, best_acc), xytext=(best_acc_epoch+5, best_acc-0.02),
    #             arrowprops=dict(facecolor='black', arrowstyle='->'))
    
    # Second subplot
    ax2 = plt.subplot(2, 1, 2)
    ax2.set_facecolor('white')
    ax2.spines['top'].set_color('gray')
    ax2.spines['right'].set_color('gray')
    ax2.spines['bottom'].set_color('gray')
    ax2.spines['left'].set_color('gray')
    ax2.grid(True, color='gray', alpha=0.65, linewidth=0.5, linestyle='dashed')
    ax2.set_xticks(np.arange(0, len(metrics['train_loss']) + 1, len(metrics['train_loss']) / 10).astype(int))
    
    ax2.plot(epochs, metrics['lr'], 'g-', label='Learning rate')
    ax2.set_title('Learning Rate Scheduling')
    ax2.set_xlabel('Epochs')
    ax2.set_ylabel('Learning Rate')
    ax2.ticklabel_format(axis='y', style='sci', scilimits=(0,0))
    ax2.legend()  # Add legend to the second subplot
    
    plt.tight_layout()
    plt.show()

In [None]:
plot_train_metrics(statistics)

# Validation

In [None]:
test_dataset = PTBLoader(Y_test_scaled, Z_test.to_numpy())
test_loader = DataLoader(test_dataset, batch_size=1)

In [None]:
from sklearn.metrics import confusion_matrix

In [None]:
def plot_jet_colormap(data):
    """
    Plot data using the jet colormap.
    
    Parameters:
        data (numpy.ndarray): Input data array in the range [0, 1].
    """
    # Create a figure and axis
    fig, ax = plt.subplots(figsize=(8, 1))

    # Plot the data using the jet colormap
    ax.imshow(data.reshape(1, -1), cmap='jet', aspect='auto')

    # Remove y-axis ticks and labels
    ax.set_yticks([])
    ax.set_yticklabels([])

    # Set x-axis ticks and labels
    ax.set_xticks([0, len(data) // 2, len(data) - 1])
    ax.set_xticklabels(['0', str(len(data) // 2), str(len(data) - 1)])

    # Set labels and title
    plt.xlabel('Data Points')
    plt.title('Jet Colormap')

    # Show the plot
    plt.show()

In [None]:
# Initialize an instance of your model
ecg_leads_count = 12
model = SignalCNN(ecg_leads_count, conv_filters=16, number_of_classes=2, add_cam=True)

# Load the trained parameters from the .pt file
model_path = 'model_0.pt'
model.load_state_dict(torch.load(model_path))

# Put the model in evaluation mode
model.eval()

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, accuracy_score

def plot_confusion_matrix(y_true_list, y_pred_list, classes, title):
    # Build confusion matrix
    cf_matrix = confusion_matrix(y_true_list, y_pred_list)
    
    # Normalize confusion matrix
    normalized_cf_matrix = cf_matrix / np.sum(cf_matrix, axis=1)[:, None]

    # Create DataFrame
    df_cm = pd.DataFrame(normalized_cf_matrix, index=classes, columns=classes)

    # Set up figure size
    plt.figure(figsize=(10, 6))

    # Generate heatmap with increased annotation size
    heatmap = sns.heatmap(df_cm, annot=True, fmt='.2f', cmap='YlGnBu', 
                          annot_kws={'size': 0}, cbar_kws={'shrink': 1.0})

    # Increase font size of axis labels and tick labels
    heatmap.set_xticklabels(heatmap.get_xticklabels(), fontsize=25)
    heatmap.set_yticklabels(heatmap.get_yticklabels(), fontsize=25)
    
    # Increase font size of colorbar tick labels
    cbar = heatmap.collections[0].colorbar
    cbar.ax.tick_params(labelsize=26)

    # Calculate accuracy
    accuracy = accuracy_score(y_true_list, y_pred_list)
    accuracy_percent = accuracy * 100.0

    # Add title with increased font size
    plt.title(f'{title} (Accuracy={accuracy_percent:.2f}%)', fontsize=26)

    # Loop through each cell in the heatmap and add element counts with semi-transparent background
    for i in range(len(classes)):
        for j in range(len(classes)):
            count = cf_matrix[i, j]
            value = df_cm.iloc[i, j]
            text = f'{value*100:.2f}%\n({count})'

            # Use a contrasting text color based on the cell value
            if value > 0.5:
                text_color = 'white'        # Light text for higher values
            else:
                text_color = 'black'        # Dark text for lower values

            # Add text annotation with a semi-transparent background
            heatmap.text(j + 0.5, i + 0.5, text, ha='center', va='center', 
                         fontsize=34, color=text_color,
                         bbox=dict(facecolor='white', alpha=0.0, edgecolor='none', boxstyle='round,pad=0.1'))

    # Display plot
    plt.show()


In [None]:
def plot_jet_colormap(data, ax, y_min, y_max):
    """
    Plot data using the jet colormap on a given axis.
    
    Parameters:
        data (numpy.ndarray): Input data array in the range [0, 1].
        ax (matplotlib.axes.Axes): Axis to plot the colormap on.
    """
    # Plot the data using the jet colormap
    ax.imshow(data.reshape(1, -1), cmap='jet', aspect='auto', alpha=0.4, extent=[0, data.shape[1], y_min, y_max])
    
def plot_ecg_high(ecg: np.array, title: str = 'ECG signal', data_1000: np.array = None):
    if data_1000 is None:
        raise ValueError("Input 'data_1000' is required.")
    
    if ecg.shape != (1000, 12):
        raise ValueError("Input 'ecg' must be of shape (1000, 12)")

    lead_names = ['I', 'II', 'III', 'aVL', 'aVR', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
    
    fig, axs = plt.subplots(12, 1, figsize=(14, 15))
    for i, ax in enumerate(axs.flat):
        # Overlay the ECG signal on top of the colormap
        ax.plot(ecg[:,i], alpha=0.95, linewidth=0.95, color='black')
        
        ax.set_facecolor('white')
        ax.spines['top'].set_color('gray')
        ax.spines['right'].set_color('gray')
        ax.spines['bottom'].set_color('gray')
        ax.spines['left'].set_color('gray')
        
        ax.grid(True, color='gray', alpha=0.65, linewidth=0.5, linestyle='dashed')
        ax.set_xticks(np.arange(0, ecg.shape[0] + 1, ecg.shape[0] / 10))  # Add more ticks
        if not i == len(ecg[1]) - 1:
            ax.set_xticklabels([])  # Remove x-axis labels
        ylim = ax.get_ylim()

        # Compute min and max
        y_min, y_max = np.min(ecg[:, i]), np.max(ecg[:, i])
        ax.set_ylim([y_min, y_max])  # Set y-axis limits

        # Move lead annotations to the right
        ax.text(ecg.shape[0] * 1.025, (y_max + y_min) / 2, lead_names[i], 
                fontsize=15, color='black', ha='center', va='center')
        
        # Invert y-axis labels
        #ax.set_yticklabels(reversed(ax.get_yticklabels()))
        # Plot the colormap
        plot_jet_colormap(data_1000, ax, ylim[0], ylim[1])
        # Increase y-axis tick size
        ax.tick_params(axis='y', labelsize=16)  # Adjust the labelsize as needed
        ax.tick_params(axis='x', labelsize=16) 
            
    fig.text(0.5, 0.00, 'Time (in samples)', ha='center', va='center', fontsize=16)
    fig.text(0.00, 0.5, 'Normalized scale', ha='center', va='center', rotation='vertical', fontsize=16)
    plt.suptitle(title, fontsize=20, y=0.98)
    plt.tight_layout()
    plt.show()

In [None]:
sample_to_visualize = 25

# General Normal vs Abnormal

In [None]:
from sklearn.metrics import roc_curve, auc, classification_report, confusion_matrix, f1_score, precision_score, recall_score

# Evaluate
y_pred = []
y_true = []
y_prob = []  # Store probabilities for ROC curve
threshold = 0.5
counter = 0

for idx, data in tqdm(enumerate(test_loader), total=len(test_loader)):
    signal, label = data
    
    output, cam = model(signal)  # Feed Network
    output = output.squeeze()
    
    op_prob = torch.detach(output).numpy().item()  # Convert tensor to numpy
    y_prob.append(op_prob)  # Store probability

    # Convert probabilities to predicted class labels (0 or 1)
    op_label = 1 if op_prob > threshold else 0
    y_pred.append(op_label)

    label_ = label.data.cpu().numpy().item()
    y_true.append(label_)  # Save Truth

    # View abnormal ECGs
    if op_label == 1 and label_ == 1 and counter < sample_to_visualize:
        cam_numpy = cam.detach().numpy()
        signal_numpy = signal.detach().numpy().squeeze(0)
        plot_ecg_high(signal_numpy.T, title='Abnormal', data_1000=cam_numpy)
        counter += 1

# Calculate confusion matrix
conf_matrix = confusion_matrix(y_true, y_pred)

# Precision, Recall, and F1 Score calculations
precision = precision_score(y_true, y_pred)
recall = recall_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred)

# Print the results
print("Confusion Matrix:")
print(conf_matrix)
print("\nClassification Report:")
print(classification_report(y_true, y_pred))

print(f"\nPrecision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")

# Define classes
classes = ('Normal', 'Diseased')
title = f'Normal and Diseased'
plot_confusion_matrix(y_true, y_pred, classes, title)

# Calculate and plot ROC curve, AUC
fpr, tpr, _ = roc_curve(y_true, y_prob)
roc_auc = auc(fpr, tpr)

print(f"\nAUC: {roc_auc:.4f}")

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc

# Compute ROC curve and AUC
fpr, tpr, _ = roc_curve(y_true, y_prob)
roc_auc = auc(fpr, tpr)

# Create the ROC plot
plt.figure(figsize=(8, 6))

# ROC curve
plt.plot(fpr, tpr, color='darkorchid', linewidth=2, label=f'ROC curve (AUC = {roc_auc:.3f})')

# Transparent fill under the ROC curve
plt.fill_between(fpr, tpr, alpha=0.05, color='darkorchid')

# Diagonal line
plt.plot([0, 1], [0, 1], color='gray', linestyle='--', linewidth=2)  

# Titles and labels
plt.xlabel('False Positive Rate', fontsize=19)
plt.ylabel('True Positive Rate', fontsize=19)
plt.title('Receiver Operating Characteristic (ROC) Curve', fontsize=24)

# Stronger grid visibility
plt.grid(True, linestyle='--', linewidth=0.7, alpha=0.8, color='gray')

# Ensure all axes and borders are visible
plt.tick_params(axis='both', which='major', labelsize=19)
plt.gca().spines['top'].set_visible(True)
plt.gca().spines['right'].set_visible(True)
plt.gca().spines['left'].set_visible(True)
plt.gca().spines['bottom'].set_visible(True)

# Legend settings
plt.legend(loc='lower right', fontsize=18, frameon=True)

# White background
plt.gca().set_facecolor('white')
plt.gcf().set_facecolor('white')

# Show the plot
plt.show()


# Normal vs Myocardial Infarction

In [None]:
# pick Myocardial Infarction and normal
labels = Z_test.to_numpy()
conditions = (
    ((labels[:, 0] == 1) & (labels[:, 1:] == 0).all(axis=1)) |  
    (labels[:, 1] == 1) 
)
#conditions = ((labels == [1, 0, 0, 0, 0]).all(axis=1)) | ((labels == [0, 1, 0, 0, 0]).all(axis=1))
indices = np.where(conditions)[0]
Z_test_ = labels[indices]
Y_test_ = Y_test_scaled[indices]

In [None]:
test_dataset = PTBLoader(Y_test_, Z_test_)
test_loader = DataLoader(test_dataset, batch_size=1)
# evaluate
y_pred = []
y_true = []
threshold = 0.5
counter = 0
for idx, data in tqdm(enumerate(test_loader), total=len(test_loader)):
    signal, label = data
    output, cam = model(signal) # Feed Network
    output = output.squeeze()
    op_prob = torch.detach(output).numpy().item()
    # Convert probabilities to predicted class labels (0 or 1)
    op_label = 1 if op_prob > threshold else 0
    y_pred.append(op_label)
    label_ = label.data.cpu().numpy().item()
    y_true.append(label_) # Save Truth
    # lets try to view few abnormal ecgs with XAI
    if op_label == 1 and label_ == 1 and counter < sample_to_visualize:
        cam_numpy = cam.detach().numpy()
        #plot_jet_colormap(cam_numpy)
        signal_numpy = signal.detach().numpy().squeeze(0)
        plot_ecg_high(signal_numpy.T, title = 'Myocardial Infarction', data_1000=cam_numpy)
        counter += 1
# Define classes
classes = ('Norm', 'Myocardial\n Infarction')
title = f'Confusion Matrix - Norm and Myocardial Infarction\n'
plot_confusion_matrix(y_true, y_pred, classes, title)

# Normal vs ST/T Change

In [None]:
# pick ST/T Change and normal
labels = Z_test.to_numpy()
#conditions = ((labels == [1, 0, 0, 0, 0]).all(axis=1)) | ((labels == [0, 0, 1, 0, 0]).all(axis=1))
conditions = (
    ((labels[:, 0] == 1) & (labels[:, 1:] == 0).all(axis=1)) |  
    (labels[:, 2] == 1) 
)
indices = np.where(conditions)[0]
Z_test_ = labels[indices]
Y_test_ = Y_test_scaled[indices]

In [None]:
test_dataset = PTBLoader(Y_test_, Z_test_)
test_loader = DataLoader(test_dataset, batch_size=1)
# evaluate
y_pred = []
y_true = []
threshold = 0.5
counter = 0
for idx, data in tqdm(enumerate(test_loader), total=len(test_loader)):
    signal, label = data
    output, cam = model(signal) # Feed Network
    output = output.squeeze()
    op_prob = torch.detach(output).numpy().item()
    # Convert probabilities to predicted class labels (0 or 1)
    op_label = 1 if op_prob > threshold else 0
    y_pred.append(op_label)
    label_ = label.data.cpu().numpy().item()
    y_true.append(label_) # Save Truth
    # lets try to view few abnormal ecgs with XAI
    if op_label == 1 and label_ == 1 and counter < sample_to_visualize:
        cam_numpy = cam.detach().numpy()
        #plot_jet_colormap(cam_numpy)
        signal_numpy = signal.detach().numpy().squeeze(0)
        plot_ecg_high(signal_numpy.T, title = 'ST/T Change', data_1000=cam_numpy)
        counter += 1
# Define classes
classes = ('Norm', 'ST/T Change')
title = f'Confusion Matrix - Norm and ST/T Change\n'
plot_confusion_matrix(y_true, y_pred, classes, title)

# Normal vs Conduction Disturbance

In [None]:
# pick Conduction Disturbance and normal
labels = Z_test.to_numpy()
## Check if the first and fourth positions match [1, 0, 0, 0, 0] or [0, 0, 0, 1, 0]
# conditions = ((labels == [1, 0, 0, 0, 0]).all(axis=1)) | ((labels == [0, 0, 0, 1, 0]).all(axis=1))
conditions = (
    ((labels[:, 0] == 1) & (labels[:, 1:] == 0).all(axis=1)) |  
    (labels[:, 3] == 1) 
)
indices = np.where(conditions)[0]
Z_test_ = labels[indices]
Y_test_ = Y_test_scaled[indices]

In [None]:
test_dataset = PTBLoader(Y_test_, Z_test_)
test_loader = DataLoader(test_dataset, batch_size=1)
# evaluate
y_pred = []
y_true = []
threshold = 0.5
counter = 0
for idx, data in tqdm(enumerate(test_loader), total=len(test_loader)):
    signal, label = data
    output, cam = model(signal) # Feed Network
    output = output.squeeze()
    op_prob = torch.detach(output).numpy().item()
    # Convert probabilities to predicted class labels (0 or 1)
    op_label = 1 if op_prob > threshold else 0
    y_pred.append(op_label)
    label_ = label.data.cpu().numpy().item()
    y_true.append(label_) # Save Truth
    # lets try to view few abnormal ecgs with XAI
    if op_label == 1 and label_ == 1 and counter < sample_to_visualize:
        cam_numpy = cam.detach().numpy()
        #plot_jet_colormap(cam_numpy)
        signal_numpy = signal.detach().numpy().squeeze(0)
        plot_ecg_high(signal_numpy.T, title = 'Conduction Disturbance', data_1000=cam_numpy)
        counter += 1
# Define classes
classes = ('Norm', 'Conduction\n Disturbance')
title = f'Confusion Matrix - Norm and Conduction Disturbance\n'
plot_confusion_matrix(y_true, y_pred, classes, title)

# Normal vs Hypertrophy

In [None]:
# pick Hypertrophy and normal
labels = Z_test.to_numpy()
#conditions = ((labels == [1, 0, 0, 0, 0]).all(axis=1)) | ((labels == [0, 0, 0, 0, 1]).all(axis=1))
conditions = (
    ((labels[:, 0] == 1) & (labels[:, 1:] == 0).all(axis=1)) |  
    (labels[:, 4] == 1) 
)
indices = np.where(conditions)[0]
Z_test_ = labels[indices]
Y_test_ = Y_test_scaled[indices]

In [None]:
test_dataset = PTBLoader(Y_test_, Z_test_)
test_loader = DataLoader(test_dataset, batch_size=1)
# evaluate
y_pred = []
y_true = []
threshold = 0.5
counter = 0
for idx, data in tqdm(enumerate(test_loader), total=len(test_loader)):
    signal, label = data
    output, cam = model(signal) # Feed Network
    output = output.squeeze()
    op_prob = torch.detach(output).numpy().item()
    # Convert probabilities to predicted class labels (0 or 1)
    op_label = 1 if op_prob > threshold else 0
    y_pred.append(op_label)
    label_ = label.data.cpu().numpy().item()
    y_true.append(label_) # Save Truth
    # lets try to view few abnormal ecgs with XAI
    if op_label == 1 and label_ == 1 and counter < sample_to_visualize:
        cam_numpy = cam.detach().numpy()
        #plot_jet_colormap(cam_numpy)
        signal_numpy = signal.detach().numpy().squeeze(0)
        plot_ecg_high(signal_numpy.T, title = 'Hypertrophy', data_1000=cam_numpy)
        counter += 1
# Define classes
classes = ('Norm', 'Hypertrophy')
title = f'Confusion Matrix - Norm and Hypertrophy\n'
plot_confusion_matrix(y_true, y_pred, classes, title)

# Jet colormap

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

# Create a modified 'jet' colormap with alpha=0.4
jet = plt.cm.get_cmap('jet', 256)
jet_colors = jet(np.linspace(0, 1, 256))
jet_colors[:, -1] = 0.4  # Set alpha to 0.4
jet_alpha = ListedColormap(jet_colors)

# Create a figure
fig, ax = plt.subplots(figsize=(12, 3))  # Adjust height for better proportions

# Create a dummy image to generate the colorbar
dummy_data = np.linspace(0, 1, 100).reshape(1, -1)
img = ax.imshow(dummy_data, cmap=jet_alpha, aspect="auto", visible=False)

# Create the colorbar
cbar = fig.colorbar(img, ax=ax, orientation="horizontal", aspect=30, pad=0.2)
cbar.ax.xaxis.set_ticks_position('top')  # Move ticks to the top
cbar.ax.tick_params(labelsize=14)  # Increase tick size

# Set custom tick labels at 0, 0.5, and 1
cbar.set_ticks([0, 0.5, 1])
cbar.set_ticklabels(["0", "0.5", "1"])

# Manually add "Intensity" label to the right
cbar.ax.text(1.02, 0.5, "Intensity", fontsize=16, fontweight="bold",
             transform=cbar.ax.transAxes, va="center")

# Remove the main axis
ax.remove()

plt.show()


# Model FLOPS comparisson (both models in evaluation mode)

In [None]:
from fvcore.nn import FlopCountAnalysis
import torch
from torchsummary import summary

input_data = torch.randn(1, 12, 1000)
ecg_leads_count = 12
classification_model = SignalCNN(ecg_leads_count, conv_filters=16, number_of_classes=2, add_cam=False)
classification_model.eval()

# Get parameter count and FLOPs for model without CAM
classification_model_params = sum(p.numel() for p in classification_model.parameters())
print(f'Classification model parameters: {classification_model_params}')

classification_flops = FlopCountAnalysis(classification_model, input_data)
classification_model_flops = classification_flops.total()
print(f'Classification model FLOPs: {classification_model_flops}')

# Model with CAM
classification_cam_model = SignalCNN(ecg_leads_count, conv_filters=16, number_of_classes=2, add_cam=True)
classification_cam_model.eval()

classification_cam_model_params = sum(p.numel() for p in classification_cam_model.parameters())
print(f'Classification model with CAM parameters: {classification_cam_model_params}')

classification_cam_flops = FlopCountAnalysis(classification_cam_model, input_data)
classification_cam_model_flops = classification_cam_flops.total()
print(f'Classification model with CAM FLOPs: {classification_cam_model_flops}')


In [None]:
import matplotlib.pyplot as plt
import matplotlib.cm as cm

# Data for the bar plot
models = ['Classification', 'Classification +\nCAM Explainability']
flops = [classification_model_flops, classification_cam_model_flops]

# Create the horizontal bar plot with bars having minimal gap
plt.figure(figsize=(6, 2.5), facecolor='white')  # Slightly smaller figure to reduce space
plasma = cm.plasma
bars = plt.barh(models, flops, color=[plasma(0.1, alpha=0.8), plasma(0.5, alpha=0.8)], 
                height=0.6, edgecolor='black', linewidth=1)  # Height adjusted for minimal gap
plt.gca().set_facecolor('white')  # Ensures the axes background is white

# Remove y-axis labels (since we put them inside bars)
plt.yticks([])

# Add labels and title
plt.xlabel('FLOPS', fontsize=18)  # Increased font size for xlabel
#plt.title('Comparison of Floating-Point\nOperations per Second (FLOPS)', fontsize=20)  # Increased font size for title
plt.xticks(fontsize=14)  # Increased font size for x-axis ticks
plt.xlim(0, max(flops) * 1.2)  # Adjust x-limit for better visualization
plt.grid(axis='x', linestyle='--', color='gray', linewidth=0.8, alpha=1.0)

# Display model names inside the bars
for i, bar in enumerate(bars):
    plt.text(bar.get_width() * 0.05,  # Position text slightly inside the bar from left
             bar.get_y() + bar.get_height() / 2,  
             models[i],  # Model name only
             va='center', ha='left',  
             fontsize=17, fontweight='bold', color='white')

# Display rotated FLOPS values on top of the bars
for i, bar in enumerate(bars):
    plt.text(bar.get_width() + (max(flops) * 0.02),  # Position slightly to the right of the bar
             bar.get_y() + bar.get_height() / 2,  
             f"{flops[i]}",  # FLOPS value
             va='center', ha='left',  
             fontsize=19, fontweight='bold', color='black',
             rotation=-90)  # Rotate the text for better readability

# Remove extra space between the bars
plt.subplots_adjust(left=0.2, right=0.85, top=0.85, bottom=0.15)

# Show the plot
plt.show()
