# Deep FECG Research: All-in-One Experiment Notebook for Google Colab

This notebook is optimized for Python 3.12+ and modern libraries in a Google Colab environment. It contains all the code for data preprocessing, feature extraction, and model training using a self-contained `gcForest` class.

## 1. Setup Environment

This cell installs all necessary libraries. Run it first.

In [None]:
!pip install -q wfdb librosa pywavelets ssqueezepy imbalanced-learn shap matplotlib

: 

## 2. Mount Google Drive & Define Paths

This section mounts your Google Drive to make your dataset accessible. You will need to authorize Colab to access your Drive.

**IMPORTANT:** After running the second cell, you **must** update the `PROJECT_PATH` variable to point to the correct location of your project folder on Google Drive.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os

# TODO: Update this path to your project directory on Google Drive
PROJECT_PATH = '/content/drive/MyDrive/deep_fecg_research'

# --- You should not need to edit below this line ---
DATA_PATH = os.path.join(PROJECT_PATH, 'data/mit-bih-arrhythmia-database-1.0.0')
OUTPUT_PATH = os.path.join(PROJECT_PATH, 'colab_outputs')

# Create an output directory for plots if it doesn't exist
os.makedirs(OUTPUT_PATH, exist_ok=True)

print(f"Project path set to: {PROJECT_PATH}")
print(f"Data path set to: {DATA_PATH}")
print(f"Output path set to: {OUTPUT_PATH}")

## 3. All-in-One Experiment Code

The following cells contain all the necessary code for the experiment pipeline.

In [None]:
import itertools
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

class gcForest(object):
    def __init__(self, shape_1X=None, n_mgsRFtree=30, window=None, stride=1,
                 cascade_test_size=0.2, n_cascadeRF=2, n_cascadeRFtree=101, cascade_layer=np.inf,
                 min_samples_mgs=0.1, min_samples_cascade=0.05, tolerance=0.0, n_jobs=1, use_mg_scanning=True):
        self.shape_1X = shape_1X
        self.n_layer = 0
        self._n_samples = 0
        self.n_cascadeRF = int(n_cascadeRF)
        self.window = [window] if isinstance(window, int) else window
        self.stride = stride
        self.cascade_test_size = cascade_test_size
        self.n_mgsRFtree = int(n_mgsRFtree)
        self.n_cascadeRFtree = int(n_cascadeRFtree)
        self.cascade_layer = cascade_layer
        self.min_samples_mgs = min_samples_mgs
        self.min_samples_cascade = min_samples_cascade
        self.tolerance = tolerance
        self.n_jobs = n_jobs
        self.use_mg_scanning = use_mg_scanning

    def fit(self, X, y):
        if X.shape[0] != len(y):
            raise ValueError('Sizes of y and X do not match.')
        if self.use_mg_scanning:
            X = self.mg_scanning(X, y)
        self.cascade_forest(X, y)

    def predict_proba(self, X):
        if self.use_mg_scanning:
            X = self.mg_scanning(X)
        cascade_all_pred_prob = self.cascade_forest(X)
        return np.mean(cascade_all_pred_prob, axis=0)

    def predict(self, X):
        pred_proba = self.predict_proba(X=X)
        return np.argmax(pred_proba, axis=1)

    def mg_scanning(self, X, y=None):
        self._n_samples = X.shape[0]
        shape_1X = self.shape_1X
        if isinstance(shape_1X, int):
            shape_1X = [1, shape_1X]
        if not self.window:
            self.window = [shape_1X[1]]
        mgs_pred_prob = []
        for wdw_size in self.window:
            wdw_pred_prob = self._window_slicing_pred_prob(X, wdw_size, shape_1X, y=y)
            mgs_pred_prob.append(wdw_pred_prob)
        return np.concatenate(mgs_pred_prob, axis=1)

    def _window_slicing_pred_prob(self, X, window, shape_1X, y=None):
        if shape_1X[0] > 1:
            sliced_X, sliced_y = self._window_slicing_img(X, window, shape_1X, y=y, stride=self.stride)
        else:
            sliced_X, sliced_y = self._window_slicing_sequence(X, window, shape_1X, y=y, stride=self.stride)
        if y is not None:
            prf = RandomForestClassifier(n_estimators=self.n_mgsRFtree, max_features='sqrt', min_samples_split=self.min_samples_mgs, oob_score=True, n_jobs=self.n_jobs)
            crf = RandomForestClassifier(n_estimators=self.n_mgsRFtree, max_features=1, min_samples_split=self.min_samples_mgs, oob_score=True, n_jobs=self.n_jobs)
            prf.fit(sliced_X, sliced_y)
            crf.fit(sliced_X, sliced_y)
            setattr(self, f'_mgsprf_{window}', prf)
            setattr(self, f'_mgscrf_{window}', crf)
            pred_prob_prf = prf.oob_decision_function_
            pred_prob_crf = crf.oob_decision_function_
        else:
            prf = getattr(self, f'_mgsprf_{window}')
            crf = getattr(self, f'_mgscrf_{window}')
            pred_prob_prf = prf.predict_proba(sliced_X)
            pred_prob_crf = crf.predict_proba(sliced_X)
        pred_prob = np.c_[pred_prob_prf, pred_prob_crf]
        return pred_prob.reshape([self._n_samples, -1])

    def _window_slicing_sequence(self, X, window, shape_1X, y=None, stride=1):
        if shape_1X[1] < window:
            raise ValueError('window must be smaller than the sequence dimension')
        len_iter = (shape_1X[1] - window) // stride + 1
        iter_array = np.arange(0, stride * len_iter, stride)
        inds_to_take = [np.arange(i, i + window) for i in iter_array]
        sliced_X = np.take(X, inds_to_take, axis=1).reshape(-1, window)
        if y is not None:
            sliced_y = np.repeat(y, len_iter)
        else:
            sliced_y = None
        return sliced_X, sliced_y

    def cascade_forest(self, X, y=None):
        if y is not None:
            self.n_layer = 0
            X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=self.cascade_test_size)
            self.n_layer += 1
            prf_crf_pred_ref = self._cascade_layer(X_train, y_train)
            accuracy_ref = self._cascade_evaluation(X_test, y_test)
            feat_arr = self._create_feat_arr(X_train, prf_crf_pred_ref)
            self.n_layer += 1
            prf_crf_pred_layer = self._cascade_layer(feat_arr, y_train)
            accuracy_layer = self._cascade_evaluation(X_test, y_test)
            while accuracy_layer > (accuracy_ref + self.tolerance) and self.n_layer <= self.cascade_layer:
                accuracy_ref = accuracy_layer
                prf_crf_pred_ref = prf_crf_pred_layer
                feat_arr = self._create_feat_arr(X_train, prf_crf_pred_ref)
                self.n_layer += 1
                prf_crf_pred_layer = self._cascade_layer(feat_arr, y_train)
                accuracy_layer = self._cascade_evaluation(X_test, y_test)
            if accuracy_layer < accuracy_ref:
                for irf in range(self.n_cascadeRF):
                    delattr(self, f'_casprf{self.n_layer}_{irf}')
                    delattr(self, f'_cascrf{self.n_layer}_{irf}')
                self.n_layer -= 1
        else:
            at_layer = 1
            prf_crf_pred_ref = self._cascade_layer(X, layer=at_layer)
            while at_layer < self.n_layer:
                at_layer += 1
                feat_arr = self._create_feat_arr(X, prf_crf_pred_ref)
                prf_crf_pred_ref = self._cascade_layer(feat_arr, layer=at_layer)
        return prf_crf_pred_ref

    def _cascade_layer(self, X, y=None, layer=0):
        prf = RandomForestClassifier(n_estimators=self.n_cascadeRFtree, max_features='sqrt', min_samples_split=self.min_samples_cascade, oob_score=True, n_jobs=self.n_jobs)
        crf = RandomForestClassifier(n_estimators=self.n_cascadeRFtree, max_features=1, min_samples_split=self.min_samples_cascade, oob_score=True, n_jobs=self.n_jobs)
        prf_crf_pred = []
        if y is not None:
            for irf in range(self.n_cascadeRF):
                prf.fit(X, y)
                crf.fit(X, y)
                setattr(self, f'_casprf{self.n_layer}_{irf}', prf)
                setattr(self, f'_cascrf{self.n_layer}_{irf}', crf)
                prf_crf_pred.append(prf.oob_decision_function_)
                prf_crf_pred.append(crf.oob_decision_function_)
        else:
            for irf in range(self.n_cascadeRF):
                prf = getattr(self, f'_casprf{layer}_{irf}')
                crf = getattr(self, f'_cascrf{layer}_{irf}')
                prf_crf_pred.append(prf.predict_proba(X))
                prf_crf_pred.append(crf.predict_proba(X))
        return prf_crf_pred

    def _cascade_evaluation(self, X_test, y_test):
        casc_pred_prob = np.mean(self.cascade_forest(X_test), axis=0)
        casc_pred = np.argmax(casc_pred_prob, axis=1)
        return accuracy_score(y_true=y_test, y_pred=casc_pred)

    def _create_feat_arr(self, X, prf_crf_pred):
        swap_pred = np.swapaxes(prf_crf_pred, 0, 1)
        add_feat = swap_pred.reshape([X.shape[0], -1])
        return np.concatenate([add_feat, X], axis=1)


In [None]:
import wfdb
from scipy.signal import butter, filtfilt
from sklearn.model_selection import train_test_split
from imblearn.over_sampling import SMOTE

AAMI_CLASSES = {
    'N': 0, 'L': 0, 'R': 0, 'e': 0, 'j': 0, 
    'A': 1, 'a': 1, 'J': 1, 'S': 1, 
    'V': 2, 'E': 2, 
    'F': 3, 
    '/': 4, 'f': 4, 'Q': 4, 
}

def get_aami_class(symbol):
    return AAMI_CLASSES.get(symbol)

def apply_bandpass_filter(signal, fs=360):
    lowcut = 0.5
    highcut = 45.0
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = butter(2, [low, high], btype='band')
    return filtfilt(b, a, signal)

def segment_heartbeats(signal, annotations, fs=360, window_size=360):
    heartbeats, labels = [], []
    window_before = window_size // 2
    window_after = window_size - window_before
    for i, symbol in enumerate(annotations.symbol):
        aami_class = get_aami_class(symbol)
        if aami_class is not None:
            peak_sample = annotations.sample[i]
            start, end = peak_sample - window_before, peak_sample + window_after
            if start >= 0 and end < len(signal):
                heartbeats.append(signal[start:end])
                labels.append(aami_class)
    return np.array(heartbeats), np.array(labels)

def preprocess_data(data_path, window_size=360, max_records=None):
    print(f"Starting data preprocessing...")
    record_names = sorted([f.split('.')[0] for f in os.listdir(data_path) if f.endswith('.hea')])
    all_heartbeats, all_labels = [], []
    for i, record_name in enumerate(record_names):
        if max_records and i >= max_records:
            break
        try:
            record = wfdb.rdrecord(os.path.join(data_path, record_name))
            annotations = wfdb.rdann(os.path.join(data_path, record_name), 'atr')
            signal = record.p_signal[:, record.sig_name.index('MLII') if 'MLII' in record.sig_name else 0]
            filtered_signal = apply_bandpass_filter(signal, fs=record.fs)
            heartbeats, labels = segment_heartbeats(filtered_signal, annotations, fs=record.fs, window_size=window_size)
            all_heartbeats.append(heartbeats)
            all_labels.append(labels)
        except Exception as e:
            print(f"Could not process record {record_name}: {e}")
    if not all_heartbeats:
        raise ValueError("No heartbeats processed. Check data path and file integrity.")
    X, y = np.concatenate(all_heartbeats), np.concatenate(all_labels)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
    print("Applying SMOTE to balance the training data...")
    smote = SMOTE(random_state=42)
    X_train_resampled, y_train_resampled = smote.fit_resample(X_train, y_train)
    print(f"Original training samples: {len(y_train)}, Resampled training samples: {len(y_train_resampled)}")
    print("Data preprocessing complete.")
    return X_train_resampled, X_test, y_train_resampled, y_test


In [None]:
import librosa
import pywt

def extract_features(train_data, test_data, method='MFCC'):
    print(f"Extracting features using {method} method...")
    if method == 'MFCC':
        train_features = _extract_mfcc(train_data)
        test_features = _extract_mfcc(test_data)
    elif method == 'DWT':
        train_features = _extract_dwt(train_data)
        test_features = _extract_dwt(test_data)
    else:
        raise ValueError(f"Unknown feature extraction method: {method}")
    print("Feature extraction complete.")
    return train_features, test_features

def _extract_mfcc(data, sr=360, n_mfcc=13):
    mfccs = [np.mean(librosa.feature.mfcc(y=heartbeat.astype(float), sr=sr, n_mfcc=n_mfcc, n_fft=2048).T, axis=0) for heartbeat in data]
    return np.array(mfccs)

def _extract_dwt(data, wavelet='db4', level=4):
    coeffs = [pywt.wavedec(heartbeat, wavelet, level=level) for heartbeat in data]
    flat_features = [np.concatenate([c.flatten() for c in coef]) for coef in coeffs]
    max_len = max(len(f) for f in flat_features)
    padded_features = np.array([np.pad(f, (0, max_len - len(f))) for f in flat_features])
    return padded_features


In [None]:
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_auc_score, confusion_matrix
from sklearn.model_selection import GridSearchCV, StratifiedKFold

def train_and_evaluate(train_features, train_labels, test_features, test_labels, model_type='gcForest'):
    print(f"--- Training and evaluating {model_type} model ---")
    if model_type == 'CascadeForest':
        param_grid = {
            'n_cascadeRFtree': [101, 151], 'n_cascadeRF': [2],
            'min_samples_cascade': [0.05, 0.1], 'cascade_layer': [15, 25], 'tolerance': [0.005]
        }
        model_base = gcForest(use_mg_scanning=False, n_jobs=-1)
    elif model_type == 'gcForest':
        feature_dim = train_features.shape[1]
        param_grid = {
            'window': [[int(feature_dim * 0.2)], [int(feature_dim * 0.3)]], 'n_mgsRFtree': [30],
            'n_cascadeRFtree': [101], 'n_cascadeRF': [2], 'cascade_layer': [15], 'tolerance': [0.005]
        }
        model_base = gcForest(shape_1X=train_features.shape[1], n_jobs=-1)
    else:
        raise ValueError(f"Invalid model type: {model_type}")
    cv = StratifiedKFold(n_splits=2, shuffle=True, random_state=42)
    grid_search = GridSearchCV(estimator=model_base, param_grid=param_grid, cv=cv, n_jobs=-1, verbose=1, scoring='accuracy')
    grid_search.fit(train_features, train_labels)
    print(f"Best parameters for {model_type}: {grid_search.best_params_}")
    model = grid_search.best_estimator_
    print("Evaluating the best model on the test set...")
    predictions = model.predict(test_features)
    probas = model.predict_proba(test_features)
    print(f"Accuracy: {accuracy_score(test_labels, predictions):.4f}")
    print(f"F1-score: {f1_score(test_labels, predictions, average='weighted'):.4f}")
    print(f"Precision: {precision_score(test_labels, predictions, average='weighted'):.4f}")
    print(f"Recall: {recall_score(test_labels, predictions, average='weighted'):.4f}")
    try:
        roc_auc = roc_auc_score(test_labels, probas, multi_class='ovr', average='weighted')
        print(f"ROC AUC Score: {roc_auc:.4f}")
    except ValueError as e:
        print(f"Could not compute ROC AUC Score: {e}")
    print("
Confusion Matrix:")
    print(confusion_matrix(test_labels, predictions))
    return model


In [None]:
import shap
import matplotlib.pyplot as plt

def explain_model(model, test_features, feature_names, output_path):
    print("Calculating SHAP values...")
    try:
        background_data = shap.sample(test_features, 100)
        explainer = shap.KernelExplainer(model.predict_proba, background_data)
        shap_values = explainer.shap_values(test_features)
        print("Generating SHAP summary plot...")
        plt.figure()
        if isinstance(shap_values, list):
            shap.summary_plot(shap_values[0], test_features, feature_names=feature_names, plot_type="bar", show=False)
        else:
            shap.summary_plot(shap_values, test_features, feature_names=feature_names, plot_type="bar", show=False)
        plt.title("SHAP Feature Importance")
        plot_file = os.path.join(output_path, 'shap_summary_plot.png')
        plt.savefig(plot_file)
        plt.close()
        print(f"SHAP summary plot saved to {plot_file}")
    except Exception as e:
        print(f"Could not generate SHAP plot: {e}")


In [None]:
import argparse

def run_experiment(args):
    print(f"====================--- Starting Experiment: Model={args.model}, Features={args.feature_extractor} ---")
    X_train, X_test, y_train, y_test = preprocess_data(args.data_path, max_records=args.max_records)
    train_features, test_features = extract_features(X_train, X_test, method=args.feature_extractor)
    model = train_and_evaluate(train_features, y_train, test_features, y_test, model_type=args.model)
    if args.explain:
        feature_names = [f'MFCC_{i}' for i in range(train_features.shape[1])] if args.feature_extractor == 'MFCC' else [f'DWT_{i}' for i in range(train_features.shape[1]) ]
        explain_model(model, test_features, feature_names, args.output_path)
    print(f"--- Experiment Finished: Model={args.model} ---")


## 4. Run Experiments

Now we can run the experiments for both model types. We use a small number of records (`max_records=4`) for a quick test run.

In [None]:
import argparse

args_cascade = argparse.Namespace(
    data_path=DATA_PATH,
    output_path=OUTPUT_PATH,
    feature_extractor='MFCC',
    model='CascadeForest',
    explain=True,
    max_records=4
)
run_experiment(args_cascade)

In [None]:
import argparse

args_gc = argparse.Namespace(
    data_path=DATA_PATH,
    output_path=OUTPUT_PATH,
    feature_extractor='DWT',
    model='gcForest',
    explain=True,
    max_records=4
)
run_experiment(args_gc)

## Conclusion

If the cells above executed without errors, your environment is correctly set up and the self-contained experiment notebook is working. You can now adjust the parameters (e.g., `max_records`, `feature_extractor`, and the `param_grid` in the `train_and_evaluate` function) to run your full research experiments.