In [5]:
import h5py
import os
from utils.datasets_statistics import get_channel_index_for_dataset
from sklearn.utils.class_weight import compute_class_weight
from utils.utils import *
from utils.dataloader import get_dataloader, get_balanced_dataloader
from utils.model_training import *
from utils.tiny_sleep_net import TinySleepNet
import torch
import torch.nn as nn
from collections import Counter
import random
import numpy as np 


def main(dataset_path, model_save_dir, model_save_name, dataset_index, nch_dataset_path):
    all_combined_accuracies = [] 
    n_epochs = 10
    patience = 25
    
    for seed in [42, 43, 44, 45, 46]:
        set_seed(seed)
        best_val_loss = float("inf")
        epochs_no_improve = 0

        channel_name = 'F3-C3'
        batch_size = 64 #for training later
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        model_save_dir = model_save_dir
        os.makedirs(model_save_dir, exist_ok = True)
        model_save_name = f'{model_save_name}_{seed}.pth'

        datasets_to_merge = [('chb', True), ('helsinki', True), ('nch', False), ('sienna', True)] 

        all_X_train, all_y_train = [], []
        all_X_val, all_y_val = [], []
        all_X_test, all_y_test = [], []

        train_names, val_names, test_names = [], [], []  

        all_data = [] 

        datasets_to_merge = datasets_to_merge[dataset_index:dataset_index + 1]

        for data in datasets_to_merge:
            dataset_name = data[0]
            scale_to_nch = data[1]
            print ('Dealing with:', dataset_name)
            file = f'{dataset_path}'

            channel_index = get_channel_index_for_dataset(dataset_name, channel_name)

            print(f"Loading {dataset_name} dataset and selecting {channel_name} (index: {channel_index})...")

            X_sub, y_sub = load_patientwise_file(file, channel_index, scale_to_nch, nch_dataset_path)

            (X_train, y_train), (X_val, y_val), (X_test, y_test) = stratified_train_val_test_split(X_sub, y_sub, random_state = seed)

            # Append dataset splits
            all_X_train.append(X_train)
            all_y_train.append(y_train)

            all_X_val.append(X_val)
            all_y_val.append(y_val)

            all_X_test.append(X_test)
            all_y_test.append(y_test)

            all_data.append(X_train)
            all_data.append(X_val)
            all_data.append(X_test)

            # Append corresponding dataset names
            train_names.extend([dataset_name] * len(X_train))  
            val_names.extend([dataset_name] * len(X_val))
            test_names.extend([dataset_name] * len(X_test))

        X_train = np.concatenate(all_X_train, axis=0)
        y_train = np.concatenate(all_y_train, axis=0)

        X_val = np.concatenate(all_X_val, axis=0)
        y_val = np.concatenate(all_y_val, axis=0)

        X_test = np.concatenate(all_X_test, axis=0)
        y_test = np.concatenate(all_y_test, axis=0)

        all_data = np.concatenate(all_data, axis = 0)

        global_mean, global_sd = np.mean(all_data, axis = None), np.std(all_data, axis = None)

        X_train, X_val, X_test = standardize_data(X_train, global_mean, global_sd), standardize_data(X_val, global_mean, global_sd), standardize_data(X_test, global_mean, global_sd)

        train_names = np.array(train_names)
        val_names = np.array(val_names)
        test_names = np.array(test_names)

        unique_labels = np.unique(y_train)
        class_weights = compute_class_weight(class_weight='balanced', classes=unique_labels, y=y_train)
        class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to(device)

        train_dataloader = get_balanced_dataloader(X_train, y_train, train_names, batch_size, shuffle = True)
        val_dataloader = get_dataloader(X_val, y_val, batch_size, shuffle = False)
        test_dataloader = get_dataloader(X_test, y_test, batch_size, shuffle = False)

        sleep_model = TinySleepNet(num_classes = 2, Fs = 12, kernel_size = 4).to(device)    

        print ("Training from scratch...")

        loss = nn.CrossEntropyLoss(weight = class_weights_tensor) 
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, sleep_model.parameters()), lr=4e-5, weight_decay=1e-6)    


        for epoch in range(n_epochs):
            train_loss, train_acc = train(sleep_model, device, train_dataloader, loss, optimizer)
            val_loss, val_acc = validate(sleep_model, device, val_dataloader, loss)
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                epochs_no_improve = 0 
                print ("best model saved...")
                torch.save(sleep_model.state_dict(), f"{model_save_dir}/{model_save_name}")
            else:
                epochs_no_improve += 1 

            if epochs_no_improve >= patience:
                print(f"Early stopping at epoch {epoch+1}. No improvement for {patience} epochs.")
                break

            print(f"Epoch {epoch+1}/{n_epochs}\nTrain Loss: {train_loss} | Train Accuracy: {train_acc}")
            print(f"Val Loss: {val_loss} | Val Accuracy: {val_acc}")
            print ("-----------------------------")

        test_loss, test_acc  = test(sleep_model, device, test_dataloader, loss)
        print(f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}%")
        all_combined_accuracies.append(test_acc)
    
        print (f"SEED = {seed} DONE")
        print ("===========")
        
    all_combined_accuracies = np.array(all_combined_accuracies)
    print (all_combined_accuracies)
    print ("Mean Accuracy from 5 runs:", np.mean(all_combined_accuracies))
    print ("SD of Accuracy from 5 runs:", np.std(all_combined_accuracies))
    

path = '../DATA/helsinki_patientwise.h5'
model_save_dir = 'saved_models_local'
dataset_index = 1 
nch_dataset_path = '../DATA/nch_patientwise.h5'
model_save_name = 'helsinki'
main(path, model_save_dir, model_save_name, dataset_index, nch_dataset_path)

Dealing with: helsinki
Loading helsinki dataset and selecting F3-C3 (index: 2)...
Training from scratch...
best model saved...
Epoch 1/10
Train Loss: 0.5556192168822656 | Train Accuracy: 83.16203143893591
Val Loss: 0.34800218160335833 | Val Accuracy: 90.20556227327691
-----------------------------
best model saved...
Epoch 2/10
Train Loss: 0.2561234768050221 | Train Accuracy: 89.20798065296252
Val Loss: 0.2529312876554636 | Val Accuracy: 90.44740024183797
-----------------------------
best model saved...
Epoch 3/10
Train Loss: 0.23854990367992565 | Train Accuracy: 89.29866989117292
Val Loss: 0.23232682622396028 | Val Accuracy: 91.17291414752115
-----------------------------
best model saved...
Epoch 4/10
Train Loss: 0.22188698493230802 | Train Accuracy: 90.43228536880291
Val Loss: 0.22213270114018366 | Val Accuracy: 90.44740024183797
-----------------------------
best model saved...
Epoch 5/10
Train Loss: 0.2126688057413468 | Train Accuracy: 90.85550181378477
Val Loss: 0.21282552870420

In [10]:
from utils.utils import *
import h5py
import numpy as np



min_val, max_val = nch_min_max(7, '../DATA/nch_patientwise.h5')


print (min_val, max_val)

## Train All

In [None]:
import h5py
import os
from utils.datasets_statistics import get_channel_index_for_dataset
from sklearn.utils.class_weight import compute_class_weight
from utils.utils import *
from utils.dataloader import get_dataloader, get_balanced_dataloader
from utils.model_training import *
from utils.tiny_sleep_net import TinySleepNet
import torch
import torch.nn as nn
from collections import Counter
import random
import numpy as np 


def main(chb, helsinki, nch, sienna, model_save_dir, model_save_name):
    all_combined_accuracies = [] 
    n_epochs = 10
    patience = 25
    
    for seed in [42, 43, 44, 45, 46]:
        set_seed(seed)
        best_val_loss = float("inf")
        epochs_no_improve = 0

        channel_name = 'F3-C3'
        batch_size = 64 #for training later
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        model_save_dir = model_save_dir
        os.makedirs(model_save_dir, exist_ok = True)
        model_save_name = f'{model_save_name}_{seed}.pth'

        datasets_to_merge = [('chb', True), ('helsinki', True), ('nch', False), ('sienna', True)] 
        datasets_paths = [chb, helsinki, nch, sienna]

        all_X_train, all_y_train = [], []
        all_X_val, all_y_val = [], []
        all_X_test, all_y_test = [], []

        train_names, val_names, test_names = [], [], []  

        all_data = [] 

        for idx, data in enumerate(datasets_to_merge):
            dataset_name = data[0]
            scale_to_nch = data[1]
            print ('Dealing with:', dataset_name)
            file = f'{dataset_paths[idx]}'

            channel_index = get_channel_index_for_dataset(dataset_name, channel_name)

            print(f"Loading {dataset_name} dataset and selecting {channel_name} (index: {channel_index})...")

            X_sub, y_sub = load_patientwise_file(file, channel_index, scale_to_nch, nch)

            (X_train, y_train), (X_val, y_val), (X_test, y_test) = stratified_train_val_test_split(X_sub, y_sub, random_state = seed)

            # Append dataset splits
            all_X_train.append(X_train)
            all_y_train.append(y_train)

            all_X_val.append(X_val)
            all_y_val.append(y_val)

            all_X_test.append(X_test)
            all_y_test.append(y_test)

            all_data.append(X_train)
            all_data.append(X_val)
            all_data.append(X_test)

            # Append corresponding dataset names
            train_names.extend([dataset_name] * len(X_train))  
            val_names.extend([dataset_name] * len(X_val))
            test_names.extend([dataset_name] * len(X_test))

        X_train = np.concatenate(all_X_train, axis=0)
        y_train = np.concatenate(all_y_train, axis=0)

        X_val = np.concatenate(all_X_val, axis=0)
        y_val = np.concatenate(all_y_val, axis=0)

        X_test = np.concatenate(all_X_test, axis=0)
        y_test = np.concatenate(all_y_test, axis=0)

        all_data = np.concatenate(all_data, axis = 0)

        global_mean, global_sd = np.mean(all_data, axis = None), np.std(all_data, axis = None)

        X_train, X_val, X_test = standardize_data(X_train, global_mean, global_sd), standardize_data(X_val, global_mean, global_sd), standardize_data(X_test, global_mean, global_sd)

        train_names = np.array(train_names)
        val_names = np.array(val_names)
        test_names = np.array(test_names)

        unique_labels = np.unique(y_train)
        class_weights = compute_class_weight(class_weight='balanced', classes=unique_labels, y=y_train)
        class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to(device)

        train_dataloader = get_balanced_dataloader(X_train, y_train, train_names, batch_size, shuffle = True)
        val_dataloader = get_dataloader(X_val, y_val, batch_size, shuffle = False)
        test_dataloader = get_dataloader(X_test, y_test, batch_size, shuffle = False)

        sleep_model = TinySleepNet(num_classes = 2, Fs = 12, kernel_size = 4).to(device)    

        print ("Training from scratch...")

        loss = nn.CrossEntropyLoss(weight = class_weights_tensor) 
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, sleep_model.parameters()), lr=4e-5, weight_decay=1e-6)    


        for epoch in range(n_epochs):
            train_loss, train_acc = train(sleep_model, device, train_dataloader, loss, optimizer)
            val_loss, val_acc = validate(sleep_model, device, val_dataloader, loss)
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                epochs_no_improve = 0 
                print ("best model saved...")
                torch.save(sleep_model.state_dict(), f"{model_save_dir}/{model_save_name}")
            else:
                epochs_no_improve += 1 

            if epochs_no_improve >= patience:
                print(f"Early stopping at epoch {epoch+1}. No improvement for {patience} epochs.")
                break

            print(f"Epoch {epoch+1}/{n_epochs}\nTrain Loss: {train_loss} | Train Accuracy: {train_acc}")
            print(f"Val Loss: {val_loss} | Val Accuracy: {val_acc}")
            print ("-----------------------------")

        test_loss, test_acc  = test(sleep_model, device, test_dataloader, loss)
        print(f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}%")
        all_combined_accuracies.append(test_acc)
    
        print (f"SEED = {seed} DONE")
        print ("===========")
        
    all_combined_accuracies = np.array(all_combined_accuracies)
    print (all_combined_accuracies)
    print ("Mean Accuracy from 5 runs:", np.mean(all_combined_accuracies))
    print ("SD of Accuracy from 5 runs:", np.std(all_combined_accuracies))
    

chb_path = '../DATA/chb_patientwise.h5'
helsinki_path = '../DATA/helsinki_patientwise.h5'
nch_path = '../DATA/nch_patientwise.h5'
sienna_path = '../DATA/siena_patientwise.h5'
model_save_dir = 'saved_models_local'
model_save_name = 'train_all'
main(chb_path, helsinki_path, nch_path, sienna_path, model_save_dir, model_save_name)

## Test Individual

In [None]:
import argparse
import h5py
import os
import logging
from utils.datasets_statistics import get_channel_index_for_dataset
from sklearn.utils.class_weight import compute_class_weight
from utils.utils import *
from utils.dataloader import get_dataloader, get_balanced_dataloader
from utils.model_training import *
from utils.tiny_sleep_net import TinySleepNet
import torch
import torch.nn as nn
import numpy as np



def setup_logger(log_path):
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    logger.handlers.clear()  # prevent duplicate logs
    
    formatter = logging.Formatter('%(message)s')
    console = logging.StreamHandler()
    console.setFormatter(formatter)
    logger.addHandler(console)

    fh = logging.FileHandler(log_path)
    fh.setFormatter(formatter)
    logger.addHandler(fh)



all_test_accuracies = [] 
batch_size = 64 

def main(chb, helsinki, nch, sienna, test_index, use_index, model_save_dir, model_save_name):

    for seed in [42,43,44,45,46]:
        print ("SEED = ", seed)
        channel_name = 'F3-C3'
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        set_seed(seed)

        datasets_to_merge = [('chb', True), ('helsinki', True), ('nch', False), ('sienna', True)]
        dataset_paths = [chb, helsinki, nch, sienna]
        
        TEST_ON = datasets_to_merge[test_index]
        USE = datasets_to_merge[use_index]


        for idx, data in enumerate(datasets_to_merge):
            dataset_name = data[0]
            scale_to_nch = data[1]

            if dataset_name not in [TEST_ON]:
                print ("No need for", dataset_name)
                continue

            print ('Testing On:', dataset_name)
            file = f'{dataset_paths[idx]'

            channel_index = get_channel_index_for_dataset(dataset_name, channel_name)

            X_sub, y_sub = load_patientwise_file(file, channel_index, scale_to_nch)

            if (dataset_name == TEST_ON):
                save_test_data = X_sub
                save_test_labels = y_sub


        test_mean = np.mean(save_test_data, axis = None)
        test_sd = np.std(save_test_data, axis = None)

        standardized_test_data = standardize_data(save_test_data, test_mean, test_sd)

        test_dataloader = get_dataloader(standardized_test_data, save_test_labels, batch_size, shuffle = False)

        sleep_model = TinySleepNet(num_classes = 2, Fs = 12, kernel_size = 4).to(device) 

        model_name = f'{model_save_name}_{seed}'

        model_save_name = f'{model_save_dir}/{model_name}.pth'
        sleep_model.load_state_dict(torch.load(model_save_name))

        loss = nn.CrossEntropyLoss() 

        print ("Testing...")
        test_loss, test_acc = test(sleep_model, device, test_dataloader, loss)
        print (test_acc)
        all_test_accuracies.append(test_acc)
        print ("******************")


    print ("***FINAL RESULTS:****")  
    print ("LOAD:", USE)
    print ("EVALUATE:", TEST_ON)
    all_test_accuracies = np.array(all_test_accuracies)
    print ("Mean:", np.mean(all_test_accuracies))
    print ("Std:", np.std(all_test_accuracies))

chb_path = '../DATA/chb_patientwise.h5'
helsinki_path = '../DATA/helsinki_patientwise.h5'
nch_path = '../DATA/nch_patientwise.h5'
sienna_path = '../DATA/siena_patientwise.h5'
model_save_dir = 'saved_models_local'
model_save_name = 'train_all'
test_index = 2  # 0 = chb, 1 = helsinki, 2 = nch, 3 = sienna
use_index = 1
main(chb_path, helsinki_path, nch_path, sienna_path, 2, 1 model_save_dir, model_save_name)

## Test Combined

In [None]:
import argparse
import os
import logging
import torch
import torch.nn as nn
import numpy as np

from utils.datasets_statistics import get_channel_index_for_dataset
from utils.utils import *
from utils.dataloader import get_dataloader
from utils.model_training import test
from utils.tiny_sleep_net import TinySleepNet


def setup_logger(log_path):
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    logger.handlers.clear()

    formatter = logging.Formatter('%(message)s')
    console = logging.StreamHandler()
    console.setFormatter(formatter)
    logger.addHandler(console)

    fh = logging.FileHandler(log_path)
    fh.setFormatter(formatter)
    logger.addHandler(fh)


def main(chb, helsinki, nch, sienna, test_index, model_save_dir, model_save_name, log_path):
    setup_logger(log_path)

    batch_size = 64
    channel_name = 'F3-C3'
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    datasets_to_merge = [('chb', True), ('helsinki', True), ('nch', False), ('sienna', True)]
    dataset_paths = [chb, helsinki, nch, sienna]
    test_dataset_name, test_scale = datasets_to_merge[test_index]

    all_combined_accuracies, all_combined_f1s, all_combined_aurocs = [], [], []

    for seed in [42, 43, 44, 45, 46]:
        logging.info(f"\n=== SEED: {seed} ===")
        set_seed(seed)

        all_X_train, all_X_test, all_y_test, all_examples = [], [], [], []
        test_names = []

        for idx, (dataset_name, scale_to_nch) in enumerate(datasets_to_merge):
            file = dataset_paths[idx]
            channel_index = get_channel_index_for_dataset(dataset_name, channel_name)

            X_sub, y_sub = load_patientwise_file(file, channel_index, scale_to_nch)
            (X_train, _), (X_val, _), (X_test, y_test) = stratified_train_val_test_split(X_sub, y_sub, random_state=seed)

            if dataset_name == test_dataset_name:
                save_test_data = X_test
                save_test_labels = y_test

            all_X_train.append(X_train)
            all_X_test.append(X_test)
            all_y_test.append(y_test)
            all_examples.extend([X_train, X_val, X_test])
            test_names.extend([dataset_name] * len(X_test))

        all_examples = np.concatenate(all_examples, axis=0)
        global_mean, global_sd = np.mean(all_examples), np.std(all_examples)

        logging.info(f"Standardizing dataset {test_dataset_name}")
        logging.info(f"Global mean: {global_mean:.4f}, Global std: {global_sd:.4f}")

        standardized_test_data = (save_test_data - global_mean) / global_sd
        test_dataloader = get_dataloader(standardized_test_data, save_test_labels, batch_size, shuffle=False)

        sleep_model = TinySleepNet(num_classes=2, Fs=12, kernel_size=4).to(device)
        model_path = os.path.join(model_save_dir, f"{model_save_name}_{seed}.pth")
        sleep_model.load_state_dict(torch.load(model_path, weights_only=True))

        loss = nn.CrossEntropyLoss()
        logging.info("Testing model...")
        test_loss, test_acc, f1, auroc = test(sleep_model, device, test_dataloader, loss)

        logging.info(f"Test Accuracy: {test_acc:.2f} | F1 Score: {f1:.2f} | AUROC: {auroc:.2f}")
        all_combined_accuracies.append(test_acc)
        all_combined_f1s.append(f1)
        all_combined_aurocs.append(auroc)

    logging.info("\n********** FINAL RESULTS ************")
    logging.info(f"Evaluated on: {test_dataset_name}")
    logging.info(f"Accuracy Mean: {np.mean(all_combined_accuracies):.2f} | Std: {np.std(all_combined_accuracies):.2f}")
    logging.info(f"F1 Score Mean: {np.mean(all_combined_f1s):.2f} | Std: {np.std(all_combined_f1s):.2f}")
    logging.info(f"AUROC Mean: {np.mean(all_combined_aurocs):.2f} | Std: {np.std(all_combined_aurocs):.2f}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Evaluate pretrained TinySleepNet models on different EEG datasets.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )

    parser.add_argument('--chb', type=str, default='../DATA/chb_patientwise.h5', help='Path to CHB dataset')
    parser.add_argument('--helsinki', type=str, default='../DATA/helsinki_patientwise.h5', help='Path to Helsinki dataset')
    parser.add_argument('--nch', type=str, default='../DATA/nch_patientwise.h5', help='Path to NCH dataset')
    parser.add_argument('--sienna', type=str, default='../DATA/siena_patientwise.h5', help='Path to Sienna dataset')
    parser.add_argument('--test_index', type=int, choices=[0, 1, 2, 3], default=2,
                        help='Index of dataset to test on (0 = chb, 1 = helsinki, 2 = nch, 3 = sienna)')
    parser.add_argument('--model_save_dir', type=str, default='saved_models_local', help='Directory with saved models')
    parser.add_argument('--model_save_name', type=str, default='train_all', help='Base name for saved models (seed will be appended)')
    parser.add_argument('--log_path', type=str, default='cross_eval_log.txt', help='Path to log file')

    args = parser.parse_args()

    main(args.chb, args.helsinki, args.nch, args.sienna, args.test_index,
         args.model_save_dir, args.model_save_name, args.log_path)
