# ECG Arrhythmia Classification Pipeline

**Complete pipeline for ECG analysis:**
1. Feature extraction from ECG signals
2. Preprocessing and filtering
3. Class balancing (SMOTE)
4. Random Forest classification
5. Performance evaluation
6. SHAP interpretability analysis

---

## 1. Setup and Imports

In [None]:
#!/usr/bin/env python3
import os
import sys
import time
import warnings
import gc
from collections import Counter
import numpy as np
import pandas as pd
import scipy.io as sio
import scipy.signal as signal
import pywt
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.metrics import (
    confusion_matrix, classification_report, accuracy_score,
    roc_curve, auc, precision_recall_curve, average_precision_score
)
from sklearn.preprocessing import LabelBinarizer
from imblearn.over_sampling import SMOTE
import pycatch22
import shap

print("All imports successful!")

## 2. Configuration Parameters

DATA_DIR describes the directory of the data we are using. To utilise this, please download the physionet ECG files using one of the following methods:

1. Download the ZIP file (2.3 GB) from https://physionet.org/content/ecg-arrhythmia/1.0.0/ 
2. Download the files using your terminal: wget -r -N -c -np https://physionet.org/files/ecg-arrhythmia/1.0.0/
3. Download the files using AWS command line tools: aws s3 sync --no-sign-request s3://physionet-open/ecg-arrhythmia/1.0.0/ DESTINATION

If using the default config, place these files in a directory named "./physionet_ecg_arrhythmia_data"

In [None]:
BASE_DIR = "./physionet_ecg_arrhythmia_data"
DATA_DIR = "./physionet_ecg_arrhythmia_data/WFDBRecords"
OUTPUT_DIR = "./ecg_analysis_results"

# Preprocessing options
APPLY_BALANCING = True

# Signal processing parameters
SAMPLING_RATE = 500
BASELINE_CUTOFF = 0.5
LOWPASS_CUTOFF = 40
NOTCH_FREQ = 60
WAVELET_TYPE = 'db6'
WAVELET_LEVEL = 3

# Machine learning parameters
N_ESTIMATORS = 100
TEST_SIZE = 0.3
RANDOM_STATE = 42

# Visualization parameters
SHAP_SAMPLE_SIZE = 300
TOP_N_FEATURES = 20
FIGURE_DPI = 150
FONT_SIZE = 12

os.makedirs(OUTPUT_DIR, exist_ok=True)

print("="*70)
print("ECG ANALYSIS PIPELINE")
print("="*70)
print(f"\nData directory: {DATA_DIR}")
print(f"Output directory: {OUTPUT_DIR}")
print(f"SMOTE balancing: {'Enabled' if APPLY_BALANCING else 'Disabled'}")
print(f"Random Forest estimators: {N_ESTIMATORS}")

## 3. Preprocessing Functions

Functions for ECG signal preprocessing and SNOMED-CT to AAMI class mapping.

In [None]:
def preprocess_ecg_signal(ecg_signal, fs=SAMPLING_RATE):
    """Preprocess ECG signal with filtering and wavelet denoising."""
    try:
        # Step 1: Baseline Wander Removal using High-pass Filter
        # Baseline wander has a low frequency (typically below 0.5 Hz), so a high-pass
        # filter effectively removes this drift while retaining the higher frequency
        # components that represent the actual cardiac activity.
        b, a = signal.butter(1, BASELINE_CUTOFF / (0.5 * fs), btype='highpass')
        processed = signal.filtfilt(b, a, ecg_signal)

        # Step 2: Filtering
        # ECG signals contain important components primarily below 40 Hz, so by
        # filtering out higher frequencies, you can reduce noise without affecting the
        # diagnostic information in the ECG. This improves the signal quality and
        # makes it easier to detect key features like QRS complexes and P-T waves.


        # 2a: Low-pass filter to remove high-frequency noise
        b, a = signal.butter(4, LOWPASS_CUTOFF / (0.5 * fs), btype='low')
        processed = signal.filtfilt(b, a, processed)

        # 2b: Notch filter to remove powerline interference at 50 or 60 Hz
        b, a = signal.iirnotch(NOTCH_FREQ / (0.5 * fs), 30)
        processed = signal.filtfilt(b, a, processed)

        # Step 3: Denoising using Wavelet Transform
        # Wavelet Denoising involves decomposing the ECG signal into different frequency
        # components using wavelets, then selectively reducing noise by thresholding
        # the wavelet coefficients, and finally reconstructing the signal.
        # Wavelet denoising is particularly effective for signals like ECG because it can
        # target noise at specific scales (frequencies) without affecting the signal’s
        # main components. It’s adaptive and can handle non-stationary signals,
        # which is a common characteristic of physiological data.
        coeffs = pywt.wavedec(processed, WAVELET_TYPE, level=WAVELET_LEVEL)
        sigma = np.median(np.abs(coeffs[-1])) / 0.6745
        uthresh = sigma * np.sqrt(2 * np.log(len(processed)))
        denoised_coeffs = list(map(lambda x: pywt.threshold(x, uthresh, mode='soft'), coeffs))
        processed = pywt.waverec(denoised_coeffs, WAVELET_TYPE)

        # Normalise
        return (processed - np.mean(processed)) / (np.std(processed) + 1e-10)
    except:
        return (ecg_signal - np.mean(ecg_signal)) / (np.std(ecg_signal) + 1e-10)

print("Preprocessing function defined")

In [None]:
def map_snomed_to_aami(snomed_codes, db_path=BASE_DIR + "/ecg-arrhythmia-1.0.0"):
    """Map SNOMED-CT diagnostic codes to AAMI arrhythmia classes."""
    snomed_path = os.path.join(db_path, "ConditionNames_SNOMED-CT.csv")
    snomed_dict = {}

    if os.path.exists(snomed_path):
        try:
            snomed_df = pd.read_csv(snomed_path)
            for _, row in snomed_df.iterrows():
                snomed_dict[str(row['Snomed_CT'])] = row['Acronym Name']
        except:
            pass

    aami_map = {
        'SR': 'N', 'SB': 'N', 'ST': 'N', 'SA': 'N', 'NSR': 'N',
        'AFIB': 'S', 'AF': 'S', 'AT': 'S', 'AVNRT': 'S', 'AVRT': 'S',
        'SVT': 'S', 'APB': 'S', 'ABI': 'S', 'SAAWR': 'S', 'JPT': 'S',
        'JEB': 'S', '1AVB': 'S', '2AVB': 'S', '2AVB1': 'S', '2AVB2': 'S', 'AVB': 'S',
        'VPB': 'V', 'VEB': 'V', 'VB': 'V', 'VET': 'V', 'VPE': 'V',
        '3AVB': 'V', 'RBBB': 'V', 'LBBB': 'V', 'LFBBB': 'V', 'LBBBB': 'V',
        'IVB': 'V', 'IDC': 'V',
        'VFW': 'F',
    }

    aami_classes = []
    for code in snomed_codes:
        code = str(code).strip()
        if code in snomed_dict:
            acronym = snomed_dict[code]
            aami_class = aami_map.get(acronym, 'Q')
            aami_classes.append(aami_class)
        else:
            aami_classes.append('Q')

    if 'V' in aami_classes:
        return 'V'
    if 'S' in aami_classes:
        return 'S'
    if 'F' in aami_classes:
        return 'F'
    if 'Q' in aami_classes:
        return 'Q'
    return 'N'

print("SNOMED-CT mapping function defined")

In [None]:
def extract_features_from_ecg(file_path):
    """Extract catch22 time-series features from ECG signals."""
    features = []
    try:
        hea_path = file_path.replace('.mat', '.hea')
        aami_class = 'Q'
        dx_codes = []

        if os.path.exists(hea_path):
            with open(hea_path, 'r') as f:
                for line in f:
                    if line.startswith('#Dx:'):
                        dx_str = line.strip().replace('#Dx:', '')
                        dx_codes = [code.strip() for code in dx_str.split(',')]
                        break

        if dx_codes:
            aami_class = map_snomed_to_aami(dx_codes)

        mat_data = sio.loadmat(file_path)
        signal_data = None

        for key in ['val', 'data', 'signal']:
            if key in mat_data:
                signal_data = mat_data[key]
                break

        if signal_data is None:
            for key, value in mat_data.items():
                if isinstance(value, np.ndarray) and value.size > 0 and not key.startswith('__'):
                    signal_data = value
                    break

        if signal_data is None:
            return None

        if signal_data.shape[0] > signal_data.shape[1]:
            signal_data = signal_data.T

        file_id = os.path.basename(file_path).replace('.mat', '')
        n_leads = signal_data.shape[0]

        for lead_idx in range(n_leads):
            lead_signal = signal_data[lead_idx, :].astype(float)

            if len(lead_signal) < 10 or np.std(lead_signal) < 1e-10:
                continue

            if np.count_nonzero(lead_signal) < 0.5 * len(lead_signal):
                continue

            try:
                lead_signal_processed = preprocess_ecg_signal(lead_signal)

                catch22_result = pycatch22.catch22_all(lead_signal_processed)
                feature_values = catch22_result['values']
                feature_names = catch22_result['names']

                for feat_name, feat_val in zip(feature_names, feature_values):
                    if feat_val is None or (isinstance(feat_val, (int, float)) and (np.isinf(feat_val) or np.isnan(feat_val))):
                        feat_val = np.nan

                    features.append({
                        'record_id': file_id,
                        'lead': lead_idx,
                        'feature_name': str(feat_name),
                        'feature_value': float(feat_val) if not np.isnan(feat_val) else np.nan,
                        'class': aami_class
                    })
            except:
                pass

    except Exception as e:
        print(f"Error: {e}")
        return None

    return features

print("Feature extraction function defined")

## 4. Feature Extraction from ECG Files

Process all ECG files and extract catch22 time-series features.

In [None]:
print("STEP 1: FEATURE EXTRACTION")

# Find all .mat files with corresponding .hea files
mat_files = []
for root, _, files in os.walk(DATA_DIR):
    for file in files:
        if file.endswith('.mat') and not file.startswith('.'):
            hea_file = file.replace('.mat', '.hea')
            if os.path.exists(os.path.join(root, hea_file)):
                mat_files.append(os.path.join(root, file))

print(f"Found {len(mat_files)} ECG files")

In [None]:
# Extract features from all files
all_features = []
class_counts = Counter()

for i, file_path in enumerate(mat_files):
    if (i + 1) % 10 == 0 or (i + 1) == len(mat_files):
        print(f"Processing {i+1}/{len(mat_files)}: {os.path.basename(file_path)}")
    
    features = extract_features_from_ecg(file_path)
    if features:
        all_features.extend(features)
        class_counts[features[0]['class']] += 1

features_df = pd.DataFrame(all_features)
print(f"\nExtracted {len(features_df)} feature values")
print(f"Class distribution: {dict(class_counts)}")

## 5. Feature Preparation

Transform features into wide format and prepare for machine learning.

In [None]:
print("STEP 2: FEATURE PREPARATION")

# Pivot to wide format
features_wide = features_df.pivot_table(
    index='record_id',
    columns='feature_name',
    values='feature_value',
    aggfunc='first'
).reset_index()

class_mapping = features_df.groupby('record_id')['class'].first()
features_wide['class'] = features_wide['record_id'].map(class_mapping)

print(f"Feature matrix shape: {features_wide.shape}")

# Display first few rows
display(features_wide.head())

In [None]:
# Separate features and target
feature_cols = [col for col in features_wide.columns if col not in ['record_id', 'class']]
X = features_wide[feature_cols].values
y = features_wide['class'].values

# Impute missing values
imputer = SimpleImputer(strategy='mean')
X = imputer.fit_transform(X)

print(f"Features: {X.shape}")
print(f"Original class counts: {Counter(y)}")

## 6. Class Balancing with SMOTE

Apply Synthetic Minority Over-sampling Technique to balance classes.

In [None]:
print("STEP 3: CLASS BALANCING")

if APPLY_BALANCING:
    try:
        class_counts = Counter(y)
        min_samples = min(class_counts.values())

        if min_samples >= 2:
            k_neighbors = min(1, min_samples - 1)

            smote = SMOTE(random_state=RANDOM_STATE, k_neighbors=k_neighbors)
            X_balanced, y_balanced = smote.fit_resample(X, y)

            print(f"Applied SMOTE balancing")
            print(f"Balanced class counts: {Counter(y_balanced)}")
        else:
            print("Not enough samples for SMOTE, using original data")
            X_balanced, y_balanced = X, y
    except Exception as e:
        print(f"SMOTE failed ({e}), using original data")
        X_balanced, y_balanced = X, y
else:
    print("Skipping balancing")
    X_balanced, y_balanced = X, y

## 7. Train Random Forest Classifier

In [None]:
print("STEP 4: TRAIN RANDOM FOREST CLASSIFIER")

# Split data
class_counts = Counter(y_balanced)
min_class_count = min(class_counts.values())

if min_class_count >= 2:
    X_train, X_test, y_train, y_test = train_test_split(
        X_balanced, y_balanced,
        test_size=TEST_SIZE,
        random_state=RANDOM_STATE,
        stratify=y_balanced
    )
else:
    print("Too few samples for stratified split, using random split")
    X_train, X_test, y_train, y_test = train_test_split(
        X_balanced, y_balanced,
        test_size=TEST_SIZE,
        random_state=RANDOM_STATE
    )

print(f"Training set: {X_train.shape[0]} samples")
print(f"Test set: {X_test.shape[0]} samples")

In [None]:
# Scale features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# Train Random Forest
print(f"\nTraining Random Forest (n_estimators={N_ESTIMATORS})...")
rf_classifier = RandomForestClassifier(
    n_estimators=N_ESTIMATORS,
    random_state=RANDOM_STATE,
    n_jobs=-1
)
rf_classifier.fit(X_train_scaled, y_train)

# Make predictions
y_pred = rf_classifier.predict(X_test_scaled)
y_pred_proba = rf_classifier.predict_proba(X_test_scaled)

accuracy = accuracy_score(y_test, y_pred)
print(f"\nTest Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")

## 8. Evaluation Metrics

Calculate detailed performance metrics including sensitivity, specificity, PPV, and F1-score.

In [None]:
print("STEP 5: EVALUATION METRICS")

class_labels = sorted(np.unique(y_balanced))
cm = confusion_matrix(y_test, y_pred, labels=class_labels)

print("\nConfusion Matrix:")
print(cm)

print("\nClassification Report:")
print(classification_report(y_test, y_pred))

In [None]:
# Calculate detailed metrics
metrics_list = []
for i, label in enumerate(class_labels):
    tp = cm[i, i]
    fn = np.sum(cm[i, :]) - tp
    fp = np.sum(cm[:, i]) - tp
    tn = np.sum(cm) - (tp + fn + fp)

    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    ppv = tp / (tp + fp) if (tp + fp) > 0 else 0
    f1 = 2 * (ppv * sensitivity) / (ppv + sensitivity) if (ppv + sensitivity) > 0 else 0

    metrics_list.append({
        'Class': label,
        'Sensitivity': sensitivity,
        'Specificity': specificity,
        'PPV': ppv,
        'F1_Score': f1
    })

metrics_df = pd.DataFrame(metrics_list)
print("\nDetailed Metrics:")
display(metrics_df.round(4))

# Save results
metrics_df.to_csv(os.path.join(OUTPUT_DIR, 'performance_metrics.csv'), index=False)
print(f"\nSaved metrics to: {OUTPUT_DIR}/performance_metrics.csv")

## 9. Visualizations

Generate confusion matrices, ROC curves, and feature importance plots.

In [None]:
print("STEP 6: GENERATING VISUALIZATIONS")

plt.rcParams.update({'font.size': FONT_SIZE})

### 9.1 Confusion Matrices

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=class_labels, yticklabels=class_labels,
            square=True, ax=axes[0])
axes[0].set_xlabel('Predicted', fontweight='bold')
axes[0].set_ylabel('True', fontweight='bold')
axes[0].set_title('Confusion Matrix (Counts)', fontweight='bold')

cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
sns.heatmap(cm_norm, annot=True, fmt='.2f', cmap='Greens',
            xticklabels=class_labels, yticklabels=class_labels,
            square=True, ax=axes[1])
axes[1].set_xlabel('Predicted', fontweight='bold')
axes[1].set_ylabel('True', fontweight='bold')
axes[1].set_title('Confusion Matrix (Normalized)', fontweight='bold')

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'confusion_matrices.png'), dpi=FIGURE_DPI)
print("Saved: confusion_matrices.png")
plt.show()

### 9.2 ROC Curves

In [None]:
if len(class_labels) > 1:
    lb = LabelBinarizer()
    y_test_bin = lb.fit_transform(y_test)

    plt.figure(figsize=(10, 8))
    colors = ['blue', 'red', 'green', 'purple', 'orange']

    for i, (color, label) in enumerate(zip(colors[:len(class_labels)], class_labels)):
        if i < y_test_bin.shape[1] and i < y_pred_proba.shape[1]:
            fpr, tpr, _ = roc_curve(y_test_bin[:, i], y_pred_proba[:, i])
            roc_auc = auc(fpr, tpr)
            plt.plot(fpr, tpr, color=color, lw=2,
                     label=f'Class {label} (AUC = {roc_auc:.3f})')

    plt.plot([0, 1], [0, 1], 'k--', lw=2, label='Random')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate', fontweight='bold')
    plt.ylabel('True Positive Rate', fontweight='bold')
    plt.title('ROC Curves', fontweight='bold')
    plt.legend(loc="lower right")
    plt.grid(alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, 'roc_curves.png'), dpi=FIGURE_DPI)
    print("Saved: roc_curves.png")
    plt.show()

### 9.3 Feature Importance

In [None]:
feature_importance = rf_classifier.feature_importances_
importance_df = pd.DataFrame({
    'Feature': feature_cols,
    'Importance': feature_importance
}).sort_values('Importance', ascending=False)

plt.figure(figsize=(10, 8))
top_n = min(20, len(importance_df))
top_features = importance_df.head(top_n)
plt.barh(range(len(top_features)), top_features['Importance'].values[::-1])
plt.yticks(range(len(top_features)), top_features['Feature'].values[::-1])
plt.xlabel('Importance', fontweight='bold')
plt.ylabel('Feature', fontweight='bold')
plt.title(f'Top {top_n} Feature Importance', fontweight='bold')
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'feature_importance.png'), dpi=FIGURE_DPI)
print("Saved: feature_importance.png")
plt.show()

importance_df.to_csv(os.path.join(OUTPUT_DIR, 'feature_importance.csv'), index=False)

print("\nTop 10 Most Important Features:")
display(importance_df.head(10))

## 10. SHAP Analysis (Interpretability)

Use SHAP (SHapley Additive exPlanations) to understand feature contributions to predictions.

In [None]:
print("STEP 7: SHAP ANALYSIS")

sample_size = min(SHAP_SAMPLE_SIZE, len(X_test_scaled))
X_shap_sample = X_test_scaled[:sample_size]
y_shap_sample = y_test[:sample_size]

print(f"\nCalculating SHAP values for {sample_size} samples...")
print("This may take several minutes...")

explainer = shap.TreeExplainer(rf_classifier)
shap_values = explainer.shap_values(X_shap_sample)

X_shap_df = pd.DataFrame(X_shap_sample, columns=feature_cols)

print("SHAP values calculated!")

### 10.1 SHAP Summary Plot (All Classes)

In [None]:
print("\nGenerating SHAP summary plot...")
plt.figure(figsize=(14, 12))
shap.summary_plot(shap_values, X_shap_df, class_names=class_labels,
                    show=False, max_display=TOP_N_FEATURES)
plt.title('SHAP Feature Importance - All Classes', fontsize=16, fontweight='bold', pad=20)
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'shap_summary_all_classes.png'), dpi=FIGURE_DPI)
print("Saved: shap_summary_all_classes.png")
plt.show()

### 10.2 SHAP Class-Specific Plots

In [None]:
print("\nGenerating class-specific SHAP plots...")
for i, class_name in enumerate(class_labels):
    plt.figure(figsize=(12, 10))
    shap.summary_plot(shap_values[i], X_shap_df, show=False, max_display=TOP_N_FEATURES)
    plt.title(f'SHAP Feature Importance - Class {class_name}', fontsize=16, fontweight='bold', pad=20)
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, f'shap_summary_class_{class_name}.png'), dpi=FIGURE_DPI)
    plt.show()
print(f"Saved class-specific plots for: {class_labels}")

### 10.3 SHAP Bar Plot (Mean Impact)

In [None]:
plt.figure(figsize=(12, 10))
shap.summary_plot(shap_values, X_shap_df, plot_type="bar",
                    class_names=class_labels, show=False, max_display=TOP_N_FEATURES)
plt.title('SHAP Mean Impact on Output', fontsize=16, fontweight='bold', pad=20)
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'shap_bar_plot.png'), dpi=FIGURE_DPI)
print("Saved: shap_bar_plot.png")
plt.show()

print("\nSHAP analysis complete!")

## 11. Pipeline Summary

Final summary of results and saved outputs.

In [None]:
print("\n" + "="*70)
print("PIPELINE COMPLETE - SUMMARY")
print("="*70)
print(f"\nProcessed {len(mat_files)} ECG records")
print(f"Extracted {len(feature_cols)} features per record")
print(f"Preprocessing: {'Enabled' if APPLY_PREPROCESSING else 'Disabled'}")
print(f"Balancing (SMOTE): {'Applied' if APPLY_BALANCING else 'Skipped'}")
print(f"Model: Random Forest ({N_ESTIMATORS} trees)")
print(f"Test Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
print(f"\nResults saved to: {OUTPUT_DIR}/")
print("  - performance_metrics.csv")
print("  - confusion_matrices.png")
print("  - roc_curves.png")
print("  - feature_importance.png")
print("  - feature_importance.csv")
print("  - shap_summary_all_classes.png")
print("  - shap_summary_class_*.png (per class)")
print("  - shap_bar_plot.png")
print("\n" + "="*70)