In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from sklearn.model_selection import KFold, train_test_split
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, average_precision_score
from tensorflow.keras import layers, models, callbacks
from tensorflow.keras.models import Model
import tensorflow_model_optimization as tfmot
import json
import time
import logging
from tqdm import tqdm
import seaborn as sns

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

# Early stopping callback
early_stopping = callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)

# Dataset definitions
# datasets = [
#     {"name": "CUSPH-AF-AFL", "data": "data/data_2class_arrhythmia.npy", "labels": "data/labels_2class_arrhythmia.npy", "num_classes": 2, "input_shape": (5000, 12)},
#     {"name": "CUSPH-SR-AF+AFL", "data": "data/data_2class_normal_combined.npy", "labels": "data/labels_2class_normal_combined.npy", "num_classes": 2, "input_shape": (5000, 12)},
#     {"name": "CUSPH-ALL", "data": "data/data_11class.npy", "labels": "data/labels_11class.npy", "num_classes": 11, "input_shape": (5000, 12)},
#     {"name": "CSPC18-SR-AF", "data": "data/data_2class_cpsc18.npy", "labels": "data/labels_2class_cpsc18.npy", "num_classes": 2, "input_shape": (15000, 12)},
#     {"name": "CUSPH-SR-AF", "data": "data/data_2class_normal.npy", "labels": "data/labels_2class_normal.npy", "num_classes": 2, "input_shape": (5000, 12)}
# ]

datasets = [
#     {"name": "Georgia-ALL", "data": "data/data_singleclass_georgia.npy", "labels": "data/labels_singleclass_georgia.npy", "num_classes": 56, "input_shape": (5000, 12)},
    {"name": "CPSC18-ALL", "data": "data/data_multiclass_cpsc18.npy", "labels": "data/labels_multiclass_cpsc18.npy", "num_classes": 9, "input_shape": (15000, 12)},
]

# datasets = [
#     {"name": "Georgia-ALL", "data": "data/data_singleclass_georgia.npy", "labels": "data/labels_singleclass_georgia.npy", "num_classes": 56, "input_shape": (5000, 12)},
# ]

filter_combinations = [[32, 64, 128]]

# def save_results(results, dataset_name, fold=None):
#     filename = f"results/{dataset_name}_results{'_fold_' + str(fold) if fold is not None else ''}.json"
#     os.makedirs(os.path.dirname(filename), exist_ok=True)
#     with open(filename, 'w') as f:
#         json.dump(results, f, indent=2)

import json

def convert_to_serializable(obj):
    if isinstance(obj, np.integer):
        return int(obj)
    elif isinstance(obj, np.floating):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, dict):
        return {key: convert_to_serializable(value) for key, value in obj.items()}
    elif isinstance(obj, list):
        return [convert_to_serializable(item) for item in obj]
    else:
        return obj
    
def save_results(results, filename):
    serializable_results = convert_to_serializable(results)
    with open(f"{filename}.json", 'w') as f:
        json.dump(serializable_results, f, indent=2)

def save_confusion_matrix(cm, dataset_name, fold=None):
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title(f'Confusion Matrix - {dataset_name}')
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    plt.savefig(f"results/{dataset_name}_confusion_matrix{'_fold_' + str(fold) if fold is not None else ''}.png")
    plt.close()

def calculate_dataset_stats(data, labels):
    stats = {
        'overall': {
            'mean': np.mean(data),
            'std': np.std(data),
            'min': np.min(data),
            'max': np.max(data),
        },
        'channel_wise': [],
        'label_distribution': np.bincount(labels.astype(int)).tolist()
    }
    
    for i in range(data.shape[-1]):
        channel_data = data[:, :, i]
        channel_stats = {
            'mean': np.mean(channel_data),
            'std': np.std(channel_data),
            'min': np.min(channel_data),
            'max': np.max(channel_data),
        }
        stats['channel_wise'].append(channel_stats)
    
    return stats

def downsample_block(x, filters):
    x = layers.Conv1D(filters // 2, 1, strides=1, padding='same')(x)
    x = mixed_pool_operator(x)
    return x

def branched_nodal_operator(x, filters, kernel_size=5, activation='relu'):
    y1 = layers.Conv1D(filters // 2, kernel_size, dilation_rate=2, padding='same')(x)
    y1 = layers.BatchNormalization()(y1)
    y1 = layers.Activation(activation)(y1)

    y2 = layers.SeparableConv1D(filters // 2, kernel_size, padding='same')(x)
    y2 = layers.BatchNormalization()(y2)
    y2 = layers.Activation(activation)(y2)

    y = layers.Concatenate()([y1, y2])
    return y

def mixed_pool_operator(x, pool_size=2, strides=1):
    y1 = layers.AveragePooling1D(pool_size, strides, padding='same')(x)
    y2 = layers.MaxPooling1D(pool_size, strides, padding='same')(x)
    y = layers.Concatenate()([y1, y2])
    return y

def squeeze_and_excitation_block(x, ratio=16):
    num_channels = x.shape[-1]
    squeeze = layers.GlobalAveragePooling1D()(x)
    excitation = layers.Dense(num_channels // ratio, activation='relu')(squeeze)
    excitation = layers.Dense(num_channels, activation='sigmoid')(excitation)
    excitation = layers.Reshape((1, num_channels))(excitation)
    scale = layers.Multiply()([x, excitation])
    return scale

def residual_block_SERN_AwGOP(x, filters, kernel_size=5, downsample=False):
    y = branched_nodal_operator(x, filters, kernel_size)
    y = branched_nodal_operator(y, filters, kernel_size)

    if downsample:
        x = downsample_block(x, filters)

    y = squeeze_and_excitation_block(y)

    attention_weights = layers.Dense(1, activation='sigmoid')(x)
    gop_out = layers.Multiply()([attention_weights, y])
    gop_out = layers.Add()([gop_out, x])
    gop_out = layers.Activation('relu')(gop_out)
    return gop_out

def create_SERN_AwGOP(input_shape, num_classes, filters):
    inputs = layers.Input(shape=input_shape)
    x = layers.Conv1D(filters[0], 5, strides=2, padding='same', activation='relu')(inputs)
    x = layers.MaxPooling1D(pool_size=3, strides=2, padding='same')(x)

    for f in filters[1:]:
        x = residual_block_SERN_AwGOP(x, f, downsample=True)
        x = residual_block_SERN_AwGOP(x, f)
        x = residual_block_SERN_AwGOP(x, f)
        x = residual_block_SERN_AwGOP(x, f)

    x = layers.GlobalAveragePooling1D()(x)
    if num_classes == 2:
        outputs = layers.Dense(1, activation='sigmoid')(x)
    else:
        outputs = layers.Dense(num_classes, activation='softmax')(x)
    model = models.Model(inputs, outputs)
    return model


def create_model_without_se(input_shape, num_classes, filters):
    # Similar to create_SERN_AwGOP but without squeeze_and_excitation_block
    inputs = layers.Input(shape=input_shape)
    x = layers.Conv1D(filters[0], 5, strides=2, padding='same', activation='relu')(inputs)
    x = layers.MaxPooling1D(pool_size=3, strides=2, padding='same')(x)

    for f in filters[1:]:
        x = residual_block_without_se(x, f, downsample=True)
        x = residual_block_without_se(x, f)
        x = residual_block_without_se(x, f)
        x = residual_block_without_se(x, f)

    x = layers.GlobalAveragePooling1D()(x)
    if num_classes == 2:
        outputs = layers.Dense(1, activation='sigmoid')(x)
    else:
        outputs = layers.Dense(num_classes, activation='softmax')(x)
    return models.Model(inputs, outputs)

def residual_block_without_se(x, filters, kernel_size=5, downsample=False):
    y = branched_nodal_operator(x, filters, kernel_size)
    y = branched_nodal_operator(y, filters, kernel_size)

    if downsample:
        x = downsample_block(x, filters)

    attention_weights = layers.Dense(1, activation='sigmoid')(x)
    gop_out = layers.Multiply()([attention_weights, y])
    gop_out = layers.Add()([gop_out, x])
    gop_out = layers.Activation('relu')(gop_out)
    return gop_out

def create_model_without_attention(input_shape, num_classes, filters):
    # Similar to create_SERN_AwGOP but without attention mechanism
    inputs = layers.Input(shape=input_shape)
    x = layers.Conv1D(filters[0], 5, strides=2, padding='same', activation='relu')(inputs)
    x = layers.MaxPooling1D(pool_size=3, strides=2, padding='same')(x)

    for f in filters[1:]:
        x = residual_block_without_attention(x, f, downsample=True)
        x = residual_block_without_attention(x, f)
        x = residual_block_without_attention(x, f)
        x = residual_block_without_attention(x, f)

    x = layers.GlobalAveragePooling1D()(x)
    if num_classes == 2:
        outputs = layers.Dense(1, activation='sigmoid')(x)
    else:
        outputs = layers.Dense(num_classes, activation='softmax')(x)
    return models.Model(inputs, outputs)

def residual_block_without_attention(x, filters, kernel_size=5, downsample=False):
    y = branched_nodal_operator(x, filters, kernel_size)
    y = branched_nodal_operator(y, filters, kernel_size)

    if downsample:
        x = downsample_block(x, filters)

    y = squeeze_and_excitation_block(y)
    y = layers.Add()([y, x])
    y = layers.Activation('relu')(y)
    return y

def create_model_simple_conv(input_shape, num_classes, filters):
    inputs = layers.Input(shape=input_shape)
    x = layers.Conv1D(filters[0], 5, strides=2, padding='same', activation='relu')(inputs)
    x = layers.MaxPooling1D(pool_size=3, strides=2, padding='same')(x)

    for f in filters[1:]:
        x = layers.Conv1D(f, 3, padding='same', activation='relu')(x)
        x = layers.MaxPooling1D(pool_size=2, strides=2, padding='same')(x)

    x = layers.GlobalAveragePooling1D()(x)
    if num_classes == 2:
        outputs = layers.Dense(1, activation='sigmoid')(x)
    else:
        outputs = layers.Dense(num_classes, activation='softmax')(x)
    return models.Model(inputs, outputs)


def analyze_computational_complexity(model):
    total_params = model.count_params()
    total_flops = 0

    for layer in model.layers:
        if isinstance(layer, layers.Conv1D):
            output_shape = layer.output_shape
            kernel_size = layer.kernel_size[0]
            input_channels = layer.input_shape[-1]
            output_channels = layer.filters
            flops = output_shape[1] * output_shape[2] * kernel_size * input_channels * output_channels
            total_flops += flops
        elif isinstance(layer, layers.Dense):
            input_shape = layer.input_shape
            output_shape = layer.output_shape
            flops = input_shape[-1] * output_shape[-1]
            total_flops += flops

    return {
        'total_params': total_params,
        'total_flops': total_flops
    }

def train_model(model, X_train, y_train, X_val, y_val):
    if model.output_shape[-1] == 1:  # Binary classification
        model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    else:  # Multi-class classification
        model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

    history = model.fit(X_train, y_train, epochs=100, batch_size=32, validation_data=(X_val, y_val), callbacks=[early_stopping])
    return history

def evaluate_model(model, X_test, y_test, num_classes):
    y_pred = model.predict(X_test)
    y_pred_classes = np.argmax(y_pred, axis=1) if num_classes > 2 else (y_pred > 0.5).astype(int).flatten()

    results = {
        'accuracy': accuracy_score(y_test, y_pred_classes),
        'precision': precision_score(y_test, y_pred_classes, average='weighted'),
        'recall': recall_score(y_test, y_pred_classes, average='weighted'),
        'f1': f1_score(y_test, y_pred_classes, average='weighted'),
    }

    if num_classes == 2:
        results['auc_roc'] = roc_auc_score(y_test, y_pred)
        results['average_precision'] = average_precision_score(y_test, y_pred)

    cm = confusion_matrix(y_test, y_pred_classes)

    return results, cm

def perform_cross_validation(dataset, model_fn, filters):
    data = np.load(dataset['data'])
    labels = np.load(dataset['labels'])
    kf = KFold(n_splits=5, shuffle=True, random_state=42)
    fold_results = []

    # Calculate and save overall dataset statistics
    overall_stats = calculate_dataset_stats(data, labels)
    save_results(overall_stats, f"{dataset['name']}_overall_stats")

    for fold, (train_index, val_index) in enumerate(kf.split(data)):
        logging.info(f"Processing fold {fold + 1} for dataset {dataset['name']}")

        X_train, X_val = data[train_index], data[val_index]
        y_train, y_val = labels[train_index], labels[val_index]

        # Calculate and save train/val set statistics
        train_stats = calculate_dataset_stats(X_train, y_train)
        val_stats = calculate_dataset_stats(X_val, y_val)
        save_results(train_stats, f"{dataset['name']}_train_stats_fold_{fold}")
        save_results(val_stats, f"{dataset['name']}_val_stats_fold_{fold}")

        model = model_fn(dataset['input_shape'], dataset['num_classes'], filters)
        history = train_model(model, X_train, y_train, X_val, y_val)

        results, cm = evaluate_model(model, X_val, y_val, dataset['num_classes'])
        fold_results.append(results)

        save_results(results, dataset['name'], fold)
        save_confusion_matrix(cm, dataset['name'], fold)

        complexity_results = analyze_computational_complexity(model)
        save_results(complexity_results, f"{dataset['name']}_complexity", fold)

    return calculate_average_results(fold_results)

def calculate_average_results(fold_results):
    avg_results = {}
    for key in fold_results[0].keys():
        avg_results[key] = np.mean([fold[key] for fold in fold_results])
        avg_results[f'{key}_std'] = np.std([fold[key] for fold in fold_results])
    return avg_results

def analyze_model_components(dataset, model_fn, filters):
    data = np.load(dataset['data'])
    labels = np.load(dataset['labels'])
    X_train, X_test, y_train, y_test = train_test_split(data, labels, test_size=0.2, random_state=42)

    base_model = model_fn(dataset['input_shape'], dataset['num_classes'], filters)
    base_results, _ = evaluate_model(base_model, X_test, y_test, dataset['num_classes'])

    component_results = {'base_model': base_results}

    # Analyze without SE blocks
    model_without_se = create_model_without_se(dataset['input_shape'], dataset['num_classes'], filters)
    results_without_se, _ = evaluate_model(model_without_se, X_test, y_test, dataset['num_classes'])
    component_results['without_se'] = results_without_se

    # Analyze without attention mechanism
    model_without_attention = create_model_without_attention(dataset['input_shape'], dataset['num_classes'], filters)
    results_without_attention, _ = evaluate_model(model_without_attention, X_test, y_test, dataset['num_classes'])
    component_results['without_attention'] = results_without_attention

    # Analyze with simpler convolution blocks
    model_simple_conv = create_model_simple_conv(dataset['input_shape'], dataset['num_classes'], filters)
    results_simple_conv, _ = evaluate_model(model_simple_conv, X_test, y_test, dataset['num_classes'])
    component_results['simple_conv'] = results_simple_conv

    save_results(component_results, f"{dataset['name']}_component_analysis")
    return component_results

def main():
    for dataset in datasets:
        logging.info(f"Processing dataset: {dataset['name']}")
        
        for filters in filter_combinations:
            logging.info(f"Using filter combination: {filters}")
            
            try:
                # Perform cross-validation
                avg_results = perform_cross_validation(dataset, create_SERN_AwGOP, filters)
                save_results(avg_results, f"{dataset['name']}_average_results")
                
                # Analyze model components
                component_results = analyze_model_components(dataset, create_SERN_AwGOP, filters)
                logging.info(f"Component analysis results: {component_results}")
                
            except Exception as e:
                logging.error(f"An error occurred while processing {dataset['name']}: {str(e)}")
                continue

if __name__ == "__main__":
    main()

2024-08-25 18:38:55.354808: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-08-25 18:39:00,900 - INFO - Processing dataset: CPSC18-ALL
2024-08-25 18:39:00,901 - INFO - Using filter combination: [32, 64, 128]
2024-08-25 18:41:12,876 - INFO - Processing fold 1 for dataset CPSC18-ALL


Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100


2024-08-26 02:04:52,413 - ERROR - An error occurred while processing CPSC18-ALL: save_results() takes 2 positional arguments but 3 were given
