# Deep FECG Research: All-in-One Setup and Test Notebook

This notebook contains all the commands for setting up your environment and running a test of the `deep-fecg-research` project. Please read the instructions carefully, especially regarding environment activation.

**IMPORTANT:** While all commands are listed here, the `pyenv activate` command *must* be run in your terminal *before* you launch Jupyter Notebook or JupyterLab. Running it within a notebook cell will not correctly activate the environment for the Jupyter kernel.

## 1. Activate Python Environment (Run this in your Terminal FIRST!)

**Do NOT run this cell in Jupyter.** Copy and paste this command into your terminal and execute it there. Then, from the *same terminal*, launch Jupyter Notebook or JupyterLab.

```bash
pyenv activate deepforest
```

Once you have activated the environment and launched Jupyter, you can proceed with the cells below.

In [None]:
# This cell is for demonstration purposes only. 
# It will NOT activate the environment for subsequent cells in Jupyter.
# You MUST run 'pyenv activate deepforest' in your terminal before starting Jupyter.
# !pyenv activate deepforest # Uncomment and run in terminal, not here.

## 2. Install `uv` (if not already installed)

`uv` is a fast Python package installer and resolver. We'll use it to manage project dependencies. Run this cell to install `uv` into your active environment.

*(The `!` prefix runs the command in the shell from within Jupyter.)*

In [None]:
!pip install uv

## 3. Install/Reinstall Project Dependencies with `uv`

The `numpy.dtype` size error often indicates a binary incompatibility. To resolve this, we'll force a reinstallation of all project dependencies using `uv`. The `--active` flag ensures `uv` targets your currently active `pyenv` environment.

*(This step might take a few moments as it downloads and reinstalls packages.)*

In [None]:
!uv pip install --force-reinstall -r requirements.txt --active

## 4. Project Code

In [None]:
import yaml
import os

def load_config(config_path='config.yaml'):
    """
    Loads a YAML configuration file.
    """
    if not os.path.exists(config_path):
        raise FileNotFoundError(f"Configuration file not found at {config_path}")
    with open(config_path, 'r') as file:
        return yaml.safe_load(file)

def save_results(results, output_path='results.txt'):
    """
    Saves the experiment results to a text file.
    """
    with open(output_path, 'w') as file:
        for key, value in results.items():
            file.write(f'{key}: {value}
')

def check_and_create_dir(directory):
    """
    Checks if a directory exists, and if not, creates it.
    """
    if not os.path.exists(directory):
        print(f"Creating directory: {directory}")
        os.makedirs(directory)

In [None]:
import numpy as np
import wfdb
from scipy.signal import butter, filtfilt
from sklearn.model_selection import train_test_split
import os

# AAMI-compliant class mappings
AAMI_CLASSES = {
    'N': 0, 'L': 0, 'R': 0, 'e': 0, 'j': 0,  # Non-ectopic
    'A': 1, 'a': 1, 'J': 1, 'S': 1,         # Supraventricular ectopic
    'V': 2, 'E': 2,                         # Ventricular ectopic
    'F': 3,                                 # Fusion
    '/': 4, 'f': 4, 'Q': 4,                 # Paced/Unknown
}

def get_aami_class(symbol):
    """Maps an annotation symbol to its AAMI class."""
    return AAMI_CLASSES.get(symbol)

def apply_bandpass_filter(signal, fs=360):
    """Applies a band-pass filter to the signal."""
    lowcut = 0.5
    highcut = 45.0
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = butter(2, [low, high], btype='band')
    return filtfilt(b, a, signal)

def segment_heartbeats(signal, annotations, fs=360, window_size=360):
    """
    Segments the signal into individual heartbeats.
    """
    heartbeats = []
    labels = []
    window_before = window_size // 2
    window_after = window_size - window_before

    for i, symbol in enumerate(annotations.symbol):
        aami_class = get_aami_class(symbol)
        if aami_class is not None:
            peak_sample = annotations.sample[i]
            start = peak_sample - window_before
            end = peak_sample + window_after
            if start >= 0 and end < len(signal):
                heartbeats.append(signal[start:end])
                labels.append(aami_class)

    return np.array(heartbeats), np.array(labels)

def preprocess_data(data_path, window_size=360, max_records=None):
    """
    Loads and preprocesses the ECG data from the MIT-BIH Arrhythmia Database.
    """
    print(f"Loading data from {data_path}...")
    
    # Get a list of all record names by listing .hea files
    record_names = [f.split('.')[0] for f in os.listdir(data_path) if f.endswith('.hea')]
    record_names.sort() # Ensure consistent order

    all_heartbeats = []
    all_labels = []

    for i, record_name in enumerate(record_names):
        if max_records and i >= max_records:
            print(f"Reached max_records limit of {max_records}. Stopping data loading.")
            break
        print(f"Processing record: {record_name}")
        record_full_path = os.path.join(data_path, record_name)
        try:
            record = wfdb.rdrecord(record_full_path)
            annotations = wfdb.rdann(record_full_path, 'atr')

            # Use the first channel (MLII) if available, otherwise the first channel
            if 'MLII' in record.sig_name:
                signal_index = record.sig_name.index('MLII')
            else:
                signal_index = 0 # Default to first channel
            signal = record.p_signal[:, signal_index]

            # Apply band-pass filter
            filtered_signal = apply_bandpass_filter(signal, fs=record.fs)

            # Segment heartbeats
            heartbeats, labels = segment_heartbeats(
                filtered_signal, annotations, fs=record.fs, window_size=window_size
            )

            all_heartbeats.append(heartbeats)
            all_labels.append(labels)

        except Exception as e:
            print(f"Error processing record {record_name}: {e}")

    if not all_heartbeats:
        raise ValueError("No heartbeats processed. Check data_path and file integrity.")

    X = np.concatenate(all_heartbeats)
    y = np.concatenate(all_labels)

    print("Splitting data into training and testing sets...")
    # Stratified 80/20 split
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42, stratify=y
    )

    return X_train, X_test, y_train, y_test

In [None]:
import numpy as np
import librosa
import pywt
from ssqueezepy import ssq_cwt

def extract_features(train_data, test_data, method='MFCC'):
    """
    Extracts features from the preprocessed ECG data.

    Args:
        train_data (np.ndarray): Training data (heartbeats).
        test_data (np.ndarray): Testing data (heartbeats).
        method (str): Feature extraction method (MFCC, DWT, HHT, SSCWT).

    Returns:
        tuple: A tuple containing train_features, test_features.
    """
    print(f"Extracting features using {method} method...")

    if method == 'MFCC':
        train_features = _extract_mfcc(train_data)
        test_features = _extract_mfcc(test_data)
    elif method == 'DWT':
        train_features = _extract_dwt(train_data)
        test_features = _extract_dwt(test_data)
    elif method == 'HHT':
        train_features = _extract_hht(train_data)
        test_features = _extract_hht(test_data)
    elif method == 'SSCWT':
        train_features = _extract_sscwt(train_data)
        test_features = _extract_sscwt(test_data)
    else:
        raise ValueError(f"Unknown feature extraction method: {method}")

    return train_features, test_features

def _extract_mfcc(data, sr=360, n_mfcc=13):
    """
    Extracts Mel-frequency cepstral coefficients (MFCCs).
    """
    mfccs = []
    for heartbeat in data:
        # Ensure heartbeat is float type for librosa
        heartbeat = heartbeat.astype(float)
        mfcc = librosa.feature.mfcc(y=heartbeat, sr=sr, n_mfcc=n_mfcc)
        mfccs.append(np.mean(mfcc.T, axis=0)) # Take mean across time frames
    return np.array(mfccs)

def _extract_dwt(data, wavelet='db4', level=4):
    """
    Extracts Discrete Wavelet Transform (DWT) features.
    """
    dwt_features = []
    for heartbeat in data:
        coeffs = pywt.wavedec(heartbeat, wavelet, level=level)
        # Flatten coefficients and concatenate them
        features = np.concatenate([np.array(c).flatten() for c in coeffs])
        dwt_features.append(features)
    # Pad features to the maximum length if they are not uniform
    max_len = max(len(f) for f in dwt_features)
    padded_features = np.array([np.pad(f, (0, max_len - len(f)), 'constant') for f in dwt_features])
    return padded_features

def _extract_hht(data):
    """
    Extracts Hilbert-Huang Transform (HHT) features.
    Note: HHT implementation is complex and often requires external libraries
    or a custom implementation of EMD. This is a placeholder.
    For a full implementation, consider libraries like `emd`.
    """
    print("Warning: HHT feature extraction is a placeholder and returns dummy data.")
    # Dummy implementation: return mean and std of the signal as basic features
    hht_features = []
    for heartbeat in data:
        hht_features.append([np.mean(heartbeat), np.std(heartbeat)])
    return np.array(hht_features)

def _extract_sscwt(data, fs=360):
    """
    Extracts Synchrosqueezed Continuous Wavelet Transform (SSCWT) features.
    Note: SSCWT can produce high-dimensional output. This is a placeholder
    and returns a simplified representation.
    """
    print("Warning: SSCWT feature extraction is a placeholder and returns dummy data.")
    sscwt_features = []
    for heartbeat in data:
        # ssq_cwt returns (Tx, Wx, ssq_freqs, scales, wavel_scales)
        # We'll take the mean of the absolute value of the transform as a simple feature
        Tx, Wx, ssq_freqs, scales, wavel_scales = ssq_cwt(heartbeat, 'morlet', fs=fs)
        sscwt_features.append(np.mean(np.abs(Tx)))
    return np.array(sscwt_features).reshape(-1, 1)

In [None]:
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_auc_score
from sklearn.model_selection import GridSearchCV, StratifiedKFold
from deepforest import CascadeForestClassifier
from gcforest.gcforest import GCForest

def train_and_evaluate(train_features, train_labels, test_features, test_labels, model_type='gcForest'):
    """
    Trains and evaluates the specified Deep Forest model.

    Args:
        train_features (np.ndarray): The training features.
        train_labels (np.ndarray): The training labels.
        test_features (np.ndarray): The testing features.
        test_labels (np.ndarray): The testing labels.
        model_type (str): The type of model to train ('gcForest' or 'CascadeForest').

    Returns:
        object: The trained model.
    """
    print(f"Training {model_type} model...")

    if model_type == 'gcForest':
        model = _train_gcforest(train_features, train_labels)
    elif model_type == 'CascadeForest':
        model = _train_cascade_forest(train_features, train_labels)
    else:
        raise ValueError(f"Invalid model type: {model_type}")

    print("Evaluating the model...")
    predictions = model.predict(test_features)

    # Calculate and print performance metrics
    accuracy = accuracy_score(test_labels, predictions)
    f1 = f1_score(test_labels, predictions, average='weighted')
    precision = precision_score(test_labels, predictions, average='weighted')
    recall = recall_score(test_labels, predictions, average='weighted')

    print(f"Accuracy: {accuracy:.4f}")
    print(f"F1-score: {f1:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")

    return model

def _train_gcforest(train_features, train_labels):
    print("Training gcForest...")
    # Initialize GCForest with some default parameters
    # These parameters can be tuned further for optimal performance
    gc = GCForest(shape_1X=train_features.shape[1],
                  n_mgs=1,
                  n_estimators_as_forest=[100],
                  min_samples_leaf=1,
                  max_depth=None,
                  n_tolerant_retry=10,
                  n_jobs=-1) # Use all available cores
    gc.fit(train_features, train_labels)
    return gc

def _train_cascade_forest(train_features, train_labels):
    print("Training CascadeForestClassifier...")
    # Initialize CascadeForestClassifier with some default parameters
    # These parameters can be tuned further for optimal performance
    cf = CascadeForestClassifier(n_estimators=100,
                                 n_trees=500,
                                 use_predictor=True,
                                 min_samples_leaf=1,
                                 max_depth=None,
                                 n_jobs=-1) # Use all available cores
    cf.fit(train_features, train_labels)
    return cf

In [None]:
import shap
import numpy as np
import matplotlib.pyplot as plt

def explain_model(model, test_features, test_labels):
    """
    Explains the model's predictions using SHAP.

    Args:
        model (object): The trained model (gcForest or CascadeForestClassifier).
        test_features (np.ndarray): The testing features.
        test_labels (np.ndarray): The testing labels.
    """
    print("Calculating SHAP values...")

    # For tree-based models like gcForest and CascadeForestClassifier, TreeExplainer is efficient.
    # If the model is a scikit-learn compatible tree ensemble, TreeExplainer should work.
    # If not, KernelExplainer is a more general but slower alternative.
    try:
        explainer = shap.TreeExplainer(model)
        shap_values = explainer.shap_values(test_features)
    except Exception as e:
        print(f"TreeExplainer failed: {e}. Falling back to KernelExplainer (may be slow)...")
        # KernelExplainer requires a background dataset for estimation
        # Using a subset of test_features as background for performance
        background_data = shap.sample(test_features, 100) # Sample 100 instances
        explainer = shap.KernelExplainer(model.predict_proba, background_data)
        shap_values = explainer.shap_values(test_features)

    print("Generating SHAP summary plot...")
    # If the model is multi-output (multi-class classification), shap_values will be a list of arrays.
    # For summary_plot, we often plot for one class or the absolute mean of all classes.
    if isinstance(shap_values, list):
        # For multi-class, plot the SHAP values for the first class (or average/sum them)
        shap.summary_plot(shap_values[0], test_features, plot_type="bar", show=False)
    else:
        shap.summary_plot(shap_values, test_features, plot_type="bar", show=False)

    plt.title("SHAP Feature Importance")
    plt.tight_layout()
    plt.savefig("shap_summary_plot.png")
    plt.close()
    print("SHAP summary plot saved as shap_summary_plot.png")

    # You can also implement logic to analyze misclassified instances
    # and generate local SHAP plots (e.g., shap.force_plot).

In [None]:
import argparse

def main(args):
    """
    Main function to run the experiment.
    """
    # 1. Preprocess the data
    print("Starting data preprocessing...")
    train_data, test_data, train_labels, test_labels = preprocess_data(args.data_path, max_records=args.max_records)
    print("Data preprocessing complete.")

    # 2. Extract features
    print("Extracting features...")
    train_features, test_features = extract_features(train_data, test_data, method=args.feature_extractor)
    print("Feature extraction complete.")

    # 3. Train and evaluate the model
    print("Training and evaluating the model...")
    model = train_and_evaluate(train_features, train_labels, test_features, test_labels, model_type=args.model)
    print("Model training and evaluation complete.")

    # 4. Explain the model
    if args.explain:
        print("Explaining the model...")
        explain_model(model, test_features, test_labels)
        print("Model explanation complete.")

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Run the Deep Forest ECG experiment.')
    parser.add_argument('--data_path', type=str, default='./data', help='Path to the dataset.')
    parser.add_argument('--feature_extractor', type=str, default='MFCC', choices=['MFCC', 'DWT', 'HHT', 'SSCWT'], help='Feature extraction method.')
    parser.add_argument('--model', type=str, default='gcForest', choices=['gcForest', 'CascadeForest'], help='Model to train.')
    parser.add_argument('--explain', action='store_true', help='Whether to run SHAP explainability.')
    parser.add_argument('--max_records', type=int, default=None, help='Maximum number of records to process for testing purposes.')
    args = parser.parse_args(args=[])
    main(args)

## 5. Run a Small Test

Now that the dependencies should be correctly installed, we can run `main.py` with a small dataset to verify the setup. We'll use the `--max_records` argument to limit the data processed and `--feature_extractor MFCC`.

*(This will execute your `main.py` script and print its output below the cell.)*

In [None]:
args = argparse.Namespace(data_path='./data/mit-bih-arrhythmia-database-1.0.0', feature_extractor='MFCC', model='gcForest', explain=False, max_records=10)
main(args)

## Conclusion

If the last cell executed without the `numpy.dtype` error and showed output from `main.py` (e.g., "Starting data preprocessing..."), your environment is correctly set up for the `deep-fecg-research` project. You can now proceed with your research!