In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, average_precision_score
from tensorflow.keras import layers, models, callbacks
import json
import logging
from tqdm import tqdm
import seaborn as sns
import time
from scipy.signal import resample

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

# Early stopping callback
early_stopping = callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)

# Dataset definitions
datasets = [
    {"name": "CUSPH-SR-AF", "data": "data/data_2class_normal.npy", "labels": "data/labels_2class_normal.npy", "num_classes": 2, "input_shape": (5000, 12)},
    {"name": "CSPC18-SR-AF", "data": "data/data_2class_cpsc18.npy", "labels": "data/labels_2class_cpsc18.npy", "num_classes": 2, "input_shape": (15000, 12)},
]

filter_combinations = [[32, 64, 128]]

def save_results(results, dataset_name, experiment_type, train_condition, test_condition):
    filename = f"robustness/results/{dataset_name}_{experiment_type}_train_{train_condition}_test_{test_condition}.json"
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    
    with open(filename, 'w') as f:
        json.dump(results, f, indent=2)

def save_confusion_matrix(cm, dataset_name, experiment_name):
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title(f'Confusion Matrix - {dataset_name} - {experiment_name}')
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    plt.savefig(f"robustness/results/{dataset_name}_{experiment_name}_confusion_matrix.png")
    plt.close()

def calculate_dataset_stats(data, labels):
    stats = {
        'overall': {
            'mean': float(np.mean(data)),
            'std': float(np.std(data)),
            'min': float(np.min(data)),
            'max': float(np.max(data)),
        },
        'channel_wise': [],
        'label_distribution': np.bincount(labels.astype(int)).tolist()
    }
    
    for i in range(data.shape[-1]):
        channel_data = data[:, :, i]
        channel_stats = {
            'mean': float(np.mean(channel_data)),
            'std': float(np.std(channel_data)),
            'min': float(np.min(channel_data)),
            'max': float(np.max(channel_data)),
        }
        stats['channel_wise'].append(channel_stats)
    
    return stats

def downsample_block(x, filters):
    x = layers.Conv1D(filters // 2, 1, strides=1, padding='same')(x)
    x = mixed_pool_operator(x)
    return x

def branched_nodal_operator(x, filters, kernel_size=5, activation='relu'):
    y1 = layers.Conv1D(filters // 2, kernel_size, dilation_rate=2, padding='same')(x)
    y1 = layers.BatchNormalization()(y1)
    y1 = layers.Activation(activation)(y1)

    y2 = layers.SeparableConv1D(filters // 2, kernel_size, padding='same')(x)
    y2 = layers.BatchNormalization()(y2)
    y2 = layers.Activation(activation)(y2)

    y = layers.Concatenate()([y1, y2])
    return y

def mixed_pool_operator(x, pool_size=2, strides=1):
    y1 = layers.AveragePooling1D(pool_size, strides, padding='same')(x)
    y2 = layers.MaxPooling1D(pool_size, strides, padding='same')(x)
    y = layers.Concatenate()([y1, y2])
    return y

def squeeze_and_excitation_block(x, ratio=16):
    num_channels = x.shape[-1]
    squeeze = layers.GlobalAveragePooling1D()(x)
    excitation = layers.Dense(num_channels // ratio, activation='relu')(squeeze)
    excitation = layers.Dense(num_channels, activation='sigmoid')(excitation)
    excitation = layers.Reshape((1, num_channels))(excitation)
    scale = layers.Multiply()([x, excitation])
    return scale

def residual_block_SERN_AwGOP(x, filters, kernel_size=5, downsample=False):
    y = branched_nodal_operator(x, filters, kernel_size)
    y = branched_nodal_operator(y, filters, kernel_size)

    if downsample:
        x = downsample_block(x, filters)

    y = squeeze_and_excitation_block(y)

    attention_weights = layers.Dense(1, activation='sigmoid')(x)
    gop_out = layers.Multiply()([attention_weights, y])
    gop_out = layers.Add()([gop_out, x])
    gop_out = layers.Activation('relu')(gop_out)
    return gop_out

def create_SERN_AwGOP(input_shape, num_classes, filters):
    inputs = layers.Input(shape=input_shape)
    x = layers.Conv1D(filters[0], 5, strides=2, padding='same', activation='relu')(inputs)
    x = layers.MaxPooling1D(pool_size=3, strides=2, padding='same')(x)

    for f in filters[1:]:
        x = residual_block_SERN_AwGOP(x, f, downsample=True)
        x = residual_block_SERN_AwGOP(x, f)
        x = residual_block_SERN_AwGOP(x, f)
        x = residual_block_SERN_AwGOP(x, f)

    x = layers.GlobalAveragePooling1D()(x)
    if num_classes == 2:
        outputs = layers.Dense(1, activation='sigmoid')(x)
    else:
        outputs = layers.Dense(num_classes, activation='softmax')(x)
    model = models.Model(inputs, outputs)
    
    # Compile the model
    if num_classes == 2:
        model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    else:
        model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    
    return model

def train_model(model, X_train, y_train, X_val, y_val):
    try:
        history = model.fit(X_train, y_train, epochs=100, batch_size=32, validation_data=(X_val, y_val), callbacks=[early_stopping])
        return history
    except Exception as e:
        logging.error(f"Error in train_model: {str(e)}")
        return None

def evaluate_model(model, X_test, y_test, num_classes):
    try:
        y_pred = model.predict(X_test)
        y_pred_classes = np.argmax(y_pred, axis=1) if num_classes > 2 else (y_pred > 0.5).astype(int).flatten()

        results = {
            'accuracy': float(accuracy_score(y_test, y_pred_classes)),
            'precision_weighted': float(precision_score(y_test, y_pred_classes, average='weighted')),
            'recall_weighted': float(recall_score(y_test, y_pred_classes, average='weighted')),
            'f1_weighted': float(f1_score(y_test, y_pred_classes, average='weighted')),
        }

        # Add unweighted sensitivity and specificity
        cm = confusion_matrix(y_test, y_pred_classes)
        sensitivity = cm[1,1] / (cm[1,1] + cm[1,0])
        specificity = cm[0,0] / (cm[0,0] + cm[0,1])
        results['sensitivity_unweighted'] = float(sensitivity)
        results['specificity_unweighted'] = float(specificity)

        if num_classes == 2:
            results['auc_roc'] = float(roc_auc_score(y_test, y_pred))
            results['average_precision'] = float(average_precision_score(y_test, y_pred))

        return results, cm
    except Exception as e:
        logging.error(f"Error in evaluate_model: {str(e)}")
        return {}, None

def add_gaussian_noise(data, noise_factor=0.05):
    noise = np.random.normal(loc=0, scale=1, size=data.shape)
    return data + noise_factor * noise

def data_efficiency_analysis(dataset, model_fn, filters, fractions=[0.1, 0.25, 0.5, 0.75, 1.0]):
    try:
        data = np.load(dataset['data'])
        labels = np.load(dataset['labels'])
        X_train, X_test, y_train, y_test = train_test_split(data, labels, test_size=0.2, random_state=42)

        results = {}
        for fraction in fractions:
            n_samples = int(len(X_train) * fraction)
            X_train_subset = X_train[:n_samples]
            y_train_subset = y_train[:n_samples]

            model = model_fn(dataset['input_shape'], dataset['num_classes'], filters)
            history = train_model(model, X_train_subset, y_train_subset, X_test, y_test)
            test_results, _ = evaluate_model(model, X_test, y_test, dataset['num_classes'])
            results[fraction] = test_results

        return results
    except Exception as e:
        logging.error(f"Error in data_efficiency_analysis: {str(e)}")
        return {}

def longitudinal_analysis(dataset, model_fn, filters, num_time_points=5):
    try:
        data = np.load(dataset['data'])
        labels = np.load(dataset['labels'])

        # Assume data is sorted chronologically
        time_point_size = len(data) // num_time_points
        results = []

        for i in range(num_time_points):
            start_idx = i * time_point_size
            end_idx = (i + 1) * time_point_size
            X_train, X_test, y_train, y_test = train_test_split(data[start_idx:end_idx], labels[start_idx:end_idx], test_size=0.2, random_state=42)

            model = model_fn(dataset['input_shape'], dataset['num_classes'], filters)
            history = train_model(model, X_train, y_train, X_test, y_test)
            test_results, _ = evaluate_model(model, X_test, y_test, dataset['num_classes'])
            results.append(test_results)

        return results
    except Exception as e:
        logging.error(f"Error in longitudinal_analysis: {str(e)}")
        return []

def time_efficiency_analysis(dataset, model_fn, filters):
    try:
        data = np.load(dataset['data'])
        labels = np.load(dataset['labels'])
        X_train, X_test, y_train, y_test = train_test_split(data, labels, test_size=0.2, random_state=42)

        model = model_fn(dataset['input_shape'], dataset['num_classes'], filters)

        start_time = time.time()
        history = train_model(model, X_train, y_train, X_test, y_test)
        training_time = time.time() - start_time

        start_time = time.time()
        _ = model.predict(X_test)
        inference_time = (time.time() - start_time) / len(X_test)

        return {"training_time": float(training_time), "inference_time": float(inference_time)}
    except Exception as e:
        logging.error(f"Error in time_efficiency_analysis: {str(e)}")
        return {}

def cross_dataset_validation(train_dataset, test_dataset, model_fn, filters):
    try:
        train_data = np.load(train_dataset['data'])
        train_labels = np.load(train_dataset['labels'])
        test_data = np.load(test_dataset['data'])
        test_labels = np.load(test_dataset['labels'])

        # Ensure compatible shapes
        if train_data.shape[1:] != test_data.shape[1:]:
            test_data = resample(test_data, train_data.shape[1], axis=1)

        model = model_fn(train_dataset['input_shape'], train_dataset['num_classes'], filters)
        history = train_model(model, train_data, train_labels, test_data, test_labels)
        test_results, _ = evaluate_model(model, test_data, test_labels, test_dataset['num_classes'])

        return test_results
    except Exception as e:
        logging.error(f"Error in cross_dataset_validation: {str(e)}")
        return {}

def add_powerline_interference(data, frequency=50, amplitude=0.1):
    t = np.arange(data.shape[1]) / 1000  # Assume 1000 Hz sampling rate
    noise = amplitude * np.sin(2 * np.pi * frequency * t)
    return data + noise.reshape(1, -1, 1)

def add_electrode_motion_artifact(data, artifact_duration=100, amplitude=0.5):
    artifact = np.zeros(data.shape)
    for i in range(data.shape[0]):
        start = np.random.randint(0, data.shape[1] - artifact_duration)
        artifact[i, start:start+artifact_duration, :] = amplitude * np.random.randn(artifact_duration, data.shape[2])
    return data + artifact

def robustness_testing(dataset, model_fn, filters):
    try:
        data = np.load(dataset['data'])
        labels = np.load(dataset['labels'])
        X_train, X_test, y_train, y_test = train_test_split(data, labels, test_size=0.2, random_state=42)

        model = model_fn(dataset['input_shape'], dataset['num_classes'], filters)
        history = train_model(model, X_train, y_train, X_test, y_test)

        results = {}

        # Test on clean data
        clean_results, _ = evaluate_model(model, X_test, y_test, dataset['num_classes'])
        results['clean'] = clean_results

        # Test with Gaussian noise
        noisy_data = add_gaussian_noise(X_test)
        noise_results, _ = evaluate_model(model, noisy_data, y_test, dataset['num_classes'])
        results['gaussian_noise'] = noise_results

        # Test with powerline interference
        powerline_data = add_powerline_interference(X_test)
        powerline_results, _ = evaluate_model(model, powerline_data, y_test, dataset['num_classes'])
        results['powerline_interference'] = powerline_results

        # Test with electrode motion artifact
        motion_data = add_electrode_motion_artifact(X_test)
        motion_results, _ = evaluate_model(model, motion_data, y_test, dataset['num_classes'])
        results['electrode_motion'] = motion_results

        return results
    except Exception as e:
        logging.error(f"Error in robustness_testing: {str(e)}")
        return {}

def add_random_baseline_wander(data, max_amplitude=0.1):
    t = np.linspace(0, 1, data.shape[1])
    baseline = max_amplitude * np.sin(2 * np.pi * np.random.rand() * t)
    return data + baseline.reshape(1, -1, 1)

def time_warp(data, sigma=0.2, knot=4):
    orig_steps = np.arange(data.shape[1])
    warp_steps = np.linspace(0, data.shape[1]-1, num=knot+2)
    warper = np.zeros_like(data)
    
    for i in range(data.shape[0]):  # iterate over samples
        for j in range(data.shape[2]):  # iterate over channels
            random_warps = np.random.normal(loc=1.0, scale=sigma, size=(knot+2,))
            warper[i, :, j] = np.interp(orig_steps, warp_steps, random_warps)
    
    return data * warper

def apply_augmentations(data, labels):
    augmented_data = []
    augmented_labels = []

    # Original data
    augmented_data.append(data)
    augmented_labels.append(labels)

    # Gaussian noise
    augmented_data.append(add_gaussian_noise(data))
    augmented_labels.append(labels)

    # Baseline wander
    augmented_data.append(add_random_baseline_wander(data))
    augmented_labels.append(labels)

    # Time warping
    augmented_data.append(time_warp(data))
    augmented_labels.append(labels)

    return np.concatenate(augmented_data, axis=0), np.concatenate(augmented_labels, axis=0)

def simulate_hardware_issues(data, corruption_rate=0.1):
    corrupted_data = data.copy()
    mask = np.random.choice([0, 1], size=data.shape, p=[1-corruption_rate, corruption_rate])
    corrupted_data[mask.astype(bool)] = np.random.normal(loc=0, scale=1, size=corrupted_data[mask.astype(bool)].shape)
    return corrupted_data

def run_experiment(dataset, model_fn, filters, experiment_type, use_augmentation=False, single_electrode=None, pretrain_dataset=None):
    try:
        data = np.load(dataset['data'])
        labels = np.load(dataset['labels'])

        # Truncate data to (5000, 12) if it's larger
        if data.shape[1] > 5000:
            data = data[:, :5000, :]
            logging.info(f"Truncated {dataset['name']} data to shape {data.shape}")

        if single_electrode is not None:
            data = data[:, :, single_electrode:single_electrode+1]
            input_shape = (5000, 1)
            experiment_type = f"single_electrode_{single_electrode}"
            train_condition = "original"
        elif use_augmentation:
            train_condition = "aug_combined"
        elif pretrain_dataset is not None:
            train_condition = f"pretrained_{pretrain_dataset['name']}"
        else:
            train_condition = "original"
        
        input_shape = (5000, 12) if single_electrode is None else input_shape
        X_train, X_test, y_train, y_test = train_test_split(data, labels, test_size=0.2, random_state=42)
        X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42)

        if use_augmentation:
            X_train, y_train = apply_augmentations(X_train, y_train)

        if pretrain_dataset is not None:
            pretrain_data = np.load(pretrain_dataset['data'])
            pretrain_labels = np.load(pretrain_dataset['labels'])
            if pretrain_data.shape[1] > 5000:
                pretrain_data = pretrain_data[:, :5000, :]
            pretrain_model = model_fn(input_shape, pretrain_dataset['num_classes'], filters)
            
            pretrain_model.fit(pretrain_data, pretrain_labels, epochs=100, batch_size=32, validation_split=0.2, callbacks=[early_stopping])
            
            # Evaluate pretrained model
            pretrain_results, pretrain_cm = evaluate_model(pretrain_model, pretrain_data, pretrain_labels, pretrain_dataset['num_classes'])
            save_results(pretrain_results, pretrain_dataset['name'], "pretrain", "original", "clean")
            save_confusion_matrix(pretrain_cm, pretrain_dataset['name'], "pretrain_original_clean")
            
            # Transfer learned weights to the new model
            model = model_fn(input_shape, dataset['num_classes'], filters)
            for i, layer in enumerate(pretrain_model.layers[:-1]):  # Exclude the last layer
                model.layers[i].set_weights(layer.get_weights())
        else:
            model = model_fn(input_shape, dataset['num_classes'], filters)

        history = train_model(model, X_train, y_train, X_val, y_val)

        # Evaluate on normal test set
        normal_results, normal_cm = evaluate_model(model, X_test, y_test, dataset['num_classes'])
        save_results(normal_results, dataset['name'], experiment_type, train_condition, "clean")
        save_confusion_matrix(normal_cm, dataset['name'], f"{experiment_type}_train_{train_condition}_test_clean")

        # Evaluate on corrupted test set
        X_test_corrupted = simulate_hardware_issues(X_test)
        corrupted_results, corrupted_cm = evaluate_model(model, X_test_corrupted, y_test, dataset['num_classes'])
        save_results(corrupted_results, dataset['name'], experiment_type, train_condition, "hw_corrupted")
        save_confusion_matrix(corrupted_cm, dataset['name'], f"{experiment_type}_train_{train_condition}_test_hw_corrupted")

        # Ensure all required metrics are present
        for metric in ['accuracy', 'precision_weighted', 'recall_weighted', 'f1_weighted', 'sensitivity_unweighted', 'specificity_unweighted']:
            if metric not in normal_results:
                normal_results[metric] = None
                logging.warning(f"Metric {metric} not found in normal results for {dataset['name']} {experiment_type}")
            if metric not in corrupted_results:
                corrupted_results[metric] = None
                logging.warning(f"Metric {metric} not found in corrupted results for {dataset['name']} {experiment_type}")

        return {"normal": normal_results, "corrupted": corrupted_results}

    except Exception as e:
        logging.error(f"Error in run_experiment for {dataset['name']} {experiment_type}: {str(e)}")
        return {"normal": {}, "corrupted": {}}

def main():
    results_dict = {}
    for dataset in datasets:
        results_dict[dataset['name']] = {}
        for filters in filter_combinations:
            logging.info(f"Processing dataset: {dataset['name']} with filters: {filters}")

            try:
                # Base experiment
                base_results, _ = run_experiment(dataset, create_SERN_AwGOP, filters, "base")
                results_dict[dataset['name']]["base"] = base_results

                # Augmented data experiment
                aug_results, _ = run_experiment(dataset, create_SERN_AwGOP, filters, "augmented", use_augmentation=True)
                results_dict[dataset['name']]["augmented"] = aug_results

                # Single electrode experiments
                for i in range(12):
                    single_results, _ = run_experiment(dataset, create_SERN_AwGOP, filters, f"single_electrode_{i}", single_electrode=i)
                    results_dict[dataset['name']][f"single_electrode_{i}"] = single_results

                # Pretrained model experiments
                for pretrain_dataset in datasets:
                    if pretrain_dataset['name'] != dataset['name']:
                        pretrain_results, _ = run_experiment(dataset, create_SERN_AwGOP, filters, f"pretrained_{pretrain_dataset['name']}", pretrain_dataset=pretrain_dataset)
                        results_dict[dataset['name']][f"pretrained_{pretrain_dataset['name']}"] = pretrain_results

                # 2. Data Efficiency Analysis
                efficiency_results = data_efficiency_analysis(dataset, create_SERN_AwGOP, filters)
                results_dict[dataset['name']]['data_efficiency'] = efficiency_results

                # 3. Longitudinal Analysis
                longitudinal_results = longitudinal_analysis(dataset, create_SERN_AwGOP, filters)
                results_dict[dataset['name']]['longitudinal'] = longitudinal_results

                # 4. Time Efficiency Analysis
                time_results = time_efficiency_analysis(dataset, create_SERN_AwGOP, filters)
                results_dict[dataset['name']]['time_efficiency'] = time_results

                # 5. Cross-dataset Validation
                for test_dataset in datasets:
                    if test_dataset['name'] != dataset['name']:
                        cross_results = cross_dataset_validation(dataset, test_dataset, create_SERN_AwGOP, filters)
                        results_dict[dataset['name']][f'cross_validation_{test_dataset["name"]}'] = cross_results

                # 6. Robustness Testing
                robustness_results = robustness_testing(dataset, create_SERN_AwGOP, filters)
                results_dict[dataset['name']]['robustness'] = robustness_results

            except Exception as e:
                logging.error(f"Error processing {dataset['name']}: {str(e)}")
                continue

            # Save results after each dataset
            save_results(results_dict[dataset['name']], dataset['name'], "all_experiments", "various", "various")

    logging.info("All experiments and analyses completed successfully.")

if __name__ == "__main__":
    main()

2024-08-22 17:29:24.917458: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-08-22 17:29:30,742 - INFO - Processing dataset: CUSPH-SR-AF with filters: [32, 64, 128]


Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100


  _warn_prf(average, modifier, msg_start, len(result))


Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100


2024-08-23 17:37:10,757 - INFO - Processing dataset: CSPC18-SR-AF with filters: [32, 64, 128]
2024-08-23 17:37:12,571 - INFO - Truncated CSPC18-SR-AF data to shape (2016, 5000, 12)


Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100


2024-08-23 17:57:54,982 - INFO - Truncated CSPC18-SR-AF data to shape (2016, 5000, 12)


Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100


2024-08-23 20:29:54,453 - INFO - Truncated CSPC18-SR-AF data to shape (2016, 5000, 12)


Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100


  _warn_prf(average, modifier, msg_start, len(result))




  _warn_prf(average, modifier, msg_start, len(result))
2024-08-23 20:44:43,646 - INFO - Truncated CSPC18-SR-AF data to shape (2016, 5000, 12)


Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100


  _warn_prf(average, modifier, msg_start, len(result))
2024-08-23 21:11:43,937 - INFO - Truncated CSPC18-SR-AF data to shape (2016, 5000, 12)


Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100


  _warn_prf(average, modifier, msg_start, len(result))




  _warn_prf(average, modifier, msg_start, len(result))
2024-08-23 21:31:11,103 - INFO - Truncated CSPC18-SR-AF data to shape (2016, 5000, 12)


Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100


  _warn_prf(average, modifier, msg_start, len(result))




  _warn_prf(average, modifier, msg_start, len(result))
2024-08-23 21:45:12,368 - INFO - Truncated CSPC18-SR-AF data to shape (2016, 5000, 12)


Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100


  _warn_prf(average, modifier, msg_start, len(result))




  _warn_prf(average, modifier, msg_start, len(result))
2024-08-23 22:01:23,005 - INFO - Truncated CSPC18-SR-AF data to shape (2016, 5000, 12)


Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100


  _warn_prf(average, modifier, msg_start, len(result))




  _warn_prf(average, modifier, msg_start, len(result))
2024-08-23 22:12:02,544 - INFO - Truncated CSPC18-SR-AF data to shape (2016, 5000, 12)


Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100


  _warn_prf(average, modifier, msg_start, len(result))




  _warn_prf(average, modifier, msg_start, len(result))
2024-08-23 22:24:18,599 - INFO - Truncated CSPC18-SR-AF data to shape (2016, 5000, 12)


Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100


  _warn_prf(average, modifier, msg_start, len(result))




  _warn_prf(average, modifier, msg_start, len(result))
2024-08-23 22:51:45,859 - INFO - Truncated CSPC18-SR-AF data to shape (2016, 5000, 12)


Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100


  _warn_prf(average, modifier, msg_start, len(result))




  _warn_prf(average, modifier, msg_start, len(result))
2024-08-23 23:17:33,173 - INFO - Truncated CSPC18-SR-AF data to shape (2016, 5000, 12)


Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100


  _warn_prf(average, modifier, msg_start, len(result))




  _warn_prf(average, modifier, msg_start, len(result))
2024-08-23 23:33:06,198 - INFO - Truncated CSPC18-SR-AF data to shape (2016, 5000, 12)


Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100


  _warn_prf(average, modifier, msg_start, len(result))




  _warn_prf(average, modifier, msg_start, len(result))
2024-08-23 23:58:39,874 - INFO - Truncated CSPC18-SR-AF data to shape (2016, 5000, 12)


Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100


  _warn_prf(average, modifier, msg_start, len(result))




  _warn_prf(average, modifier, msg_start, len(result))
2024-08-24 00:10:50,480 - INFO - Truncated CSPC18-SR-AF data to shape (2016, 5000, 12)


Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epo



Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100




Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100


2024-08-24 17:31:15,171 - INFO - All experiments and analyses completed successfully.


In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, average_precision_score
from tensorflow.keras import layers, models, callbacks
import json
import logging
import seaborn as sns
from scipy.signal import resample

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

# Early stopping callback
early_stopping = callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)

# Dataset definitions
datasets = [
    {"name": "CUSPH-SR-AF", "data": "data/data_2class_normal.npy", "labels": "data/labels_2class_normal.npy", "num_classes": 2, "input_shape": (5000, 12)},
    {"name": "CSPC18-SR-AF", "data": "data/data_2class_cpsc18.npy", "labels": "data/labels_2class_cpsc18.npy", "num_classes": 2, "input_shape": (15000, 12)},
]

filter_combinations = [[32, 64, 128]]

def save_results(results, dataset_name, experiment_type, train_condition, test_condition):
    filename = f"robustness/results/{dataset_name}_{experiment_type}_train_{train_condition}_test_{test_condition}.json"
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    
    with open(filename, 'w') as f:
        json.dump(results, f, indent=2)

def save_confusion_matrix(cm, dataset_name, experiment_name):
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title(f'Confusion Matrix - {dataset_name} - {experiment_name}')
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    plt.savefig(f"robustness/results/{dataset_name}_{experiment_name}_confusion_matrix.png")
    plt.close()

def downsample_block(x, filters):
    x = layers.Conv1D(filters // 2, 1, strides=1, padding='same')(x)
    x = mixed_pool_operator(x)
    return x

def branched_nodal_operator(x, filters, kernel_size=5, activation='relu'):
    y1 = layers.Conv1D(filters // 2, kernel_size, dilation_rate=2, padding='same')(x)
    y1 = layers.BatchNormalization()(y1)
    y1 = layers.Activation(activation)(y1)

    y2 = layers.SeparableConv1D(filters // 2, kernel_size, padding='same')(x)
    y2 = layers.BatchNormalization()(y2)
    y2 = layers.Activation(activation)(y2)

    y = layers.Concatenate()([y1, y2])
    return y

def mixed_pool_operator(x, pool_size=2, strides=1):
    y1 = layers.AveragePooling1D(pool_size, strides, padding='same')(x)
    y2 = layers.MaxPooling1D(pool_size, strides, padding='same')(x)
    y = layers.Concatenate()([y1, y2])
    return y

def squeeze_and_excitation_block(x, ratio=16):
    num_channels = x.shape[-1]
    squeeze = layers.GlobalAveragePooling1D()(x)
    excitation = layers.Dense(num_channels // ratio, activation='relu')(squeeze)
    excitation = layers.Dense(num_channels, activation='sigmoid')(excitation)
    excitation = layers.Reshape((1, num_channels))(excitation)
    scale = layers.Multiply()([x, excitation])
    return scale

def residual_block_SERN_AwGOP(x, filters, kernel_size=5, downsample=False):
    y = branched_nodal_operator(x, filters, kernel_size)
    y = branched_nodal_operator(y, filters, kernel_size)

    if downsample:
        x = downsample_block(x, filters)

    y = squeeze_and_excitation_block(y)

    attention_weights = layers.Dense(1, activation='sigmoid')(x)
    gop_out = layers.Multiply()([attention_weights, y])
    gop_out = layers.Add()([gop_out, x])
    gop_out = layers.Activation('relu')(gop_out)
    return gop_out

def create_SERN_AwGOP(input_shape, num_classes, filters):
    inputs = layers.Input(shape=input_shape)
    x = layers.Conv1D(filters[0], 5, strides=2, padding='same', activation='relu')(inputs)
    x = layers.MaxPooling1D(pool_size=3, strides=2, padding='same')(x)

    for f in filters[1:]:
        x = residual_block_SERN_AwGOP(x, f, downsample=True)
        x = residual_block_SERN_AwGOP(x, f)
        x = residual_block_SERN_AwGOP(x, f)
        x = residual_block_SERN_AwGOP(x, f)

    x = layers.GlobalAveragePooling1D()(x)
    if num_classes == 2:
        outputs = layers.Dense(1, activation='sigmoid')(x)
    else:
        outputs = layers.Dense(num_classes, activation='softmax')(x)
    model = models.Model(inputs, outputs)
    
    # Compile the model
    if num_classes == 2:
        model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    else:
        model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    
    return model

def train_model(model, X_train, y_train, X_val, y_val):
    try:
        history = model.fit(X_train, y_train, epochs=100, batch_size=32, validation_data=(X_val, y_val), callbacks=[early_stopping])
        return history
    except Exception as e:
        logging.error(f"Error in train_model: {str(e)}")
        return None

def evaluate_model(model, X_test, y_test, num_classes):
    try:
        y_pred = model.predict(X_test)
        y_pred_classes = np.argmax(y_pred, axis=1) if num_classes > 2 else (y_pred > 0.5).astype(int).flatten()

        results = {
            'accuracy': float(accuracy_score(y_test, y_pred_classes)),
            'precision_weighted': float(precision_score(y_test, y_pred_classes, average='weighted')),
            'recall_weighted': float(recall_score(y_test, y_pred_classes, average='weighted')),
            'f1_weighted': float(f1_score(y_test, y_pred_classes, average='weighted')),
        }

        cm = confusion_matrix(y_test, y_pred_classes)
        sensitivity = cm[1,1] / (cm[1,1] + cm[1,0])
        specificity = cm[0,0] / (cm[0,0] + cm[0,1])
        results['sensitivity_unweighted'] = float(sensitivity)
        results['specificity_unweighted'] = float(specificity)

        if num_classes == 2:
            results['auc_roc'] = float(roc_auc_score(y_test, y_pred))
            results['average_precision'] = float(average_precision_score(y_test, y_pred))

        return results, cm
    except Exception as e:
        logging.error(f"Error in evaluate_model: {str(e)}")
        return {}, None

def data_efficiency_analysis(dataset, model_fn, filters, fractions=[0.1, 0.25, 0.5, 0.75, 1.0]):
    try:
        data = np.load(dataset['data'])
        labels = np.load(dataset['labels'])
        X_train, X_test, y_train, y_test = train_test_split(data, labels, test_size=0.2, random_state=42)

        results = {}
        for fraction in fractions:
            n_samples = int(len(X_train) * fraction)
            X_train_subset = X_train[:n_samples]
            y_train_subset = y_train[:n_samples]

            model = model_fn(dataset['input_shape'], dataset['num_classes'], filters)
            history = train_model(model, X_train_subset, y_train_subset, X_test, y_test)
            test_results, cm = evaluate_model(model, X_test, y_test, dataset['num_classes'])
            results[fraction] = test_results
            
            # Save confusion matrix for each fraction
            save_confusion_matrix(cm, dataset['name'], f"data_efficiency_{fraction}")

        return results
    except Exception as e:
        logging.error(f"Error in data_efficiency_analysis: {str(e)}")
        return {}

def longitudinal_analysis(dataset, model_fn, filters, num_time_points=5):
    try:
        data = np.load(dataset['data'])
        labels = np.load(dataset['labels'])

        time_point_size = len(data) // num_time_points
        results = []

        for i in range(num_time_points):
            start_idx = i * time_point_size
            end_idx = (i + 1) * time_point_size
            X_train, X_test, y_train, y_test = train_test_split(data[start_idx:end_idx], labels[start_idx:end_idx], test_size=0.2, random_state=42)

            model = model_fn(dataset['input_shape'], dataset['num_classes'], filters)
            history = train_model(model, X_train, y_train, X_test, y_test)
            test_results, cm = evaluate_model(model, X_test, y_test, dataset['num_classes'])
            results.append(test_results)
            
            # Save confusion matrix for each time point
            save_confusion_matrix(cm, dataset['name'], f"longitudinal_timepoint_{i}")

        return results
    except Exception as e:
        logging.error(f"Error in longitudinal_analysis: {str(e)}")
        return []

def add_gaussian_noise(data, noise_factor=0.05):
    noise = np.random.normal(loc=0, scale=1, size=data.shape)
    return data + noise_factor * noise

def add_powerline_interference(data, frequency=50, amplitude=0.1):
    t = np.arange(data.shape[1]) / 1000  # Assume 1000 Hz sampling rate
    noise = amplitude * np.sin(2 * np.pi * frequency * t)
    return data + noise.reshape(1, -1, 1)

def add_electrode_motion_artifact(data, artifact_duration=100, amplitude=0.5):
    artifact = np.zeros(data.shape)
    for i in range(data.shape[0]):
        start = np.random.randint(0, data.shape[1] - artifact_duration)
        artifact[i, start:start+artifact_duration, :] = amplitude * np.random.randn(artifact_duration, data.shape[2])
    return data + artifact

def robustness_testing(dataset, model_fn, filters):
    try:
        data = np.load(dataset['data'])
        labels = np.load(dataset['labels'])
        X_train, X_test, y_train, y_test = train_test_split(data, labels, test_size=0.2, random_state=42)

        model = model_fn(dataset['input_shape'], dataset['num_classes'], filters)
        history = train_model(model, X_train, y_train, X_test, y_test)

        results = {}

        # Test on clean data
        clean_results, clean_cm = evaluate_model(model, X_test, y_test, dataset['num_classes'])
        results['clean'] = clean_results
        save_confusion_matrix(clean_cm, dataset['name'], "robustness_clean")

        # Test with Gaussian noise
        noisy_data = add_gaussian_noise(X_test)
        noise_results, noise_cm = evaluate_model(model, noisy_data, y_test, dataset['num_classes'])
        results['gaussian_noise'] = noise_results
        save_confusion_matrix(noise_cm, dataset['name'], "robustness_gaussian_noise")

        # Test with powerline interference
        powerline_data = add_powerline_interference(X_test)
        powerline_results, powerline_cm = evaluate_model(model, powerline_data, y_test, dataset['num_classes'])
        results['powerline_interference'] = powerline_results
        save_confusion_matrix(powerline_cm, dataset['name'], "robustness_powerline_interference")

        # Test with electrode motion artifact
        motion_data = add_electrode_motion_artifact(X_test)
        motion_results, motion_cm = evaluate_model(model, motion_data, y_test, dataset['num_classes'])
        results['electrode_motion'] = motion_results
        save_confusion_matrix(motion_cm, dataset['name'], "robustness_electrode_motion")

        return results
    except Exception as e:
        logging.error(f"Error in robustness_testing: {str(e)}")
        return {}

def cross_dataset_validation(train_dataset, test_dataset, model_fn, filters):
    try:
        train_data = np.load(train_dataset['data'])
        train_labels = np.load(train_dataset['labels'])
        test_data = np.load(test_dataset['data'])
        test_labels = np.load(test_dataset['labels'])

        # Ensure compatible shapes
        if train_data.shape[1:] != test_data.shape[1:]:
            test_data = resample(test_data, train_data.shape[1], axis=1)

        model = model_fn(train_dataset['input_shape'], train_dataset['num_classes'], filters)
        history = train_model(model, train_data, train_labels, test_data, test_labels)
        test_results, cm = evaluate_model(model, test_data, test_labels, test_dataset['num_classes'])
        
        # Save confusion matrix for cross-dataset validation
        save_confusion_matrix(cm, f"{train_dataset['name']}_{test_dataset['name']}", "cross_validation")

        return test_results
    except Exception as e:
        logging.error(f"Error in cross_dataset_validation: {str(e)}")
        return {}

def main():
    results_dict = {}
    for dataset in datasets:
        results_dict[dataset['name']] = {}
        for filters in filter_combinations:
            logging.info(f"Processing dataset: {dataset['name']} with filters: {filters}")

            try:
                # Data Efficiency Analysis
                efficiency_results = data_efficiency_analysis(dataset, create_SERN_AwGOP, filters)
                results_dict[dataset['name']]['data_efficiency'] = efficiency_results
                save_results(efficiency_results, dataset['name'], "data_efficiency", "various", "clean")

                # Longitudinal Analysis
                longitudinal_results = longitudinal_analysis(dataset, create_SERN_AwGOP, filters)
                results_dict[dataset['name']]['longitudinal'] = longitudinal_results
                save_results(longitudinal_results, dataset['name'], "longitudinal", "various", "clean")

                # Robustness Testing
                robustness_results = robustness_testing(dataset, create_SERN_AwGOP, filters)
                results_dict[dataset['name']]['robustness'] = robustness_results
                save_results(robustness_results, dataset['name'], "robustness", "original", "various")

                # Cross-dataset Validation
                for test_dataset in datasets:
                    if test_dataset['name'] != dataset['name']:
                        cross_results = cross_dataset_validation(dataset, test_dataset, create_SERN_AwGOP, filters)
                        results_dict[dataset['name']][f'cross_validation_{test_dataset["name"]}'] = cross_results
                        save_results(cross_results, f"{dataset['name']}_{test_dataset['name']}", "cross_validation", dataset['name'], test_dataset['name'])

            except Exception as e:
                logging.error(f"Error processing {dataset['name']}: {str(e)}")
                continue

    logging.info("All analyses completed successfully.")

if __name__ == "__main__":
    main()

2024-08-25 01:06:25.057510: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-08-25 01:06:30,206 - INFO - Processing dataset: CUSPH-SR-AF with filters: [32, 64, 128]


Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epo

  _warn_prf(average, modifier, msg_start, len(result))
2024-08-25 07:21:30,300 - INFO - Processing dataset: CSPC18-SR-AF with filters: [32, 64, 128]


Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch



Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100




Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100


2024-08-25 17:46:56,422 - INFO - All analyses completed successfully.


In [2]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import resample

# Load data
data = np.load('data/data_2class_normal.npy')
labels = np.load('data/labels_2class_normal.npy')

# Select a random AF sample
af_indices = np.where(labels == 1)[0]
random_af_index = np.random.choice(af_indices)
sample = data[random_af_index]

# Ensure sample is (5000, 12)
if sample.shape[0] > 5000:
    sample = resample(sample, 5000, axis=0)

# Define transformation functions
def add_gaussian_noise(data, noise_factor=0.05):
    noise = np.random.normal(loc=0, scale=1, size=data.shape)
    return data + noise_factor * noise

def add_random_baseline_wander(data, max_amplitude=0.1):
    t = np.linspace(0, 1, data.shape[0])
    baseline = max_amplitude * np.sin(2 * np.pi * np.random.rand() * t)
    return data + baseline.reshape(-1, 1)

def time_warp(data, sigma=0.2, knot=4):
    orig_steps = np.arange(data.shape[0])
    random_warps = np.random.normal(loc=1.0, scale=sigma, size=(knot+2,))
    warp_steps = np.linspace(0, data.shape[0]-1, num=knot+2)
    warper = np.interp(orig_steps, warp_steps, random_warps)
    return data * warper.reshape(-1, 1)

def simulate_hardware_issues(data, corruption_rate=0.1):
    corrupted_data = data.copy()
    mask = np.random.choice([0, 1], size=data.shape, p=[1-corruption_rate, corruption_rate])
    corrupted_data[mask.astype(bool)] = np.random.normal(loc=0, scale=1, size=corrupted_data[mask.astype(bool)].shape)
    return corrupted_data

# Create figure
fig = plt.figure(figsize=(20, 15))

# First subplot: All 12 electrodes
ax1 = fig.add_subplot(221)
ax1.set_title("Original 12-lead ECG")
for i in range(12):
    ax1.plot(sample[:, i] + i*4, label=f'Lead {i+1}')
ax1.set_yticks(np.arange(0, 48, 4))
ax1.set_yticklabels([f'Lead {i+1}' for i in range(12)])
ax1.set_xlabel('Time')
ax1.set_ylabel('Amplitude')

# Second subplot: Transformations on specific electrodes
ax2 = fig.add_subplot(222)
ax2.set_title("Transformations on Specific Leads")

# Apply transformations
noisy = add_gaussian_noise(sample[:, 1])
baseline_wander = add_random_baseline_wander(sample[:, 3])
time_warped = time_warp(sample[:, 5])
all_combined = time_warp(add_random_baseline_wander(add_gaussian_noise(sample[:, 7])))

ax2.plot(sample[:, 1], label='Original (Lead 2)')
ax2.plot(noisy, label='Noisy (Lead 2)')
ax2.plot(baseline_wander, label='Baseline Wander (Lead 4)')
ax2.plot(time_warped, label='Time Warped (Lead 6)')
ax2.plot(all_combined, label='All Combined (Lead 8)')
ax2.legend()
ax2.set_xlabel('Time')
ax2.set_ylabel('Amplitude')

# Third subplot: Hardware corruption on all electrodes
ax3 = fig.add_subplot(223)
ax3.set_title("Hardware Corruption on All Leads")
corrupted = simulate_hardware_issues(sample)
for i in range(12):
    ax3.plot(corrupted[:, i] + i*4, label=f'Lead {i+1}')
ax3.set_yticks(np.arange(0, 48, 4))
ax3.set_yticklabels([f'Lead {i+1}' for i in range(12)])
ax3.set_xlabel('Time')
ax3.set_ylabel('Amplitude')

# Fourth subplot: Longitudinal segments
ax4 = fig.add_subplot(224)
ax4.set_title("Longitudinal Segments")
num_segments = 5
segment_length = sample.shape[0] // num_segments
for i in range(12):
    for j in range(num_segments):
        start = j * segment_length
        end = (j + 1) * segment_length
        ax4.plot(range(start, end), sample[start:end, i] + i*4, label=f'Lead {i+1}, Segment {j+1}' if j == 0 else "")
ax4.set_yticks(np.arange(0, 48, 4))
ax4.set_yticklabels([f'Lead {i+1}' for i in range(12)])
ax4.set_xlabel('Time')
ax4.set_ylabel('Amplitude')

plt.tight_layout()
plt.savefig('ecg_transformations_visualization.png', dpi=300, bbox_inches='tight')
plt.close()

print("Visualization saved as 'ecg_transformations_visualization.png'")

  plt.tight_layout()
  plt.tight_layout()
  plt.savefig('ecg_transformations_visualization.png', dpi=300, bbox_inches='tight')


ValueError: Image size of 4964x942782 pixels is too large. It must be less than 2^16 in each direction.

  func(*args, **kwargs)
  fig.canvas.print_figure(bytes_io, **kw)


ValueError: Image size of 1654x314261 pixels is too large. It must be less than 2^16 in each direction.

<Figure size 2000x1500 with 4 Axes>

In [5]:
import os
import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score
import json
import logging

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)


def downsample_block(x, filters):
    x = layers.Conv1D(filters // 2, 1, strides=1, padding='same')(x)
    x = mixed_pool_operator(x)
    return x

def branched_nodal_operator(x, filters, kernel_size=5, activation='relu'):
    y1 = layers.Conv1D(filters // 2, kernel_size, dilation_rate=2, padding='same')(x)
    y1 = layers.BatchNormalization()(y1)
    y1 = layers.Activation(activation)(y1)

    y2 = layers.SeparableConv1D(filters // 2, kernel_size, padding='same')(x)
    y2 = layers.BatchNormalization()(y2)
    y2 = layers.Activation(activation)(y2)

    y = layers.Concatenate()([y1, y2])
    return y

def mixed_pool_operator(x, pool_size=2, strides=1):
    y1 = layers.AveragePooling1D(pool_size, strides, padding='same')(x)
    y2 = layers.MaxPooling1D(pool_size, strides, padding='same')(x)
    y = layers.Concatenate()([y1, y2])
    return y

def squeeze_and_excitation_block(x, ratio=16):
    num_channels = x.shape[-1]
    squeeze = layers.GlobalAveragePooling1D()(x)
    excitation = layers.Dense(num_channels // ratio, activation='relu')(squeeze)
    excitation = layers.Dense(num_channels, activation='sigmoid')(excitation)
    excitation = layers.Reshape((1, num_channels))(excitation)
    scale = layers.Multiply()([x, excitation])
    return scale

def residual_block_SERN_AwGOP(x, filters, kernel_size=5, downsample=False):
    y = branched_nodal_operator(x, filters, kernel_size)
    y = branched_nodal_operator(y, filters, kernel_size)

    if downsample:
        x = downsample_block(x, filters)

    y = squeeze_and_excitation_block(y)

    attention_weights = layers.Dense(1, activation='sigmoid')(x)
    gop_out = layers.Multiply()([attention_weights, y])
    gop_out = layers.Add()([gop_out, x])
    gop_out = layers.Activation('relu')(gop_out)
    return gop_out

def create_SERN_AwGOP(input_shape, num_classes, filters):
    inputs = layers.Input(shape=input_shape)
    x = layers.Conv1D(filters[0], 5, strides=2, padding='same', activation='relu')(inputs)
    x = layers.MaxPooling1D(pool_size=3, strides=2, padding='same')(x)

    for f in filters[1:]:
        x = residual_block_SERN_AwGOP(x, f, downsample=True)
        x = residual_block_SERN_AwGOP(x, f)
        x = residual_block_SERN_AwGOP(x, f)
        x = residual_block_SERN_AwGOP(x, f)

    x = layers.GlobalAveragePooling1D()(x)
    if num_classes == 2:
        outputs = layers.Dense(1, activation='sigmoid')(x)
    else:
        outputs = layers.Dense(num_classes, activation='softmax')(x)
    model = models.Model(inputs, outputs)
    
    # Compile the model
    if num_classes == 2:
        model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    else:
        model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    
    return model

def train_model(model, X_train, y_train, X_val, y_val):
    try:
        history = model.fit(X_train, y_train, epochs=100, batch_size=32, validation_data=(X_val, y_val), callbacks=[early_stopping])
        return history
    except Exception as e:
        logging.error(f"Error in train_model: {str(e)}")
        return None


def evaluate_model(model, X_test, y_test, num_classes):
    try:
        y_pred = model.predict(X_test)
        y_pred_classes = np.argmax(y_pred, axis=1) if num_classes > 2 else (y_pred > 0.5).astype(int).flatten()

        results = {
            'accuracy': float(accuracy_score(y_test, y_pred_classes)),
            'precision_weighted': float(precision_score(y_test, y_pred_classes, average='weighted')),
            'recall_weighted': float(recall_score(y_test, y_pred_classes, average='weighted')),
            'f1_weighted': float(f1_score(y_test, y_pred_classes, average='weighted')),
        }

        cm = confusion_matrix(y_test, y_pred_classes)
        results['confusion_matrix'] = cm.tolist()  # Convert to list for JSON serialization
        
        if num_classes == 2:
            sensitivity = cm[1,1] / (cm[1,1] + cm[1,0])
            specificity = cm[0,0] / (cm[0,0] + cm[0,1])
            results['sensitivity_unweighted'] = float(sensitivity)
            results['specificity_unweighted'] = float(specificity)
            results['auc_roc'] = float(roc_auc_score(y_test, y_pred))
            results['average_precision'] = float(average_precision_score(y_test, y_pred))

        return results
    except Exception as e:
        logging.error(f"Error in evaluate_model: {str(e)}")
        return {}

def simulate_hardware_issues(data, corruption_level):
    corrupted_data = data.copy()
    corruption_rate = corruption_level * 0.05
    disconnection_rate = corruption_level * 0.02
    saturation_threshold = 7.0 / corruption_level

    # Simulate random noise (general corruption)
    noise_mask = np.random.choice([0, 1], size=data.shape, p=[1-corruption_rate, corruption_rate])
    corrupted_data[noise_mask.astype(bool)] = np.random.normal(loc=0, scale=0.5, size=corrupted_data[noise_mask.astype(bool)].shape)
    
    # Simulate electrode disconnections (set segments of channels to zero)
    for channel in range(corrupted_data.shape[2]):  # Assuming shape is (samples, time_steps, channels)
        if np.random.rand() < disconnection_rate:
            segment_length = np.random.randint(low=1, high=data.shape[1] // 4)  # Random segment length
            start = np.random.randint(low=0, high=data.shape[1] - segment_length)
            corrupted_data[:, start:start+segment_length, channel] = 0
    
    # Simulate amplifier saturation (clipping)
    corrupted_data = np.clip(corrupted_data, -saturation_threshold, saturation_threshold)
    
    return corrupted_data

def save_results(results, dataset_name, experiment_type):
    filename = f"results/{dataset_name}_{experiment_type}_hw_corruption.json"
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    with open(filename, 'w') as f:
        json.dump(results, f, indent=2)

def hardware_corruption_test(dataset, model_fn, filters):
    try:
        data = np.load(dataset['data'])
        labels = np.load(dataset['labels'])
        X_train, X_test, y_train, y_test = train_test_split(data, labels, test_size=0.2, random_state=42)

        # Train on clean data
        model = model_fn(dataset['input_shape'], dataset['num_classes'], filters)
        history = train_model(model, X_train, y_train, X_test, y_test)

        results = {}

        # Test with different levels of hardware corruption
        corruption_levels = [1, 2, 3]  # Low, medium, high corruption
        for level in corruption_levels:
            corrupted_data = simulate_hardware_issues(X_test, level)
            hw_results = evaluate_model(model, corrupted_data, y_test, dataset['num_classes'])
            results[f'hardware_corruption_level_{level}'] = hw_results

        return results
    except Exception as e:
        logging.error(f"Error in hardware_corruption_test: {str(e)}")
        return {}

def main():
    datasets = [
        {"name": "CUSPH-SR-AF", "data": "data/data_2class_normal.npy", "labels": "data/labels_2class_normal.npy", "num_classes": 2, "input_shape": (5000, 12)},
    ]

    filter_combinations = [[32, 64, 128]]

    for dataset in datasets:
        for filters in filter_combinations:
            logging.info(f"Processing dataset: {dataset['name']} with filters: {filters}")

            try:
                # Hardware Corruption Test
                hw_results = hardware_corruption_test(dataset, create_SERN_AwGOP, filters)
                save_results(hw_results, dataset['name'], "hardware_corruption")

            except Exception as e:
                logging.error(f"Error processing {dataset['name']}: {str(e)}")
                continue

    logging.info("All hardware corruption tests completed successfully.")

if __name__ == "__main__":
    main()

2024-08-26 09:00:01,219 - INFO - Processing dataset: CUSPH-SR-AF with filters: [32, 64, 128]


Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100


2024-08-26 10:19:19,140 - INFO - All hardware corruption tests completed successfully.
