In [None]:
import numpy as np
import pandas as pd
import os # NEW IMPORT for file system operations
from scipy import stats
from collections import Counter
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import recall_score, f1_score, precision_recall_fscore_support, auc, precision_recall_curve, confusion_matrix
from sklearn.utils import class_weight, resample # NEW IMPORT for downsampling
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv1D, BatchNormalization, Activation, MaxPooling1D, Dropout, Dense, Flatten
from tensorflow.keras.utils import to_categorical


SIGNAL_LENGTH = 187
NUM_CLASSES = 5
KFOLD_SPLITS = 5
RANDOM_STATE = 42
MAX_EPOCHS = 25 # OPTIMIZATION: Reduced from 50
BATCH_SIZE = 128 # OPTIMIZATION: Increased from 64 for faster training
 
RAW_DATA_PATH = '/content/mitbih_database/'

RECORD_IDS = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 111, 112, 113, 114, 115, 116, 117, 118, 119, 121, 122, 123, 124, 200, 201, 202, 203, 205, 207, 208, 209, 210, 212, 213, 214, 215, 217, 219, 220, 221, 222, 223, 228, 230, 231, 232, 233, 234]

AAMI_MAPPING = {
    'N': 0, 'L': 0, 'R': 0, 'e': 0, 'j': 0,  # N (Normal)
    'A': 1, 'a': 1, 'J': 1, 'S': 1,          # S (Supraventricular ectopic)
    'V': 2, 'E': 2,                          # V (Ventricular ectopic)
    'F': 3,                                  # F (Fusion)
    '/': 4, 'f': 4, 'Q': 4, '?': 4           # Q (Unclassifiable/Paced)
}

# AAMI Class Mapping: N=0, S=1, V=2, F=3, Q=4
CLASS_LABELS = {0: 'N (Normal)', 1: 'S (SVEB)', 2: 'V (VEB)', 3: 'F (Fusion)', 4: 'Q (Unknown)'}
CLASS_COLORS = {0: 'green', 1: 'blue', 2: 'orange', 3: 'red', 4: 'purple'}


def load_and_preprocess_data():

    if not os.path.isdir(RAW_DATA_PATH):
        print("CRITICAL ERROR: The data directory was NOT found.")
        print(f"Please check the path: --> {RAW_DATA_PATH}")
        return np.array([]), np.array([])
    else:
        file_list = os.listdir(RAW_DATA_PATH)
        print(f"Directory found. Total files: {len(file_list)}. Sample files: {file_list[:5]}...")

    all_beats = []
    all_labels = []

    
    for rec_id in RECORD_IDS:
       
        signal_file = os.path.join(RAW_DATA_PATH, f'{rec_id}.csv')
        
        annotation_file = os.path.join(RAW_DATA_PATH, f'{rec_id}annotations.txt')
        
        if not os.path.exists(signal_file) or not os.path.exists(annotation_file):
            continue

        try:
           
          
            signal_df = pd.read_csv(signal_file, header=0, engine='python')
            
            
            if 'mlii' in signal_df.columns:
                 signal = signal_df['mlii'].values
            elif signal_df.shape[1] > 1:
                # Fallback to the second column
                signal = signal_df.iloc[:, 1].values
            else:
                signal = signal_df.iloc[:, 0].values

           
            annotations = pd.read_csv(
                annotation_file, 
                sep=r'\s+', 
                header=None,
                skiprows=1, 
                usecols=[1, 2], 
                names=['Index', 'Beat_Type'],
                on_bad_lines='skip'
            )
            r_peak_indices = annotations['Index'].values.astype(int)
            beat_types = annotations['Beat_Type'].values

        except Exception as e:
            # Skip records with I/O errors
            print(f"Warning: Error processing record {rec_id}. Skipping. Error: {e}")
            continue

       
        half_beat = SIGNAL_LENGTH // 2 
        
        for r_idx, beat_type in zip(r_peak_indices, beat_types):
            label = AAMI_MAPPING.get(beat_type.strip(), 4) 
            
            start_idx = r_idx - half_beat
            end_idx = r_idx + half_beat + 1 
            
            if start_idx < 0 or end_idx > len(signal) or len(signal[start_idx:end_idx]) != SIGNAL_LENGTH:
                continue

            beat_segment = signal[start_idx:end_idx]

            all_beats.append(beat_segment)
            all_labels.append(label)

    X_raw = np.array(all_beats)
    Y_raw = np.array(all_labels)
    
    if X_raw.shape[0] == 0:
        print("CRITICAL ERROR: No beats were extracted...")
        return X_raw, Y_raw

    X_final = []
    Y_final = []
    
   
    data_df = pd.DataFrame(X_raw)
    data_df['label'] = Y_raw
    

    df_majority = data_df[data_df['label'] == 0]
    df_minority = data_df[data_df['label'] != 0]
    
    N_MAX_SAMPLES = 20000 
    
   
    df_majority_downsampled = resample(
        df_majority,
        replace=False, 
        n_samples=min(len(df_majority), N_MAX_SAMPLES), 
        random_state=RANDOM_STATE
    )
    
    # Combine downsampled majority class with all minority classes
    df_balanced = pd.concat([df_majority_downsampled, df_minority])
    
    # Shuffle the final dataset
    df_balanced = df_balanced.sample(frac=1, random_state=RANDOM_STATE).reset_index(drop=True)

    X = df_balanced.drop('label', axis=1).values
    Y = df_balanced['label'].values
    
    
    # 7. Apply Z-Score Normalization
    X = np.apply_along_axis(stats.zscore, 1, X)

    # 8. Reshape for 1D CNN: (samples, timesteps, features=1)
    X = X.reshape(X.shape[0], SIGNAL_LENGTH, 1)

    print(f"--- Finished Segmentation and Downsampling ---")
    print(f"Total Samples (Original): {X_raw.shape[0]}, Total Samples (Final): {X.shape[0]}")
    print(f"Final Class Distribution: {Counter(Y)}")
    
    return X, Y


def create_cnn_model(input_shape, num_classes):
   
    model = Sequential([
        # Block 1
        Conv1D(filters=32, kernel_size=5, padding='same', input_shape=input_shape),
        BatchNormalization(),
        Activation('relu'),
        MaxPooling1D(pool_size=2),
        Dropout(0.2),

        # Block 2
        Conv1D(filters=64, kernel_size=5, padding='same'),
        BatchNormalization(),
        Activation('relu'),
        MaxPooling1D(pool_size=2),
        Dropout(0.2),

        # Block 3
        Conv1D(filters=128, kernel_size=5, padding='same'),
        BatchNormalization(),
        Activation('relu'),
        MaxPooling1D(pool_size=2),
        Dropout(0.3),

        # Classifier
        Flatten(),
        Dense(128, activation='relu'),
        Dropout(0.5),
        Dense(num_classes, activation='softmax')
    ])

    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    return model


def calculate_class_weights(y_train):
    """
    Calculates class weights inversely proportional to class frequencies.
    This is essential for handling the severe class imbalance in MIT-BIH.
    """
    weights = class_weight.compute_class_weight(
        class_weight='balanced',
        classes=np.unique(y_train),
        y=y_train
    )
   
    class_weights_dict = dict(enumerate(weights))
    print("\nCalculated Class Weights (Higher for Rare Classes):")
    for k, v in class_weights_dict.items():
        print(f"  Class {k} ({CLASS_LABELS[k]}): {v:.4f}")
    return class_weights_dict


def evaluate_metrics(y_true, y_pred_probs, class_labels):

    y_pred_classes = np.argmax(y_pred_probs, axis=1)

    f1_macro = f1_score(y_true, y_pred_classes, average='macro', zero_division=0.0) 
    recall_per_class = recall_score(y_true, y_pred_classes, average=None, labels=range(len(class_labels)), zero_division=0.0)

    pr_auc_scores = []
    for i in range(len(class_labels)):
       
        if np.sum(y_true == i) == 0:
            pr_auc = np.nan # Use NaN if class is absent in test set
        else:
            precision, recall, _ = precision_recall_curve(y_true == i, y_pred_probs[:, i])
            pr_auc = auc(recall, precision)
        pr_auc_scores.append(pr_auc)
    
   
    mean_pr_auc = np.nanmean(pr_auc_scores) 

    results = {
        'F1_Macro': f1_macro,
        'Recall_Per_Class': recall_per_class,
        'PR_AUC_Mean': mean_pr_auc,
        'PR_AUC_Per_Class': pr_auc_scores,
        'Confusion_Matrix': confusion_matrix(y_true, y_pred_classes)
    }
    return results

def main():
    
    X, Y = load_and_preprocess_data()
    
    
    if X.shape[0] == 0:
        return

    y_encoded = to_categorical(Y, num_classes=NUM_CLASSES)

    class_weights = calculate_class_weights(Y)

    skf = StratifiedKFold(n_splits=KFOLD_SPLITS, shuffle=True, random_state=RANDOM_STATE)
    
    fold_metrics = []
    
    print(f"\n--- Starting {KFOLD_SPLITS}-Fold Stratified Cross-Validation ---")

    for fold, (train_index, test_index) in enumerate(skf.split(X, Y)):
        print(f"\n[FOLD {fold + 1}/{KFOLD_SPLITS}]")
        
        # Split data
        X_train, X_test = X[train_index], X[test_index]
        Y_train, Y_test = Y[train_index], Y[test_index]
        Y_train_enc, Y_test_enc = y_encoded[train_index], y_encoded[test_index]
        
        # NEW PRINT STATEMENT: Show data size for the current fold
        print(f"  Data Split: Train samples={len(X_train)}, Test samples={len(X_test)}")
        print(f"  Training distribution: {Counter(Y_train)}")
        
        
        model = create_cnn_model(X_train.shape[1:], NUM_CLASSES)
        
       
        callbacks = [
            tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
        ]
        
        
        print(f"  Training Model (Max Epochs={MAX_EPOCHS}, Batch Size={BATCH_SIZE})...")
        
        
        history = model.fit(
            X_train, Y_train_enc,
            epochs=MAX_EPOCHS, # OPTIMIZATION APPLIED HERE
            batch_size=BATCH_SIZE, # OPTIMIZATION APPLIED HERE
            validation_data=(X_test, Y_test_enc),
            class_weight=class_weights,
            callbacks=callbacks,
            verbose=1 # CHANGED FROM 0 TO 1 for visibility
        )
        print(f"  Trained for {len(history.history['loss'])} epochs. Best validation loss: {min(history.history['val_loss']):.4f}")

       
        Y_pred_probs = model.predict(X_test, verbose=0)
        
       
        metrics = evaluate_metrics(Y_test, Y_pred_probs, CLASS_LABELS)
        metrics['Fold'] = fold + 1
        fold_metrics.append(metrics)
        
        print(f"  F1-Macro: {metrics['F1_Macro']:.4f}, Mean PR-AUC: {metrics['PR_AUC_Mean']:.4f}")
        # Focus on rare classes
        print(f"  Rare Class Recall: F ({metrics['Recall_Per_Class'][3]:.3f}), Q ({metrics['Recall_Per_Class'][4]:.3f})")


    # --- Final Report Generation ---
    final_report = {
        'F1_Macro': [],
        'PR_AUC_Mean': [],
        'Recall_N': [], 'Recall_S': [], 'Recall_V': [], 'Recall_F': [], 'Recall_Q': [],
        'PR_AUC_N': [], 'PR_AUC_S': [], 'PR_AUC_V': [], 'PR_AUC_F': [], 'PR_AUC_Q': [],
    }
    
    for fm in fold_metrics:
        final_report['F1_Macro'].append(fm['F1_Macro'])
        final_report['PR_AUC_Mean'].append(fm['PR_AUC_Mean'])
        
        # Per-class recall
        final_report['Recall_N'].append(fm['Recall_Per_Class'][0])
        final_report['Recall_S'].append(fm['Recall_Per_Class'][1])
        final_report['Recall_V'].append(fm['Recall_Per_Class'][2])
        final_report['Recall_F'].append(fm['Recall_Per_Class'][3])
        final_report['Recall_Q'].append(fm['Recall_Per_Class'][4])

        # Per-class PR-AUC
        final_report['PR_AUC_N'].append(fm['PR_AUC_Per_Class'][0])
        final_report['PR_AUC_S'].append(fm['PR_AUC_Per_Class'][1])
        final_report['PR_AUC_V'].append(fm['PR_AUC_Per_Class'][2])
        final_report['PR_AUC_F'].append(fm['PR_AUC_Per_Class'][3])
        final_report['PR_AUC_Q'].append(fm['PR_AUC_Per_Class'][4])

    
    final_metrics_data = {}
    for k, v in final_report.items():
        # Sanitize: Convert list elements to numeric, coercing non-numeric/strings to NaN
        sanitized_v = pd.to_numeric(np.asarray(v), errors='coerce') 
        final_metrics_data[k] = [np.nanmean(sanitized_v), np.nanstd(sanitized_v)]

    results_df = pd.DataFrame(final_metrics_data, index=['Mean', 'Std'])
    results_df = results_df.T
    
    print("\n=======================================================================")
    print("                 FINAL 1D-CNN CLASSIFICATION REPORT")
    print("=======================================================================")
    print(f"Cross-Validation: {KFOLD_SPLITS}-Fold Stratified CV")
    print("Imbalance Handled: Class Weighting + N-Class Downsampling")
    print("-----------------------------------------------------------------------")
    
    # Display overall metrics
    print(f"Overall F1-Macro: {results_df.loc['F1_Macro', 'Mean']:.4f} (+/- {results_df.loc['F1_Macro', 'Std']:.4f})")
    print(f"Mean PR-AUC:      {results_df.loc['PR_AUC_Mean', 'Mean']:.4f} (+/- {results_df.loc['PR_AUC_Mean', 'Std']:.4f})")
    
    print("\n--- Per-Class Recall (Crucial for Rare-Classes) ---")
    
    for i, label in CLASS_LABELS.items():
        key = f'Recall_{label.split(" ")[0]}'
        mean_recall = results_df.loc[key, 'Mean']
        std_recall = results_df.loc[key, 'Std']
        print(f"  {label:<18} (Class {i}): {mean_recall:.4f} (+/- {std_recall:.4f})")

    print("\n--- Per-Class PR-AUC ---")

    for i, label in CLASS_LABELS.items():
        key = f'PR_AUC_{label.split(" ")[0]}'
        mean_pr_auc = results_df.loc[key, 'Mean']
        std_pr_auc = results_df.loc[key, 'Std']
        print(f"  {label:<18} (Class {i}): {mean_pr_auc:.4f} (+/- {std_pr_auc:.4f})")
    
    print("=======================================================================")
    
if __name__ == '__main__':
    # Set logging off for a cleaner output
    tf.get_logger().setLevel('ERROR') 
    main()
