# Preprocessing the ECG Data to converted the .atr and .hea to image of 112*112 pixel in greyscale

IMPORT required Library

In [4]:
import os
import wfdb
import numpy as np
import matplotlib.pyplot as plt
import pywt
from scipy.signal import butter, filtfilt, find_peaks
from PIL import Image

Remove baseline drift from ecg Signal
This function is designed to remove low-frequency components or "baseline drift" from an ECG signal using wavelet decomposition

In [5]:
def remove_baseline_drift(signal, wavelet="db6", level=9):
    coeff = pywt.wavedec(signal, wavelet, level=level)
    coeff[0] = np.zeros_like(coeff[0])  # remove low-frequency drift
    return pywt.waverec(coeff, wavelet)

5th order Bandpass Filter

In [7]:
def bandpass_filter(signal, low_freq, high_freq, fs, order=5):
    nyquist = 0.5 * fs
    low = low_freq / nyquist
    high = high_freq / nyquist
    b, a = butter(order, [low, high], btype="band")
    return filtfilt(b, a, signal)

Detect R peak

In [8]:
def detect_r_peaks(ecg_signal, fs=500):
    """
    Detect only R-peaks, avoiding T-peaks.
    """
    # 1. Initial peak detection
    distance = int(0.2 * fs)  # at least 200 ms apart
    raw_peaks, _ = find_peaks(ecg_signal, distance=distance, height=np.mean(ecg_signal) + 0.5*np.std(ecg_signal))

    # 2. Refine: keep peaks with sharp slope (R-peaks are steep)
    r_peaks = []
    for idx in raw_peaks:
        # local slope around the peak
        left = max(0, idx-5)
        right = min(len(ecg_signal)-1, idx+5)
        slope = np.max(np.diff(ecg_signal[left:right]))

        # local amplitude
        amp = ecg_signal[idx]

        # Apply slope and amplitude threshold
        if slope > 0.5*np.std(ecg_signal) and amp > 0.5*np.max(ecg_signal):
            r_peaks.append(idx)

    return np.array(r_peaks)

Load ECG Data

In [9]:
def load_ecg_data(record_path):
    record = wfdb.rdrecord(record_path)
    ecg_signal = record.p_signal[:, 0]  # first channel

    # filtering
    ecg_signal = remove_baseline_drift(ecg_signal)
    ecg_signal = bandpass_filter(ecg_signal, 0.5, 40, fs=500)

    # detect R-peaks
    r_peaks = detect_r_peaks(ecg_signal, fs=500)
    return ecg_signal, r_peaks

Image Generation

In [10]:

def save_beat_images(ecg_signal, r_peaks, record_name, person_id, output_dir, fs=500):
    margin = int(0.2 * fs)  # 200 ms = 100 samples at 500 Hz

    # Directories for single, two, three
    single_dir = os.path.join(output_dir, "Single_Beat", person_id)
    two_dir = os.path.join(output_dir, "Two_Beats", person_id)
    three_dir = os.path.join(output_dir, "Three_Beats", person_id)
    os.makedirs(single_dir, exist_ok=True)
    os.makedirs(two_dir, exist_ok=True)
    os.makedirs(three_dir, exist_ok=True)

    for i in range(len(r_peaks) - 3):  # ensure enough future beats
        # --- Single Beat ---
        start = max(r_peaks[i] - margin, 0)
        end = min(r_peaks[i] + margin, len(ecg_signal))
        single = ecg_signal[start:end]

        plt.figure(figsize=(2,2))  # square figure
        plt.plot(single, color="black", linewidth=1)
        plt.axis("off")
        temp_path = os.path.join("C:\\ecg_new\\t_img", "temp.png")
        plt.savefig(temp_path, bbox_inches="tight", pad_inches=0)
        plt.close()

        # Convert to grayscale 112x112
        img = Image.open(temp_path).convert("L")
        img = img.resize((112,112))
        img.save(os.path.join(single_dir, f"{record_name}_single_{i}.png"))
        os.remove(temp_path)  # clean temp




        # --- Two Beats ---
        end_two = min(r_peaks[i+1] + margin, len(ecg_signal))
        two_beats = ecg_signal[start:end_two]

        plt.figure(figsize=(2,2))  # square figure
        plt.plot(two_beats, color="black", linewidth=1)
        plt.axis("off")
        temp_path = os.path.join("C:\\ecg_new\\t_img", "temp.png")
        plt.savefig(temp_path, bbox_inches="tight", pad_inches=0)
        plt.close()

        # Convert to grayscale 112x112
        img = Image.open(temp_path).convert("L")
        img = img.resize((112,112))
        img.save(os.path.join(two_dir, f"{record_name}_two_{i}.png"))
        os.remove(temp_path)  # clean temp

        # --- Three Beats ---
        end_three = min(r_peaks[i+2] + margin, len(ecg_signal))
        three_beats = ecg_signal[start:end_three]

        plt.figure(figsize=(2,2))  # square figure
        plt.plot(three_beats, color="black", linewidth=1)
        plt.axis("off")
        temp_path = os.path.join("C:\\ecg_new\\t_img", "temp.png")
        plt.savefig(temp_path, bbox_inches="tight", pad_inches=0)
        plt.close()

        # Convert to grayscale 112x112
        img = Image.open(temp_path).convert("L")
        img = img.resize((112,112))
        img.save(os.path.join(three_dir, f"{record_name}_three_{i}.png"))
        os.remove(temp_path)  # clean temp





# ------------------ Runner ------------------

base_directory = "C:\\ecg_new\\ecg-id-database-1.0.0"
output_directory = "C:\\ecg_new\\ecg_grayscale_images"

for p in range(1, 91):
    person_id = f"Person_{p:02d}"
    person_path = os.path.join(base_directory, person_id)

    if not os.path.exists(person_path):
        print(f"Skipping {person_id} (folder not found)")
        continue

    print(f"Processing {person_id}...")

    for record_file in os.listdir(person_path):
        if record_file.endswith(".dat"):
            record_base = record_file[:-4]
            record_path = os.path.join(person_path, record_base)

            try:
                ecg_signal, r_peaks = load_ecg_data(record_path)
                save_beat_images(ecg_signal, r_peaks, record_base, person_id, output_directory)
            except Exception as e:
                print(f"Error processing {record_path}: {e}")

print("✅ Image databases created: Single_Beat, Two_Beats, Three_Beats")



Skipping Person_01 (folder not found)
Skipping Person_02 (folder not found)
Skipping Person_03 (folder not found)
Skipping Person_04 (folder not found)
Skipping Person_05 (folder not found)
Skipping Person_06 (folder not found)
Skipping Person_07 (folder not found)
Skipping Person_08 (folder not found)
Skipping Person_09 (folder not found)
Skipping Person_10 (folder not found)
Skipping Person_11 (folder not found)
Skipping Person_12 (folder not found)
Skipping Person_13 (folder not found)
Skipping Person_14 (folder not found)
Skipping Person_15 (folder not found)
Skipping Person_16 (folder not found)
Skipping Person_17 (folder not found)
Skipping Person_18 (folder not found)
Skipping Person_19 (folder not found)
Skipping Person_20 (folder not found)
Skipping Person_21 (folder not found)
Skipping Person_22 (folder not found)
Skipping Person_23 (folder not found)
Skipping Person_24 (folder not found)
Skipping Person_25 (folder not found)
Skipping Person_26 (folder not found)
Skipping Per

In [13]:
from google.colab import drive
drive.mount('/content/drive')

MessageError: Error: credential propagation was unsuccessful

# Database Handling and Model Training


Import Required library

In [None]:
import os
import random
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models, backend as K
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score
from sklearn.utils.class_weight import compute_class_weight
import matplotlib.pyplot as plt

activation function

In [None]:
def swish_activation(x):
    """Swish activation function: x * sigmoid(x)"""
    return x * tf.nn.sigmoid(x)

# Register custom activation
tf.keras.utils.get_custom_objects().update({'swish_activation': swish_activation})

In [None]:
def custom_activation(x):
    return 1.0 / (1.0 + K.pow(10.0, -x))

tf.keras.utils.get_custom_objects().update({'custom_activation': custom_activation})

Data Preparation

In [None]:
def make_balanced_pairs(data_dir, img_size=(112,112), pairs_per_person=200):
    persons = [os.path.join(data_dir, p) for p in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, p))]
    images = {}

    # Load and preprocess images
    for person in persons:
        person_id = os.path.basename(person)
        imgs = []
        files = [f for f in os.listdir(person) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

        for file in files:
            img_path = os.path.join(person, file)
            try:
                img = load_img(img_path, target_size=img_size, color_mode='grayscale')
                img = img_to_array(img) / 255.0
                # Add slight noise for data augmentation
                img = img + np.random.normal(0, 0.01, img.shape)
                img = np.clip(img, 0, 1)
                imgs.append(img)
            except Exception as e:
                print(f"Error loading {img_path}: {e}")
                continue

        if len(imgs) >= 2:  # Only include persons with at least 2 images
            images[person_id] = imgs

    pairs, labels = [], []
    person_ids = list(images.keys())

    print(f"Found {len(person_ids)} persons with sufficient images")

    # Create balanced positive pairs
    positive_pairs = 0
    for person_id in person_ids:
        img_list = images[person_id]
        # Create all possible positive pairs for this person
        for i in range(len(img_list)):
            for j in range(i+1, len(img_list)):
                pairs.append([img_list[i], img_list[j]])
                labels.append(1)
                positive_pairs += 1

                # Limit pairs per person to avoid class imbalance
                if len([l for l in labels if l == 1]) >= pairs_per_person * len(person_ids):
                    break
            if len([l for l in labels if l == 1]) >= pairs_per_person * len(person_ids):
                break

    print(f"Created {positive_pairs} positive pairs")

    # Create equal number of negative pairs
    negative_pairs = 0
    target_negative = positive_pairs

    while negative_pairs < target_negative:
        p1, p2 = random.sample(person_ids, 2)
        img1 = random.choice(images[p1])
        img2 = random.choice(images[p2])
        pairs.append([img1, img2])
        labels.append(0)
        negative_pairs += 1

    print(f"Created {negative_pairs} negative pairs")

    # Shuffle the data
    combined = list(zip(pairs, labels))
    random.shuffle(combined)
    pairs, labels = zip(*combined)

    return np.array(pairs), np.array(labels)


Siamese Network

In [2]:
def build_improved_base_network(input_shape=(112,112,1)):
    inp = layers.Input(shape=input_shape)

    # First Conv Block
    x = layers.Conv2D(64, (7,7), strides=2, padding='same')(inp)
    x = layers.BatchNormalization()(x)
    x = layers.Activation(custom_activation)(x)
    x = layers.MaxPooling2D((2,2))(x)
    x = layers.Dropout(0.1)(x)

    # Second Conv Block
    x = layers.Conv2D(128, (5,5), strides=1, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation(custom_activation)(x)
    x = layers.MaxPooling2D((2,2))(x)
    x = layers.Dropout(0.1)(x)

    # Third Conv Block
    x = layers.Conv2D(256, (3,3), strides=1, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation(custom_activation)(x)
    x = layers.MaxPooling2D((2,2))(x)
    x = layers.Dropout(0.2)(x)

    # Fourth Conv Block
    x = layers.Conv2D(512, (3,3), strides=1, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation(custom_activation)(x)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(0.3)(x)

    # Dense layers
    x = layers.Dense(256)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation(custom_activation)(x)
    x = layers.Dropout(0.4)(x)

    # Final embedding layer
    x = layers.Dense(128)(x)
    x = layers.Lambda(lambda x: K.l2_normalize(x, axis=1))(x)  # L2 normalization

    return models.Model(inp, x)

Loss Function

In [None]:
def improved_contrastive_loss(y_true, y_pred, margin=1.0):
    """Improved contrastive loss with better gradient flow"""
    square_pred = K.square(y_pred)
    margin_square = K.square(K.maximum(margin - y_pred, 0))

    # Add small epsilon to prevent numerical instability
    epsilon = 1e-6
    loss = y_true * square_pred + (1 - y_true) * margin_square
    return K.mean(loss + epsilon)

def triplet_loss(y_true, y_pred, margin=0.2):
    """Alternative triplet loss"""
    anchor, positive, negative = y_pred[:, 0], y_pred[:, 1], y_pred[:, 2]

    pos_dist = K.sum(K.square(anchor - positive), axis=1)
    neg_dist = K.sum(K.square(anchor - negative), axis=1)

    loss = K.maximum(0.0, pos_dist - neg_dist + margin)
    return K.mean(loss)

learning rate scheduler

In [None]:
def create_lr_scheduler():
    def scheduler(epoch, lr):
        if epoch < 5:
            return lr
        elif epoch < 15:
            return lr * 0.95
        else:
            return lr * 0.9
    return tf.keras.callbacks.LearningRateScheduler(scheduler)

Training Function

In [None]:
def train_and_evaluate_improved(dataset_path, save_name):
    print(f"\nTraining improved model on dataset: {dataset_path}")

    # Create balanced pairs
    pairs, labels = make_balanced_pairs(dataset_path, pairs_per_person=200)

    print(f"Total pairs: {len(pairs)}")
    print(f"Positive pairs: {np.sum(labels)}")
    print(f"Negative pairs: {len(labels) - np.sum(labels)}")

    X1, X2, y = pairs[:,0], pairs[:,1], labels

    # Stratified split to maintain class balance
    from sklearn.model_selection import train_test_split
    X1_train, X1_test, X2_train, X2_test, y_train, y_test = train_test_split(
        X1, X2, y, test_size=0.2, random_state=42, stratify=y
    )

    print(f"Training set - Positive: {np.sum(y_train)}, Negative: {len(y_train) - np.sum(y_train)}")
    print(f"Test set - Positive: {np.sum(y_test)}, Negative: {len(y_test) - np.sum(y_test)}")

    # Build improved model
    base_network = build_improved_base_network()

    input_a = layers.Input(shape=(112,112,1))
    input_b = layers.Input(shape=(112,112,1))
    feat_a = base_network(input_a)
    feat_b = base_network(input_b)

    # Euclidean distance
    distance = layers.Lambda(lambda tensors: K.sqrt(K.sum(K.square(tensors[0] - tensors[1]), axis=1, keepdims=True)))([feat_a, feat_b])

    siamese_model = models.Model([input_a, input_b], distance)

    # Compile with improved settings
    optimizer = Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.999)
    siamese_model.compile(loss=improved_contrastive_loss, optimizer=optimizer)

    # Callbacks
    callbacks = [
        create_lr_scheduler(),
        tf.keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True, monitor='val_loss'),
        tf.keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=5, min_lr=1e-7)
    ]

    # Calculate class weights to handle any remaining imbalance
    class_weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
    class_weight_dict = {0: class_weights[0], 1: class_weights[1]}

    print("Training model...")
    history = siamese_model.fit(
        [X1_train, X2_train], y_train,
        batch_size=16,  # Smaller batch size for better gradient updates
        epochs=50,
        validation_data=([X1_test, X2_test], y_test),
        callbacks=callbacks,
        class_weight=class_weight_dict,
        verbose=1
    )

    # Save model
    siamese_model.save(f"{save_name}_improved.h5")

    # Predictions with optimal threshold
    y_pred_dist = siamese_model.predict([X1_test, X2_test])

    # Find optimal threshold using validation data
    thresholds = np.arange(0.1, 2.0, 0.1)
    best_f1 = 0
    best_threshold = 0.5

    for threshold in thresholds:
        y_pred_temp = (y_pred_dist < threshold).astype("int32")
        f1_temp = f1_score(y_test, y_pred_temp)
        if f1_temp > best_f1:
            best_f1 = f1_temp
            best_threshold = threshold

    print(f"Optimal threshold: {best_threshold:.2f}")

    # Final predictions with optimal threshold
    y_pred_class = (y_pred_dist < best_threshold).astype("int32")

    # Metrics
    acc = accuracy_score(y_test, y_pred_class)
    cm = confusion_matrix(y_test, y_pred_class)

    if cm.shape == (2, 2):
        tn, fp, fn, tp = cm.ravel()
        sensitivity = tp / (tp + fn) if (tp+fn) > 0 else 0
        specificity = tn / (tn + fp) if (tn+fp) > 0 else 0
        precision = tp / (tp + fp) if (tp+fp) > 0 else 0
    else:
        sensitivity = specificity = precision = 0

    f1 = f1_score(y_test, y_pred_class)

    print(f"\n=== RESULTS ===")
    print(f"Accuracy: {acc*100:.2f}%")
    print(f"Sensitivity (Recall): {sensitivity*100:.2f}%")
    print(f"Specificity: {specificity*100:.2f}%")
    print(f"Precision: {precision*100:.2f}%")
    print(f"F1 Score: {f1*100:.2f}%")
    print(f"Confusion Matrix:\n{cm}")

    # Enhanced plotting
    plt.figure(figsize=(15,10))

    # Loss curve
    plt.subplot(2,3,1)
    plt.plot(history.history['loss'], label='Train Loss', linewidth=2)
    plt.plot(history.history['val_loss'], label='Val Loss', linewidth=2)
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title(f'{save_name} Loss Curve')
    plt.legend()
    plt.grid(True, alpha=0.3)

    # Distance distribution
    plt.subplot(2,3,2)
    pos_distances = y_pred_dist[y_test == 1].flatten()
    neg_distances = y_pred_dist[y_test == 0].flatten()

    plt.hist(pos_distances, bins=30, alpha=0.7, label='Same Person', color='green')
    plt.hist(neg_distances, bins=30, alpha=0.7, label='Different Person', color='red')
    plt.axvline(best_threshold, color='black', linestyle='--', label=f'Threshold: {best_threshold:.2f}')
    plt.xlabel('Distance')
    plt.ylabel('Frequency')
    plt.title('Distance Distribution')
    plt.legend()
    plt.grid(True, alpha=0.3)

    # Confusion Matrix
    plt.subplot(2,3,3)
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title('Confusion Matrix')
    plt.colorbar()
    tick_marks = np.arange(2)
    plt.xticks(tick_marks, ['Different', 'Same'])
    plt.yticks(tick_marks, ['Different', 'Same'])

    # Add text annotations
    for i in range(2):
        for j in range(2):
            plt.text(j, i, format(cm[i, j], 'd'),
                    horizontalalignment="center",
                    color="white" if cm[i, j] > cm.max() / 2. else "black")

    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')

    # Metrics bar chart
    plt.subplot(2,3,4)
    metrics = ['Accuracy', 'Sensitivity', 'Specificity', 'Precision', 'F1 Score']
    values = [acc*100, sensitivity*100, specificity*100, precision*100, f1*100]
    colors = ['blue', 'green', 'orange', 'red', 'purple']

    bars = plt.bar(metrics, values, color=colors, alpha=0.7)
    plt.ylabel('Percentage')
    plt.title('Performance Metrics')
    plt.xticks(rotation=45)
    plt.grid(True, alpha=0.3, axis='y')

    # Add value labels on bars
    for bar, value in zip(bars, values):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
                f'{value:.1f}%', ha='center', va='bottom')

    # Learning rate curve
    plt.subplot(2,3,5)
    if 'lr' in history.history:
        plt.plot(history.history['lr'], linewidth=2, color='purple')
        plt.xlabel('Epochs')
        plt.ylabel('Learning Rate')
        plt.title('Learning Rate Schedule')
        plt.grid(True, alpha=0.3)
        plt.yscale('log')

    plt.tight_layout()
    plt.savefig(f'{save_name}_improved_results.png', dpi=300, bbox_inches='tight')
    plt.show()

    return siamese_model, history, best_threshold

main function

In [None]:
if __name__ == "__main__":
    # Set random seeds for reproducibility
    np.random.seed(42)
    tf.random.set_seed(42)
    random.seed(42)

    base_path = "C:\\ecg_new\\ecg_grayscale_images"  # <-- change to your dataset base folder

    for folder in ["Single_Beat", "Two_Beats", "Three_Beats"]:
        dataset_path = os.path.join(base_path, folder)
        if os.path.exists(dataset_path):
            model, history, threshold = train_and_evaluate_improved(dataset_path, f"siamese_{folder}")
            print(f"\nCompleted training for {folder}")
            print("="*50)
        else:
            print(f"Dataset path not found: {dataset_path}")