# Medical Data Analysis with Deep Learning

This notebook demonstrates the use of deep learning for medical data analysis, supporting both **medical images** and **ECG signals** with classification and regression tasks. 

## Pipeline Overview:
1. **Data preprocessing** (image enhancement or ECG signal processing)
2. **Dataset creation and data loading** (with proper train/val/test splits)
3. **Model architecture definition** (2D CNN for images, 1D CNN for ECG)
4. **Hyperparameter tuning** (grid search optimization)
5. **Model training and evaluation** (with early stopping and metrics)
6. **Results visualization** (training curves, confusion matrices, etc.)

## Key Features:
- **Unified interface** for both image and ECG data
- **Modular design** with separate preprocessing, dataset, and model modules
- **Robust ECG processing** including PhysioNet format support
- **Comprehensive evaluation** with classification and regression metrics

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import pandas as pd
import json

# Import image processing functions
from preprocessing import (
    process_single_image, visualize_preprocessing, process_all_images,
    # ECG processing functions
    read_physionet_data, process_ecg_signal, process_all_ecg_signals, preprocess_jiachen_files,
    visualize_ecg_preprocessing,
    # voice processing functions
    read_voice_data, read_voice_metadata, preprocess_voice_signal,
    process_voice_file, process_all_voice_signals, create_voice_labels_file,
    visualize_voice_preprocessing
)

# Import dataset functions
from dataset import (
    MedicalImageDataset, ECGDataset, 
    create_data_loaders, create_ecg_data_loaders, create_voice_data_loaders,
    create_asd_data_loaders, ASDDataset  # Added ASD dataset support
)

# Import model classes
from model import MedicalCNN, ECG1DCNN, ModelTrainer, Voice1DCNN, ASDTabularModel  # Added ASD model

# Import hyperparameter tuning
from hyperparameter_tuning import HyperparameterTuner

## Configuration

Set up the configuration for the experiment. You can modify these parameters to experiment with different settings.

In [9]:
student_name = '' # don't change this

# Experiment configuration
config = {
    'data_type': 'tabular',  # 'image', 'voice' 'ECG' or 'tabular' - determines data processing pipeline
    'task_type': 'classification',  # 'classification' or 'regression'
    'num_classes': 2,  # for classification only
    'image_dir': f'./data/{student_name}after_processed',
    'labels_file': f'./data/{student_name}labels.csv',
    'batch_size': 32,
    'num_epochs': 10,
    'learning_rate': 0.001,
    'train_ratio': 0.7,
    'val_ratio': 0.15,
    'test_ratio': 0.15,
    'random_seed': 42,
    'early_stopping_patience': 15,
    'save_dir': './results',
    
    # ECG-specific parameters
    'ecg_max_length': 500,  # Target sequence length for ECG signals (uniform sampling)
    
    # voice-specific parameters
    'voice_max_length': 500,  # 5 seconds at 8kHz
    'target_variable': 'Voice Handicap Index (VHI) Score',  # 'VHI Score', 'RSI Score', 'Diagnosis'
    
    # Grid search parameters
    'grid_search': {
        'num_conv_layers': [3],
        'conv_channels': [64],
        'fc_layers': [[128, 64], [128, 32]],  # , [1024, 256, 64]
        'learning_rate': [0.01] # , 0.0001
    }
}

# directly set the class names according to your data
class_nms = {
    0: 'Normal',
    1: 'Abnormal'
    }

## Data Preprocessing

Let's preprocess the data based on the data type (image or ECG signals).

In [3]:
if config['data_type'] == 'image':
    # Visualize preprocessing on a sample image
    ori_image_dir = f'./data/{student_name}ori_images'
    sample_image_path = list(Path(ori_image_dir).glob('*.png'))[0]
    visualize_preprocessing(sample_image_path)
    
elif config['data_type'] == 'ECG':
    label_df = pd.read_csv(config['labels_file'])
    label_df.columns = [x.strip() for x in label_df.columns]
    label_df['y'] = label_df['cause of death'].astype(int)
    # label_df['id'] = label_df['id'].apply(lambda x: x.replace('Jiachen_', 'P'))
    label_df.to_csv(config['labels_file'], index=False)
    
    # ECG data directories
    ecg_raw_dir = f'./data/{student_name}ori_images'
    ecg_processed_dir = f'./data/{student_name}after_processed'
    
    # Handle Jiachen's special file naming
    preprocess_jiachen_files(ecg_raw_dir, config['labels_file'])
    
    # Process all ECG files
    processed_count = process_all_ecg_signals(
        ecg_raw_dir, 
        ecg_processed_dir, 
        config['ecg_max_length']
    )
    
    # Visualize preprocessing on a sample ECG
    sample_files = list(Path(ecg_raw_dir).glob('*.dat'))
    if sample_files:
        sample_dat = sample_files[0]
        sample_hea = Path(str(sample_dat).replace('.dat', '.hea'))
        if sample_hea.exists():
            visualize_ecg_preprocessing(sample_dat, sample_hea, config['ecg_max_length'])
        else:
            print(f"Warning: Header file {sample_hea} not found")
    else:
        print("No ECG files found for visualization")

elif config['data_type'] == 'voice':
    # voice data directories
    voice_raw_dir = f'./data/{student_name}ori_images'
    voice_processed_dir = f'./data/{student_name}after_processed'
    
    # Process all voice files
    processed_count = process_all_voice_signals(
        voice_raw_dir, 
        voice_processed_dir, 
        config['voice_max_length']
    )

    # Create labels file
    labels_df = create_voice_labels_file(
        voice_raw_dir, 
        config['labels_file'], 
        config['target_variable']
    )
    print(labels_df.head())
    labels_df.to_csv(config['labels_file'], index=False)

    # Visualize preprocessing on a sample voice signal
    sample_files = list(Path(voice_raw_dir).glob('*.dat'))
    if sample_files:
        sample_dat = sample_files[0]
        sample_hea = Path(str(sample_dat).replace('.dat', '.hea'))
        sample_txt = Path(str(sample_dat).replace('.dat', '.txt'))
        if sample_hea.exists() and sample_txt.exists():
            visualize_voice_preprocessing(sample_txt, sample_hea, config['voice_max_length'], 3)
        else:
            print(f"Warning: Header file {sample_hea} not found")
    else:
        print("No voice files found for visualization")

elif config['data_type'] == 'tabular':
    # For tabular data (ASD), no preprocessing is needed - data is already in CSV format
    print("Tabular data preprocessing:")
    print("- No preprocessing required for tabular data")
    print("- Data is loaded directly from CSV file")
    print(f"- ASD labels file: {config['labels_file']}")
    
    # Check if ASD labels file exists
    asd_file_path = Path(config['labels_file'])
    if asd_file_path.exists():
        # Load and show basic info about the ASD dataset
        asd_data = pd.read_csv(config['labels_file'])
        print(f"- Dataset shape: {asd_data.shape}")
        print(f"- Features: {list(asd_data.columns)}")
        print(f"- Target distribution: {asd_data['y'].value_counts().to_dict()}")
        print(f"- Sample data:")
        print(asd_data.head())
    else:
        print(f"Warning: ASD labels file not found at {config['labels_file']}")
        
else:
    print(f"Unknown data type: {config['data_type']}")


Tabular data preprocessing:
- No preprocessing required for tabular data
- Data is loaded directly from CSV file
- ASD labels file: ./data/Jingyi/labels.csv
- Dataset shape: (800, 20)
- Features: ['ID', 'A1_Score', 'A2_Score', 'A3_Score', 'A4_Score', 'A5_Score', 'A6_Score', 'A7_Score', 'A8_Score', 'A9_Score', 'A10_Score', 'age', 'gender', 'ethnicity', 'jaundice', 'austim', 'contry_of_res', 'result', 'age_desc', 'y']
- Target distribution: {0: 639, 1: 161}
- Sample data:
   ID  A1_Score  A2_Score  A3_Score  A4_Score  A5_Score  A6_Score  A7_Score  \
0   1         1         0         1         0         1         0         1   
1   2         0         0         0         0         0         0         0   
2   3         1         1         1         1         1         1         1   
3   4         0         0         0         0         0         0         0   
4   5         0         0         0         0         0         0         0   

   A8_Score  A9_Score  A10_Score        age gender

Now, let's process all images in the dataset:

In [4]:
if config['data_type'] == 'image':
    # Process all images
    input_dir = ori_image_dir
    output_dir = config['image_dir']
    target_size = (224, 224)  # Standard size for many CNN architectures

    process_all_images(input_dir, output_dir, target_size)

## Data Loading

Create data loaders for training, validation, and testing.

In [10]:
# Create data loaders based on data type
if config['data_type'] == 'image':
    data_loaders = create_data_loaders(
        data_dir=config['image_dir'],
        labels_file=config['labels_file'],
        task_type=config['task_type'],
        batch_size=config['batch_size'],
        train_ratio=config['train_ratio'],
        val_ratio=config['val_ratio'],
        test_ratio=config['test_ratio'],
        random_seed=config['random_seed']
    )
elif config['data_type'] == 'ECG':
    # For ECG data, use processed ECG data directory
    data_loaders = create_ecg_data_loaders(
        data_dir=config['image_dir'],
        labels_file=config['labels_file'],
        task_type=config['task_type'],
        batch_size=config['batch_size'],
        train_ratio=config['train_ratio'],
        val_ratio=config['val_ratio'],
        test_ratio=config['test_ratio'],
        random_seed=config['random_seed']
    )
elif config['data_type'] == 'voice':
    print('Creating voice data loaders')
    data_loaders = create_voice_data_loaders(
        data_dir=config['image_dir'],
        labels_file=config['labels_file'],
        task_type=config['task_type'],
        batch_size=config['batch_size'],
        train_ratio=config['train_ratio'],
        val_ratio=config['val_ratio'],
        test_ratio=config['test_ratio'],
        random_seed=config['random_seed'],
        target_length=config['voice_max_length']
    )
elif config['data_type'] == 'tabular':
    print('Creating ASD tabular data loaders')
    data_loaders_info = create_asd_data_loaders(
        labels_file=config['labels_file'],
        task_type=config['task_type'],
        batch_size=config['batch_size'],
        train_ratio=config['train_ratio'],
        val_ratio=config['val_ratio'],
        test_ratio=config['test_ratio'],
        random_seed=config['random_seed']
    )
    # Extract data loaders and feature dimension for tabular data
    data_loaders = {
        'train': data_loaders_info['train'],
        'val': data_loaders_info['val'],
        'test': data_loaders_info['test']
    }
    feature_dim = data_loaders_info['feature_dim']
    print(f"Feature dimension: {feature_dim}")

train_loader, val_loader, test_loader = data_loaders['train'], data_loaders['val'], data_loaders['test']

# Print dataset sizes
print(f"Training set size: {len(data_loaders['train'].dataset)}")
print(f"Validation set size: {len(data_loaders['val'].dataset)}")
print(f"Test set size: {len(data_loaders['test'].dataset)}")
print(f"Data type: {config['data_type']}")

# Show sample data shape
sample_data, sample_label = next(iter(train_loader))
print(f"Sample data shape: {sample_data.shape}")
print(f"Sample label shape: {sample_label.shape}")
if config['data_type'] == 'tabular':
    print(f"Sample features (first 5): {sample_data[0, :5]}")
    print(f"Sample labels (first 5): {sample_label[:5]}")
else:
    sample_data[:3,]

Creating ASD tabular data loaders
Feature dimension: 18
Training set size: 560
Validation set size: 120
Test set size: 120
Data type: tabular
Sample data shape: torch.Size([32, 18])
Sample label shape: torch.Size([32])
Sample features (first 5): tensor([ 0.8980,  0.9377, -0.8980, -0.8597, -0.8256])
Sample labels (first 5): tensor([0, 0, 0, 0, 0])


## Grid Search

Perform grid search to find the best model architecture and hyperparameters.

In [11]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Create universal hyperparameter tuner
print(f"Performing {config['data_type']} grid search...")

# Prepare parameters for different data types
if config['data_type'] == 'tabular':
    # For tabular data, use specific grid search parameters
    param_grid = {
        'hidden_layers': config['grid_search']['fc_layers'],
        'learning_rate': config['grid_search']['learning_rate']
    }
    
    tuner = HyperparameterTuner(
        train_loader=train_loader,
        val_loader=val_loader,
        task_type=config['task_type'],
        model_type=config['data_type'],
        num_classes=config['num_classes'],
        input_dim=feature_dim,  # Pass feature dimension for tabular data
        device=device,
        save_dir=Path(config['save_dir']) / 'grid_search'
    )
else:
    # For CNN models (image, ECG, voice)
    param_grid = {
        'num_conv_layers': config['grid_search']['num_conv_layers'],
        'conv_channels': config['grid_search']['conv_channels'],
        'fc_layers': config['grid_search']['fc_layers'],
        'learning_rate': config['grid_search']['learning_rate']
    }
    
    tuner = HyperparameterTuner(
        train_loader=train_loader,
        val_loader=val_loader,
        task_type=config['task_type'],
        model_type=config['data_type'],  # 'image' or 'ECG' or 'voice'
        num_classes=config['num_classes'],
        input_length=config.get('ecg_max_length', 5000) if config['data_type'] == 'ECG' else config.get('voice_max_length', 5000) if config['data_type'] == 'voice' else None, 
        device=device,
        save_dir=Path(config['save_dir']) / 'grid_search'
    )

# Perform grid search
grid_search_results = tuner.grid_search(
    param_grid=param_grid,
    num_epochs=config['num_epochs'],
    early_stopping_patience=config['early_stopping_patience']
)

# Plot results
tuner.plot_results()

# Print best combination
print('\nBest combination:')
print(json.dumps(grid_search_results['best_combination'], indent=2))
print('\nBest validation metrics:')
print(json.dumps(grid_search_results['best_val_metrics'], indent=2))


Using device: cpu
Performing tabular grid search...

Trying combination 1/2:
{
  "hidden_layers": [
    128,
    64
  ],
  "learning_rate": 0.01
}


Epochs:  10%|█         | 1/10 [00:44<06:44, 44.99s/it, Train Loss=0.5184, Val Loss=0.2637, Train Acc=0.784, Val Acc=0.858, LR=0.010000]


Epoch [1/10]
Train Loss: 0.5184, Train Metrics: {'accuracy': 0.7839285714285714, 'precision': 0.8062122070243365, 'recall': 0.7839285714285714, 'f1': 0.7929139882475614, 'aupr': 0.3500303653181351, 'auc': 0.7075757575757575}
Val Loss: 0.2637, Val Metrics: {'accuracy': 0.8583333333333333, 'precision': 0.8953125, 'recall': 0.8583333333333333, 'f1': 0.868867924528302, 'aupr': 0.5071428571428571, 'auc': 0.8578643578643578}
New best model saved with validation loss: 0.2637


Epochs:  20%|██        | 2/10 [01:29<05:57, 44.67s/it, Train Loss=0.3805, Val Loss=0.2366, Train Acc=0.846, Val Acc=0.867, LR=0.010000]


Epoch [2/10]
Train Loss: 0.3805, Train Metrics: {'accuracy': 0.8464285714285714, 'precision': 0.8381211180124223, 'recall': 0.8464285714285714, 'f1': 0.8410918822974734, 'aupr': 0.4252682100508187, 'auc': 0.7258585858585859}
Val Loss: 0.2366, Val Metrics: {'accuracy': 0.8666666666666667, 'precision': 0.8778947368421052, 'recall': 0.8666666666666667, 'f1': 0.8710891976692067, 'aupr': 0.47857142857142854, 'auc': 0.8066378066378067}
New best model saved with validation loss: 0.2366


Epochs:  30%|███       | 3/10 [02:13<05:11, 44.55s/it, Train Loss=0.3229, Val Loss=0.2356, Train Acc=0.834, Val Acc=0.908, LR=0.010000]


Epoch [3/10]
Train Loss: 0.3229, Train Metrics: {'accuracy': 0.8339285714285715, 'precision': 0.8232982486632382, 'recall': 0.8339285714285715, 'f1': 0.8270593310900711, 'aupr': 0.3887111484021596, 'auc': 0.7009090909090909}
Val Loss: 0.2356, Val Metrics: {'accuracy': 0.9083333333333333, 'precision': 0.9258152173913043, 'recall': 0.9083333333333333, 'f1': 0.9132011967090501, 'aupr': 0.6306122448979592, 'auc': 0.9069264069264068}
New best model saved with validation loss: 0.2356


Epochs:  40%|████      | 4/10 [02:58<04:27, 44.50s/it, Train Loss=0.3306, Val Loss=0.2347, Train Acc=0.832, Val Acc=0.892, LR=0.010000]


Epoch [4/10]
Train Loss: 0.3306, Train Metrics: {'accuracy': 0.8321428571428572, 'precision': 0.8266614906832298, 'recall': 0.8321428571428572, 'f1': 0.8290685504971219, 'aupr': 0.3986753246753246, 'auc': 0.716969696969697}
Val Loss: 0.2347, Val Metrics: {'accuracy': 0.8916666666666667, 'precision': 0.8986979166666667, 'recall': 0.8916666666666667, 'f1': 0.8944444444444445, 'aupr': 0.5496031746031745, 'auc': 0.8405483405483406}
New best model saved with validation loss: 0.2347


Epochs:  50%|█████     | 5/10 [03:43<03:43, 44.65s/it, Train Loss=0.3081, Val Loss=0.2409, Train Acc=0.848, Val Acc=0.850, LR=0.010000]


Epoch [5/10]
Train Loss: 0.3081, Train Metrics: {'accuracy': 0.8482142857142857, 'precision': 0.8389508848698, 'recall': 0.8482142857142857, 'f1': 0.841935947770495, 'aupr': 0.4265121114840216, 'auc': 0.7235353535353535}
Val Loss: 0.2409, Val Metrics: {'accuracy': 0.85, 'precision': 0.8762220538082608, 'recall': 0.85, 'f1': 0.8588421052631579, 'aupr': 0.4620279146141215, 'auc': 0.8152958152958154}


Epochs:  60%|██████    | 6/10 [04:28<02:59, 44.99s/it, Train Loss=0.2807, Val Loss=0.2448, Train Acc=0.854, Val Acc=0.867, LR=0.010000]


Epoch [6/10]
Train Loss: 0.2807, Train Metrics: {'accuracy': 0.8535714285714285, 'precision': 0.8497929739581422, 'recall': 0.8535714285714285, 'f1': 0.851453685122956, 'aupr': 0.45691685765215173, 'auc': 0.7543434343434343}
Val Loss: 0.2448, Val Metrics: {'accuracy': 0.8666666666666667, 'precision': 0.8778947368421052, 'recall': 0.8666666666666667, 'f1': 0.8710891976692067, 'aupr': 0.47857142857142854, 'auc': 0.8066378066378067}


Epochs:  70%|███████   | 7/10 [05:13<02:14, 44.94s/it, Train Loss=0.3205, Val Loss=0.2424, Train Acc=0.841, Val Acc=0.875, LR=0.010000]


Epoch [7/10]
Train Loss: 0.3205, Train Metrics: {'accuracy': 0.8410714285714286, 'precision': 0.8462313496827303, 'recall': 0.8410714285714286, 'f1': 0.8433918095490148, 'aupr': 0.4457601222307105, 'auc': 0.7637373737373736}
Val Loss: 0.2424, Val Metrics: {'accuracy': 0.875, 'precision': 0.877435064935065, 'recall': 0.875, 'f1': 0.8761362294888443, 'aupr': 0.48257575757575755, 'auc': 0.7929292929292929}


Epochs:  80%|████████  | 8/10 [05:58<01:29, 44.81s/it, Train Loss=0.3044, Val Loss=0.2388, Train Acc=0.841, Val Acc=0.858, LR=0.010000]


Epoch [8/10]
Train Loss: 0.3044, Train Metrics: {'accuracy': 0.8410714285714286, 'precision': 0.828270150556163, 'recall': 0.8410714285714286, 'f1': 0.8314867564144781, 'aupr': 0.3972763347763348, 'auc': 0.6984848484848485}
Val Loss: 0.2388, Val Metrics: {'accuracy': 0.8583333333333333, 'precision': 0.8733019639934534, 'recall': 0.8583333333333333, 'f1': 0.8640337338771912, 'aupr': 0.46208791208791206, 'auc': 0.8015873015873016}
Epoch 00008: reducing learning rate of group 0 to 5.0000e-03.


Epochs:  90%|█████████ | 9/10 [06:43<00:44, 44.85s/it, Train Loss=0.3034, Val Loss=0.2460, Train Acc=0.843, Val Acc=0.867, LR=0.005000]


Epoch [9/10]
Train Loss: 0.3034, Train Metrics: {'accuracy': 0.8428571428571429, 'precision': 0.8377950310559005, 'recall': 0.8428571428571429, 'f1': 0.8399790685504972, 'aupr': 0.4257727272727273, 'auc': 0.733939393939394}
Val Loss: 0.2460, Val Metrics: {'accuracy': 0.8666666666666667, 'precision': 0.8988039144617617, 'recall': 0.8666666666666667, 'f1': 0.8759410801963993, 'aupr': 0.5226958525345622, 'auc': 0.862914862914863}


Epochs: 100%|██████████| 10/10 [07:27<00:00, 44.76s/it, Train Loss=0.2914, Val Loss=0.2364, Train Acc=0.846, Val Acc=0.867, LR=0.005000]



Epoch [10/10]
Train Loss: 0.2914, Train Metrics: {'accuracy': 0.8464285714285714, 'precision': 0.8453873437280517, 'recall': 0.8464285714285714, 'f1': 0.8458943719069028, 'aupr': 0.44523809523809527, 'auc': 0.7533333333333334}
Val Loss: 0.2364, Val Metrics: {'accuracy': 0.8666666666666667, 'precision': 0.871985656656208, 'recall': 0.8666666666666667, 'f1': 0.8690166975881262, 'aupr': 0.46413043478260874, 'auc': 0.7878787878787877}

Trying combination 2/2:
{
  "hidden_layers": [
    128,
    32
  ],
  "learning_rate": 0.01
}


Epochs:  10%|█         | 1/10 [00:44<06:43, 44.89s/it, Train Loss=0.5080, Val Loss=0.2540, Train Acc=0.759, Val Acc=0.867, LR=0.010000]


Epoch [1/10]
Train Loss: 0.5080, Train Metrics: {'accuracy': 0.7589285714285714, 'precision': 0.7746147155781321, 'recall': 0.7589285714285714, 'f1': 0.7659113869000409, 'aupr': 0.29412815866304237, 'auc': 0.6508080808080807}
Val Loss: 0.2540, Val Metrics: {'accuracy': 0.8666666666666667, 'precision': 0.8913224706328154, 'recall': 0.8666666666666667, 'f1': 0.8745263157894737, 'aupr': 0.5078817733990147, 'auc': 0.8441558441558441}
New best model saved with validation loss: 0.2540


Epochs:  20%|██        | 2/10 [01:29<05:57, 44.65s/it, Train Loss=0.3032, Val Loss=0.2476, Train Acc=0.839, Val Acc=0.850, LR=0.010000]


Epoch [2/10]
Train Loss: 0.3032, Train Metrics: {'accuracy': 0.8392857142857143, 'precision': 0.834083850931677, 'recall': 0.8392857142857143, 'f1': 0.8363422291993721, 'aupr': 0.4165584415584415, 'auc': 0.7282828282828283}
Val Loss: 0.2476, Val Metrics: {'accuracy': 0.85, 'precision': 0.8920062695924765, 'recall': 0.85, 'f1': 0.8618279569892473, 'aupr': 0.4925324675324675, 'auc': 0.8528138528138529}
New best model saved with validation loss: 0.2476


Epochs:  30%|███       | 3/10 [02:14<05:12, 44.68s/it, Train Loss=0.3150, Val Loss=0.2433, Train Acc=0.836, Val Acc=0.867, LR=0.010000]


Epoch [3/10]
Train Loss: 0.3150, Train Metrics: {'accuracy': 0.8357142857142857, 'precision': 0.8204498850431865, 'recall': 0.8357142857142857, 'f1': 0.8236891358049055, 'aupr': 0.37673444976076553, 'auc': 0.6814141414141415}
Val Loss: 0.2433, Val Metrics: {'accuracy': 0.8666666666666667, 'precision': 0.8988039144617617, 'recall': 0.8666666666666667, 'f1': 0.8759410801963993, 'aupr': 0.5226958525345622, 'auc': 0.862914862914863}
New best model saved with validation loss: 0.2433


Epochs:  40%|████      | 4/10 [02:58<04:27, 44.58s/it, Train Loss=0.3189, Val Loss=0.2439, Train Acc=0.857, Val Acc=0.867, LR=0.010000]


Epoch [4/10]
Train Loss: 0.3189, Train Metrics: {'accuracy': 0.8571428571428571, 'precision': 0.8552430031941295, 'recall': 0.8571428571428571, 'f1': 0.856136161445896, 'aupr': 0.47156946826758145, 'auc': 0.7668686868686868}
Val Loss: 0.2439, Val Metrics: {'accuracy': 0.8666666666666667, 'precision': 0.8988039144617617, 'recall': 0.8666666666666667, 'f1': 0.8759410801963993, 'aupr': 0.5226958525345622, 'auc': 0.862914862914863}


Epochs:  50%|█████     | 5/10 [03:42<03:42, 44.49s/it, Train Loss=0.3123, Val Loss=0.2449, Train Acc=0.832, Val Acc=0.867, LR=0.010000]


Epoch [5/10]
Train Loss: 0.3123, Train Metrics: {'accuracy': 0.8321428571428572, 'precision': 0.8246356732348111, 'recall': 0.8321428571428572, 'f1': 0.8277245316345825, 'aupr': 0.39339826839826836, 'auc': 0.7101010101010101}
Val Loss: 0.2449, Val Metrics: {'accuracy': 0.8666666666666667, 'precision': 0.871985656656208, 'recall': 0.8666666666666667, 'f1': 0.8690166975881262, 'aupr': 0.46413043478260874, 'auc': 0.7878787878787877}


Epochs:  60%|██████    | 6/10 [04:27<02:57, 44.47s/it, Train Loss=0.3122, Val Loss=0.2404, Train Acc=0.845, Val Acc=0.867, LR=0.010000]


Epoch [6/10]
Train Loss: 0.3122, Train Metrics: {'accuracy': 0.8446428571428571, 'precision': 0.8312999112688554, 'recall': 0.8446428571428571, 'f1': 0.8336823475370856, 'aupr': 0.4023022432113341, 'auc': 0.6972727272727273}
Val Loss: 0.2404, Val Metrics: {'accuracy': 0.8666666666666667, 'precision': 0.8913224706328154, 'recall': 0.8666666666666667, 'f1': 0.8745263157894737, 'aupr': 0.5078817733990147, 'auc': 0.8441558441558441}
New best model saved with validation loss: 0.2404


Epochs:  70%|███████   | 7/10 [05:12<02:13, 44.58s/it, Train Loss=0.3074, Val Loss=0.2377, Train Acc=0.832, Val Acc=0.883, LR=0.010000]


Epoch [7/10]
Train Loss: 0.3074, Train Metrics: {'accuracy': 0.8321428571428572, 'precision': 0.8246356732348111, 'recall': 0.8321428571428572, 'f1': 0.8277245316345825, 'aupr': 0.39339826839826836, 'auc': 0.7101010101010101}
Val Loss: 0.2377, Val Metrics: {'accuracy': 0.8833333333333333, 'precision': 0.9064228874573703, 'recall': 0.8833333333333333, 'f1': 0.8902105263157895, 'aupr': 0.5570197044334976, 'auc': 0.873015873015873}
New best model saved with validation loss: 0.2377


Epochs:  80%|████████  | 8/10 [05:56<01:29, 44.58s/it, Train Loss=0.2888, Val Loss=0.2405, Train Acc=0.832, Val Acc=0.867, LR=0.010000]


Epoch [8/10]
Train Loss: 0.2888, Train Metrics: {'accuracy': 0.8321428571428572, 'precision': 0.8217916244511989, 'recall': 0.8321428571428572, 'f1': 0.8255745341614907, 'aupr': 0.385523088023088, 'auc': 0.6997979797979799}
Val Loss: 0.2405, Val Metrics: {'accuracy': 0.8666666666666667, 'precision': 0.8666666666666667, 'recall': 0.8666666666666667, 'f1': 0.8666666666666667, 'aupr': 0.44988662131519275, 'auc': 0.7691197691197692}


Epochs:  90%|█████████ | 9/10 [06:41<00:44, 44.61s/it, Train Loss=0.2856, Val Loss=0.2416, Train Acc=0.834, Val Acc=0.858, LR=0.010000]


Epoch [9/10]
Train Loss: 0.2856, Train Metrics: {'accuracy': 0.8339285714285715, 'precision': 0.8224418257497649, 'recall': 0.8339285714285715, 'f1': 0.8263031920535316, 'aupr': 0.3861210628452008, 'auc': 0.6974747474747475}
Val Loss: 0.2416, Val Metrics: {'accuracy': 0.8583333333333333, 'precision': 0.8610621521335807, 'recall': 0.8583333333333333, 'f1': 0.8596210600873567, 'aupr': 0.4324675324675325, 'auc': 0.764069264069264}


Epochs: 100%|██████████| 10/10 [07:26<00:00, 44.60s/it, Train Loss=0.2744, Val Loss=0.2374, Train Acc=0.861, Val Acc=0.875, LR=0.010000]



Epoch [10/10]
Train Loss: 0.2744, Train Metrics: {'accuracy': 0.8607142857142858, 'precision': 0.859773727583462, 'recall': 0.8607142857142858, 'f1': 0.8602297791713769, 'aupr': 0.4838864838864839, 'auc': 0.7759595959595958}
Val Loss: 0.2374, Val Metrics: {'accuracy': 0.875, 'precision': 0.877435064935065, 'recall': 0.875, 'f1': 0.8761362294888443, 'aupr': 0.48257575757575755, 'auc': 0.7929292929292929}
New best model saved with validation loss: 0.2374

Best combination:
{
  "hidden_layers": [
    128,
    64
  ],
  "learning_rate": 0.01
}

Best validation metrics:
{
  "accuracy": 0.8916666666666667,
  "precision": 0.8986979166666667,
  "recall": 0.8916666666666667,
  "f1": 0.8944444444444445,
  "aupr": 0.5496031746031745,
  "auc": 0.8405483405483406
}


## Train Best Model

Train the model with the best hyperparameters found during grid search.

In [14]:
# Get best parameters
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

grid_search_res_path = 'results/grid_search/grid_search_summary.json'
with open(grid_search_res_path, 'r') as f:
    grid_search_results = json.load(f)

best_params = grid_search_results['best_combination']

# Create model with best parameters based on data type
if config['data_type'] == 'image':
    model = MedicalCNN(
        task_type=config['task_type'],
        num_classes=config['num_classes'],
        num_conv_layers=best_params['num_conv_layers'],
        conv_channels=best_params['conv_channels'],
        fc_layers=best_params['fc_layers']
    )
elif config['data_type'] == 'ECG':
    model = ECG1DCNN(
        task_type=config['task_type'],
        num_classes=config['num_classes'],
        input_length=config['ecg_max_length'],
        num_conv_layers=best_params['num_conv_layers'],
        conv_channels=best_params['conv_channels'],
        fc_layers=best_params['fc_layers']
    )
elif config['data_type'] == 'voice':
    # Create voice model with best parameters
    model = Voice1DCNN(
        task_type=config['task_type'],
        num_classes=config['num_classes'],
        input_length=config['voice_max_length'],
        num_conv_layers=best_params['num_conv_layers'],
        conv_channels=best_params['conv_channels'],
        fc_layers=best_params['fc_layers']
    )
elif config['data_type'] == 'tabular':
    # Create ASD tabular model with best parameters
    model = ASDTabularModel(
        task_type=config['task_type'],
        input_dim=feature_dim,
        num_classes=config['num_classes'],
        hidden_layers=best_params['hidden_layers'],
    )

# Define loss function and optimizer
if config['task_type'] == 'classification':
    criterion = nn.CrossEntropyLoss()
else:
    criterion = nn.MSELoss()

optimizer = optim.Adam(model.parameters(), lr=best_params['learning_rate'])

# Initialize trainer
trainer = ModelTrainer(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    task_type=config['task_type']
)

# Check if best model exists, if so, load it; otherwise, train and save
from pathlib import Path
best_model_path = Path(config['save_dir']) / 'best_model' / 'best_model.pth'
if best_model_path.exists():
    print(f"Found existing best model at {best_model_path}, loading...")
    trainer.load_model(str(best_model_path))
    history = None  # No new training history
else:
    print("No existing best model found, training...")
    history = trainer.train(
        train_loader=train_loader,
        val_loader=val_loader,
        num_epochs=config['num_epochs'],
        save_dir=Path(config['save_dir']) / 'best_model',
        early_stopping_patience=15
    )
    # Plot training history
    trainer.plot_training_history(Path(config['save_dir']) / 'best_model')

# Set class names for classification tasks
if config['task_type'] == 'classification':
    # Define class name mapping
    class_names = class_nms # Modify according to your specific classes
    # Or you can set it according to your needs, for example:
    # class_names = {0: 'aa', 1: 'bb'}
    
    # Set class names to trainer
    trainer.set_class_names(class_names)
    print(f"Class names set: {class_names}")

# Save the trainer state for later use
best_trainer = trainer


Using device: cpu
Found existing best model at results/best_model/best_model.pth, loading...
Class names set: {0: 'Normal', 1: 'Abnormal'}


## Evaluation

Evaluate the best model on the test set.

In [7]:
# Evaluate on test set
print("Evaluating model on test set...")
test_metrics = trainer.evaluate(test_loader, Path(config['save_dir']) / 'best_model')
print('\nTest Set Metrics:')
print(json.dumps(test_metrics, indent=2))

# Print some debug information
print(f"\nModel device: {next(trainer.model.parameters()).device}")
print(f"Task type: {trainer.task_type}")
print(f"Test set size: {len(test_loader.dataset)}")

Evaluating model on test set...


Evaluating: 100%|██████████████████████████████| 4/4 [00:22<00:00,  5.63s/it, Samples=120/120]


Test Set Metrics:
{
  "accuracy": 0.875,
  "precision": 0.8737211064797272,
  "recall": 0.875,
  "f1": 0.874285981833505,
  "aupr": 0.6229885057471264,
  "auc": 0.8277777777777778
}

Model device: cpu
Task type: classification
Test set size: 120





In [15]:
# Performance comparison with traditional machine learning methods
print("="*60)
print("STARTING MODEL COMPARISON WITH TRADITIONAL ML METHODS")
print("="*60)

# Import required modules for comparison

from model import compare_models_performance

# DEBUG CONTROL - Set to True for detailed debugging information
DEBUG_MODE = False  # Change to True if you want to see detailed debug information

# Set class names for comparison (if classification task)
comparison_class_names = None
if config['task_type'] == 'classification':
    comparison_class_names = class_names  # Modify according to your classes

print(f"Configuration:")
print(f"  Task type: {config['task_type']}")
print(f"  Train set size: {len(train_loader.dataset)}")
print(f"  Validation set size: {len(val_loader.dataset)}")
print(f"  Test set size: {len(test_loader.dataset)}")
print(f"  Class names: {comparison_class_names}")
print(f"  Debug mode: {DEBUG_MODE}")
print(f"  Save directory: ./results/model_evaluation")

try:
    # Compare models performance
    comparison_results = compare_models_performance(
        best_cnn_trainer=trainer,
        train_loader=train_loader,
        val_loader=val_loader, 
        test_loader=test_loader,
        save_dir='./results/model_evaluation',
        task_type=config['task_type'],
        class_names=comparison_class_names,
        debug=DEBUG_MODE  # Control debug output
    )

    # Display comparison results
    print("\n" + "="*50)
    print("MODEL PERFORMANCE COMPARISON RESULTS")
    print("="*50)

    for model_name, metrics in comparison_results.items():
        print(f"\n{model_name}:")
        for metric_name, value in metrics.items():
            print(f"  {metric_name.upper()}: {value:.4f}")

    print(f"\nAll comparison plots saved to: ./results/model_evaluation/")
    print("Generated files:")
    if config['task_type'] == 'classification':
        print("- aupr_comparison.png (AUPR curves)")
        print("- auc_comparison.png (ROC curves)")
        print("- accuracy_comparison.png")
        print("- precision_comparison.png") 
        print("- recall_comparison.png")
        print("- f1_comparison.png")
    else:
        print("- mse_comparison.png")
        print("- mae_comparison.png")
        print("- r2_comparison.png")
    print("- model_comparison_results.json")
    
except Exception as e:
    print(f"Error during model comparison: {e}")
    if DEBUG_MODE:
        import traceback
        traceback.print_exc()

STARTING MODEL COMPARISON WITH TRADITIONAL ML METHODS
Configuration:
  Task type: classification
  Train set size: 560
  Validation set size: 120
  Test set size: 120
  Class names: {0: 'Normal', 1: 'Abnormal'}
  Debug mode: False
  Save directory: ./results/model_evaluation


                                                                                       

Model comparison results saved to ./results/model_evaluation

MODEL PERFORMANCE COMPARISON RESULTS

Deep Learning (CNN):
  ACCURACY: 0.8833
  PRECISION: 0.8833
  RECALL: 0.8833
  F1: 0.8833
  AUPR: 0.6922
  AUC: 0.9148

Random Forest:
  ACCURACY: 0.8583
  PRECISION: 0.8539
  RECALL: 0.8583
  F1: 0.8492
  AUPR: 0.7876
  AUC: 0.9319

SVM:
  ACCURACY: 0.8917
  PRECISION: 0.8892
  RECALL: 0.8917
  F1: 0.8897
  AUPR: 0.7409
  AUC: 0.9078

Logistic Regression:
  ACCURACY: 0.8667
  PRECISION: 0.8625
  RECALL: 0.8667
  F1: 0.8634
  AUPR: 0.7576
  AUC: 0.9300

All comparison plots saved to: ./results/model_evaluation/
Generated files:
- aupr_comparison.png (AUPR curves)
- auc_comparison.png (ROC curves)
- accuracy_comparison.png
- precision_comparison.png
- recall_comparison.png
- f1_comparison.png
- model_comparison_results.json
