In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install wfdb
!pip install scikit-learn
!pip install vmdpy
!pip install pywavelets
!pip install pyunpack patool

In [None]:
import os
import numpy as np
from scipy.io import loadmat
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from tensorflow.keras.layers import Input, Dense, Dropout, Conv1D, MaxPooling1D, LSTM, Bidirectional, Reshape, Concatenate, Layer
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from scipy.signal import find_peaks, butter, filtfilt, welch
from scipy.stats import entropy
import pywt
from tensorflow.keras import backend as K
from sklearn.metrics import roc_curve, auc, confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns
import joblib
from tensorflow.keras.models import load_model
from scipy.signal import lfilter
from tensorflow.keras.layers import Layer
from tensorflow.keras.saving import register_keras_serializable
from pyunpack import Archive
from tensorflow.keras.utils import plot_model
from sklearn.utils import class_weight

#Code

In [None]:
base_path = "/content/drive/MyDrive/Colab Notebooks/An open-access arrhythmia database of wearable dynamic electrocardiogram.rar"
extract_path = "/content/drive/MyDrive/Colab Notebooks/Extracted/"

if not os.path.exists(extract_path) or not os.listdir(extract_path):
    os.makedirs(extract_path, exist_ok=True)
    Archive(base_path).extractall(extract_path)
    print(f"Data extracted to {extract_path}")
else:
    print(f"Data already exists in the directory: {extract_path}")



In [None]:
# Function to load the database
def load_ecg_database(base_path):
    database = {"A": [], "N": [], "V": []}
    for root, dirs, files in os.walk(base_path):
        for file in files:
            if file.endswith('.mat'):  # Process only `.mat` files
                file_path = os.path.join(root, file)
                # Extract main folder (e.g., "A", "N", "V")
                main_folder = file_path.split('/')[-4]
                if main_folder in database:
                    try:
                        mat_data = loadmat(file_path)
                        database[main_folder].append((file, mat_data))
                    except Exception as e:
                        print(f"Error loading file {file_path}: {e}")
    return database

# Low-pass filter function
def butter_lowpass_filter(data, cutoff, fs, order=4):
    nyquist = 0.5 * fs
    low = cutoff / nyquist
    b, a = butter(order, low, btype='low')
    y = filtfilt(b, a, data)
    return y

# High-pass filter function
def butter_highpass_filter(data, cutoff, fs, order=4):
    nyquist = 0.5 * fs
    high = cutoff / nyquist
    b, a = butter(order, high, btype='high')
    y = filtfilt(b, a, data)
    return y

# Function for wavelet-based filtering
def wavelet_filter(signal, wavelet='db6', level=5):
    """
    Filtering the signal using wavelet transform.
    """
    coeffs = pywt.wavedec(signal, wavelet=wavelet, level=level)

    # Zeroing out high-frequency noise components
    for i in range(1, len(coeffs)):
        coeffs[i] = pywt.threshold(coeffs[i], value=np.std(coeffs[i]) * 0.5, mode='soft')

    # Signal reconstruction
    clean_signal = pywt.waverec(coeffs, wavelet)
    return clean_signal[:len(signal)]

# Function for segmentation by time
def segment_signal_by_time(signal, sampling_rate, segment_duration=5):
    segment_length = segment_duration * sampling_rate
    num_segments = len(signal) // segment_length
    segments = [
        signal[i * segment_length:(i + 1) * segment_length]
        for i in range(num_segments)
    ]
    return np.array(segments)

# Function for feature extraction
def extract_features(segment, fs):
    peaks, properties = find_peaks(segment, height=0.5)
    rr_intervals = np.diff(peaks) / fs

    # Time-domain features
    rMSSD = np.sqrt(np.mean(np.square(np.diff(rr_intervals)))) if len(rr_intervals) > 1 else 0
    PRR50 = np.sum(np.abs(np.diff(rr_intervals)) > 0.05) / len(rr_intervals) if len(rr_intervals) > 1 else 0
    PRR20 = np.sum(np.abs(np.diff(rr_intervals)) > 0.02) / len(rr_intervals) if len(rr_intervals) > 1 else 0
    SDSD = np.std(np.diff(rr_intervals)) if len(rr_intervals) > 1 else 0
    SDRR = np.std(rr_intervals) if len(rr_intervals) > 0 else 0

    # Frequency-domain features
    freqs, power = welch(segment, fs=fs)
    low_freq = np.sum(power[(freqs >= 0.1) & (freqs < 0.5)])
    mid_freq = np.sum(power[(freqs >= 0.5) & (freqs < 15)])
    high_freq = np.sum(power[(freqs >= 15) & (freqs < 40)])

    # Poincaré metrics
    if len(rr_intervals) > 1:
        sd1 = np.std(np.diff(rr_intervals)) / np.sqrt(2)
        sd2 = np.std(rr_intervals) / np.sqrt(2)
    else:
        sd1, sd2 = 0, 0

    # Additional features: Min/Max/Average heart rate
    if len(rr_intervals) > 0:
        heart_rates = 60 / rr_intervals
        min_hr = np.min(heart_rates)
        max_hr = np.max(heart_rates)
        avg_hr = np.mean(heart_rates)
    else:
        min_hr, max_hr, avg_hr = 0, 0, 0

    # Additional features: Average peak amplitudes (P/Q/R/S/T)
    peak_amplitudes = properties["peak_heights"] if "peak_heights" in properties else []
    avg_peak_amplitude = np.mean(peak_amplitudes) if len(peak_amplitudes) > 0 else 0

    # Additional features: Average peak delays (PQ/QR/RS/ST)
    if len(peaks) >= 5:
        pq_delay = np.mean(peaks[1] - peaks[0]) / fs
        qr_delay = np.mean(peaks[2] - peaks[1]) / fs
        rs_delay = np.mean(peaks[3] - peaks[2]) / fs
        st_delay = np.mean(peaks[4] - peaks[3]) / fs
    else:
        pq_delay, qr_delay, rs_delay, st_delay = 0, 0, 0, 0

    # Combine features
    features = [
        rMSSD, PRR50, PRR20, SDSD, SDRR,
        low_freq, mid_freq, high_freq,
        sd1, sd2, min_hr, max_hr, avg_hr,
        avg_peak_amplitude, pq_delay, qr_delay, rs_delay, st_delay
    ]

    return features

# Updated preprocess_data Function
def preprocess_data(all_data, sampling_rate=400, segment_duration=5):
    X_segments, X_features, y = [], [], []
    label_mapping = {"A": 1, "N": 0, "V": 2}

    for folder, files in all_data.items():
        if folder not in label_mapping:
            print(f"Ignoring unknown folder: {folder}")
            continue

        label = label_mapping[folder]
        print(f"Processing folder: {folder}, number of files: {len(files)}")
        for mat_file_name, mat_data in files:
            print(f"  Processing file: {mat_file_name}")
            if 'ECG' in mat_data:
                try:
                    ecg_signal = mat_data['ECG']

                    # Debugging content
                    print(f"    Data type: {type(ecg_signal)}, Size: {np.shape(ecg_signal)}")

                    # Unpacking and validating
                    if isinstance(ecg_signal, np.ndarray):
                        # Handle arrays of size (5, 1)
                        if ecg_signal.ndim == 2 and ecg_signal.shape[1] == 1:
                            ecg_signal = ecg_signal.flatten()

                        # Handle nested structures
                        if ecg_signal.dtype == 'O':
                            ecg_signal = ecg_signal[0].flatten()

                    # Validation: ensure data is numeric
                    if not isinstance(ecg_signal, np.ndarray) or not np.issubdtype(ecg_signal.dtype, np.number):
                        print(f"    Invalid data in file: {mat_file_name}")
                        continue

                    # Filtering
                    # 1. Apply wavelet denoising
                    ecg_signal = wavelet_filter(ecg_signal, wavelet='db6', level=5)

                    # 2. Apply high-pass filter to remove baseline wander
                    ecg_signal = butter_highpass_filter(ecg_signal, cutoff=0.5, fs=sampling_rate, order=4)

                    # 3. Apply low-pass filter to remove high-frequency noise
                    ecg_signal = butter_lowpass_filter(ecg_signal, cutoff=50, fs=sampling_rate, order=4)

                    # Augmentation: Only for class 0 ("N")
                    augmented_signals = [ecg_signal]
                    if label == 0:
                        augmented_signals.append(-ecg_signal)

                    for aug_signal in augmented_signals:
                        segments = segment_signal_by_time(aug_signal, sampling_rate, segment_duration)

                        for segment in segments:
                            features = extract_features(segment, fs=sampling_rate)
                            X_segments.append(segment)
                            X_features.append(features)
                            y.append(label)

                except Exception as e:
                    print(f"    Error processing file {mat_file_name}: {e}")
                    continue

    return np.array(X_segments), np.array(X_features), np.array(y)


In [None]:
# --- Importing Libraries ---
from tensorflow.keras.layers import (
    Conv1D, MaxPooling1D, Dense, Dropout, Input, Flatten, Concatenate, Reshape, BatchNormalization,
    Bidirectional, LSTM, Layer
)
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping
from tensorflow.keras.regularizers import l2
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import tensorflow.keras.backend as K
import numpy as np

# --- Attention Layer Class ---
class Attention(Layer):
    def __init__(self, **kwargs):
        super(Attention, self).__init__(**kwargs)

    def build(self, input_shape):
        self.W = self.add_weight(name="att_weight", shape=(input_shape[-1], 1),
                                 initializer="normal")
        self.b = self.add_weight(name="att_bias", shape=(input_shape[1], 1),
                                 initializer="zeros")
        super(Attention, self).build(input_shape)

    def call(self, x):
        e = K.tanh(K.dot(x, self.W) + self.b)
        a = K.softmax(e, axis=1)
        output = x * a
        return K.sum(output, axis=1)


# --- Loading and Processing Data ---
all_data = load_ecg_database(extract_path)
X_segments, X_features, y = preprocess_data(all_data)

# --- Data Splitting ---
X_segments_train, X_segments_temp, X_features_train, X_features_temp, y_train, y_temp = train_test_split(
    X_segments, X_features, y, test_size=0.3, random_state=42
)
X_segments_val, X_segments_test, X_features_val, X_features_test, y_val, y_test = train_test_split(
    X_segments_temp, X_features_temp, y_temp, test_size=0.5, random_state=42
)

# Feature Scaling
scaler = StandardScaler()
X_features_train = scaler.fit_transform(X_features_train)
X_features_val = scaler.transform(X_features_val)
X_features_test = scaler.transform(X_features_test)

# Model Architecture
input_shape = (X_segments_train.shape[1], 1)

input_layer = Input(shape=input_shape)
conv1 = Conv1D(32, 5, activation="relu", kernel_regularizer=l2(0.01))(input_layer)
conv1 = BatchNormalization()(conv1)
pool1 = MaxPooling1D(2)(conv1)

conv2 = Conv1D(64, 5, activation="relu", kernel_regularizer=l2(0.01))(pool1)
conv2 = BatchNormalization()(conv2)
pool2 = MaxPooling1D(2)(conv2)

conv3 = Conv1D(128, 3, activation="relu", kernel_regularizer=l2(0.01))(pool2)
conv3 = BatchNormalization()(conv3)
pool3 = MaxPooling1D(2)(conv3)

reshape = Reshape((-1, 128))(pool3)
bi_lstm = Bidirectional(LSTM(64, return_sequences=True, dropout=0.4))(reshape)
attention = Attention()(bi_lstm)

input_features = Input(shape=(X_features_train.shape[1],))
dense_features = Dense(64, activation="relu", kernel_regularizer=l2(0.01))(input_features)
dense_features = BatchNormalization()(dense_features)
dense_features = Dropout(0.4)(dense_features)

merged = Concatenate()([attention, dense_features])
dense1 = Dense(128, activation="relu", kernel_regularizer=l2(0.01))(merged)
dense1 = BatchNormalization()(dense1)
dropout1 = Dropout(0.5)(dense1)

output_layer = Dense(len(np.unique(y)), activation="softmax")(dropout1)

# Compiling the Model
model = Model(inputs=[input_layer, input_features], outputs=output_layer)
model.compile(optimizer=Adam(learning_rate=0.0003), loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# Callbacks
lr_scheduler = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, min_lr=1e-6)
early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True, min_delta=0.001)

# Training
history = model.fit(
    [X_segments_train, X_features_train],
    y_train,
    validation_data=([X_segments_val, X_features_val], y_val),
    epochs=50,
    batch_size=32,
    callbacks=[lr_scheduler, early_stopping]
)

# Evaluation
test_loss, test_acc = model.evaluate([X_segments_test, X_features_test], y_test)
print(f"Test Accuracy: {test_acc:.4f}")


# Results

In [None]:
from sklearn.metrics import roc_curve, auc, RocCurveDisplay
from sklearn.preprocessing import label_binarize
import matplotlib.pyplot as plt
import numpy as np

# Definition of class labels
class_labels = ["N", "A", "V"]

# Binarization of labels for multi-class ROC curve
y_test_bin = label_binarize(y_test, classes=np.unique(y_train))
y_pred_prob = model.predict([X_segments_test, X_features_test])

# Generating the ROC curve for each class
plt.figure(figsize=(10, 6))
for i, class_name in enumerate(class_labels):
    fpr, tpr, _ = roc_curve(y_test_bin[:, i], y_pred_prob[:, i])
    roc_auc = auc(fpr, tpr)
    plt.plot(fpr, tpr, label=f"Class {class_name} (AUC = {roc_auc:.2f})")

# Central axis
plt.plot([0, 1], [0, 1], 'k--')
plt.title("ROC Curve")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.legend(loc="lower right")
plt.grid()
plt.show()

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

# Predictions on test data
y_pred = model.predict([X_segments_test, X_features_test]).argmax(axis=1)

# Compute confusion matrix
cm = confusion_matrix(y_test, y_pred, labels=np.unique(y_train))

# Display confusion matrix
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_labels)
disp.plot(cmap=plt.cm.Blues)
plt.title("Confusion Matrix")
plt.show()


# Plotting the model learning curve
plt.figure(figsize=(10, 6))

# Loss
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title("Model Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid()

# Accuracy
plt.subplot(1, 2, 2)
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title("Model Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.grid()

plt.tight_layout()
plt.show()


from sklearn.metrics import classification_report

# Classification report
report = classification_report(y_test, y_pred, target_names=class_labels)
print("Classification Report:\n")
print(report)


# Compute accuracy, precision, recall, and F1-score for each class
from sklearn.metrics import precision_score, recall_score, f1_score

for i, class_name in enumerate(class_labels):
    precision = precision_score(y_test, y_pred, labels=[i], average='micro')
    recall = recall_score(y_test, y_pred, labels=[i], average='micro')
    f1 = f1_score(y_test, y_pred, labels=[i], average='micro')
    print(f"Class {class_name}: Precision = {precision:.2f}, Recall = {recall:.2f}, F1-Score = {f1:.2f}")


#Real-time

In [None]:
# Path to save the scaler
scaler_path = "saved_model/scaler.pkl"

# Create the directory if it doesn't exist
os.makedirs(os.path.dirname(scaler_path), exist_ok=True)

# Save the scaler
joblib.dump(scaler, scaler_path)
print(f"Scaler saved at {scaler_path}")


In [None]:
# Path to save the model
save_path = "saved_model/arrhythmia_model.keras"

# Save the model
model.save(save_path)
print(f"Model saved at {save_path}")

In [None]:
import numpy as np
import pywt
from scipy.signal import butter, filtfilt, find_peaks, welch
from sklearn.preprocessing import StandardScaler
from tensorflow.keras.models import load_model
import joblib
from scipy.io import loadmat
import matplotlib.pyplot as plt


# Filtering
def butter_lowpass_filter(data, cutoff, fs, order=4):
    nyquist = 0.5 * fs
    low = cutoff / nyquist
    b, a = butter(order, low, btype='low')
    return filtfilt(b, a, data)

def butter_highpass_filter(data, cutoff, fs, order=4):
    nyquist = 0.5 * fs
    high = cutoff / nyquist
    b, a = butter(order, high, btype='high')
    return filtfilt(b, a, data)

def wavelet_filter(signal, wavelet='db6', level=5):
    """
    Filtering the signal using wavelet transform.
    """
    coeffs = pywt.wavedec(signal, wavelet=wavelet, level=level)

    # Zeroing high-level details (high-frequency noise)
    for i in range(1, len(coeffs)):
        coeffs[i] = pywt.threshold(coeffs[i], value=np.std(coeffs[i]) * 0.5, mode='soft')

    # Signal reconstruction
    clean_signal = pywt.waverec(coeffs, wavelet)
    return clean_signal[:len(signal)]


# Feature extraction
def extract_features(segment, fs):
    peaks, properties = find_peaks(segment, height=0.5)
    rr_intervals = np.diff(peaks) / fs

    # Time-domain features
    rMSSD = np.sqrt(np.mean(np.square(np.diff(rr_intervals)))) if len(rr_intervals) > 1 else 0
    PRR50 = np.sum(np.abs(np.diff(rr_intervals)) > 0.05) / len(rr_intervals) if len(rr_intervals) > 1 else 0
    PRR20 = np.sum(np.abs(np.diff(rr_intervals)) > 0.02) / len(rr_intervals) if len(rr_intervals) > 1 else 0
    SDSD = np.std(np.diff(rr_intervals)) if len(rr_intervals) > 1 else 0
    SDRR = np.std(rr_intervals) if len(rr_intervals) > 0 else 0

    # Frequency-domain features
    freqs, power = welch(segment, fs=fs)
    low_freq = np.sum(power[(freqs >= 0.1) & (freqs < 0.5)])
    mid_freq = np.sum(power[(freqs >= 0.5) & (freqs < 15)])
    high_freq = np.sum(power[(freqs >= 15) & (freqs < 40)])

    # Poincaré metrics
    if len(rr_intervals) > 1:
        sd1 = np.std(np.diff(rr_intervals)) / np.sqrt(2)
        sd2 = np.std(rr_intervals) / np.sqrt(2)
    else:
        sd1, sd2 = 0, 0

    # Additional features: Min/Max/Average heart rate
    if len(rr_intervals) > 0:
        heart_rates = 60 / rr_intervals
        min_hr = np.min(heart_rates)
        max_hr = np.max(heart_rates)
        avg_hr = np.mean(heart_rates)
    else:
        min_hr, max_hr, avg_hr = 0, 0, 0

    # Additional features: Average peak amplitudes (P/Q/R/S/T)
    peak_amplitudes = properties["peak_heights"] if "peak_heights" in properties else []
    avg_peak_amplitude = np.mean(peak_amplitudes) if len(peak_amplitudes) > 0 else 0

    # Additional features: Average peak delays (PQ/QR/RS/ST)
    if len(peaks) >= 5:
        pq_delay = np.mean(peaks[1] - peaks[0]) / fs
        qr_delay = np.mean(peaks[2] - peaks[1]) / fs
        rs_delay = np.mean(peaks[3] - peaks[2]) / fs
        st_delay = np.mean(peaks[4] - peaks[3]) / fs
    else:
        pq_delay, qr_delay, rs_delay, st_delay = 0, 0, 0, 0

    # Combine features
    return [
        rMSSD, PRR50, PRR20, SDSD, SDRR,
        low_freq, mid_freq, high_freq,
        sd1, sd2, min_hr, max_hr, avg_hr,
        avg_peak_amplitude, pq_delay, qr_delay, rs_delay, st_delay
    ]

# Segmenting the signal into 5-second fragments
def segment_signal_by_time(signal, sampling_rate, segment_duration=5):
    segment_length = segment_duration * sampling_rate
    num_segments = len(signal) // segment_length

    segments = [
        signal[i * segment_length:(i + 1) * segment_length]
        for i in range(num_segments)
    ]

    # Handling remaining samples (padding the last segment)
    remaining_samples = len(signal) % segment_length
    if remaining_samples > 0:
        last_segment = signal[-remaining_samples:]
        # Padding to 2000 samples (repeating the last sample)
        if len(last_segment) < 2000:
            padding_length = 2000 - len(last_segment)
            last_segment = np.pad(last_segment, (0, padding_length), mode='constant', constant_values=0)
        segments.append(last_segment)

    return np.array(segments)


# Real-time signal processing
def process_real_time_signal(signal, model, scaler, sampling_rate=400, segment_duration=5):
    try:
        # Processing the signal
        ecg_signal = signal
        ecg_signal = butter_lowpass_filter(ecg_signal, cutoff=50.0, fs=sampling_rate)  # Low-pass filter
        ecg_signal = butter_highpass_filter(ecg_signal, cutoff=0.5, fs=sampling_rate)  # High-pass filter
        ecg_signal = wavelet_filter(ecg_signal)  # Wavelet filter

        # Segmenting the signal into 5-second fragments
        segments = segment_signal_by_time(ecg_signal, sampling_rate, segment_duration)

        # Store prediction results for each segment
        class_mapping = {0: "N", 1: "A", 2: "V"}
        predictions = []

        # Processing each segment
        for i, segment in enumerate(segments):
            segment = segment.reshape(1, -1, 1)

            if segment.shape != (1, 2000, 1):
                continue

            features = extract_features(segment.flatten(), fs=sampling_rate)
            features = scaler.transform([features])

            prediction = model.predict([segment, features])
            predicted_class = np.argmax(prediction, axis=1)[0]

            predictions.append((i, predicted_class))

        # Display ECG signal plot with labeled segments
        plt.figure(figsize=(30, 10))
        plt.plot(np.arange(len(ecg_signal)) / sampling_rate, ecg_signal, label='ECG Signal', color='red')

        # Adding dashed lines for segments
        for i, predicted_class in predictions:
            start_idx = i * 2000
            end_idx = (i + 1) * 2000
            plt.axvline(x=start_idx / sampling_rate, color='k', linestyle='--', linewidth=1)
            plt.axvline(x=end_idx / sampling_rate, color='k', linestyle='--', linewidth=1)

            # Adding label above segment
            class_label = class_mapping[predicted_class]
            plt.text((start_idx + end_idx) / (2 * sampling_rate), 0.7 * np.max(ecg_signal), class_label, ha='center', va='bottom', fontsize=12, color='black')

        # Adding title and labels
        plt.title("Real-Time ECG Signal with Predicted Arrhythmias", fontsize=23)
        plt.xlabel('Time (s)', fontsize=20)
        plt.ylabel('Amplitude', fontsize=20)
        plt.grid(True)
        plt.show()

    except Exception as e:
        print(f"Error in signal processing: {e}")


# Loading model and scaler
model_path = "saved_model/arrhythmia_model.keras"
model = load_model(model_path, custom_objects={"Attention": Attention})

scaler_path = "saved_model/scaler.pkl"
scaler = joblib.load(scaler_path)

# Load multiple .mat files
file_paths = [...]  # List of file paths

# Process the combined signal in real time
process_real_time_signal(np.array(ecg_signal_combined), model, scaler)
