In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import shutil
import os

In [2]:
# CNN class
# TODO: some hardcoded numbers

class CNN(nn.Module):
    def __init__(self, input_size, conv_channels=(3, 10), kernel=(5, 5), dropout=0.0):
        super().__init__()
        num_flattened = (((input_size[1] - kernel[0] + 1) // 2 - kernel[1] + 1) // 2) * (((input_size[2] - kernel[0] + 1) // 2 - kernel[1] + 1) // 2) * conv_channels[1]
        self.conv1 = nn.Conv2d(input_size[0], conv_channels[0], kernel[0])
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(conv_channels[0], conv_channels[1], kernel[1])
        self.fc1 = nn.Linear(num_flattened, 120) # 5: 58, 7: 
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 2)
        self.dropout1d = nn.Dropout(p=dropout)
        self.dropout2d = nn.Dropout2d(p=dropout)
        self.batchnorm1 = nn.BatchNorm2d(input_size[0])
        self.batchnorm2 = nn.BatchNorm2d(conv_channels[0])

    def forward(self, x):
        x = self.batchnorm1(x)
        x = self.pool(F.relu(self.conv1(x)))
        x = self.dropout2d(x)
        # print(x.shape)
        # x = self.dropout2d(x)
        x = self.batchnorm2(x)
        x = self.pool(F.relu(self.conv2(x)))
        # print(x.shape)
        # x = self.dropout2d(x)
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        # print(x.shape)
        x = F.relu(self.fc1(x))
        x = self.dropout1d(x)
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [3]:
from scipy.signal import stft
import warnings
def pre_process(data, fs, dtype=np.float64):
    nperseg = 256 # hardcoded
    data = np.transpose(data, (0, 2, 1))
    stft_size = (129, data.shape[2] // (nperseg // 2) + 1)
    Zxx = np.zeros((data.shape[0], data.shape[1], stft_size[0], stft_size[1]), dtype=dtype)

    for i in range(data.shape[1]):      
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            _, _, Zxx[:, i, :] = stft(data[:,i,:], fs=fs)
    
    return Zxx

In [4]:
from swec_utils import training_model
def grid_search(epochs, device, weights, train_loader, val_loader, dtype, kernel_sizes, n_filters, dropouts, save_path, plot_results=False, verbose=False):
    sensitivity = []
    specificity = []
    auc = []
    f1 = []
    pr_auc = []
    for k in kernel_sizes:
        for f in n_filters:
            for d in dropouts:
                if (verbose):
                    print(f"Training model with kernel size = {k}, filter_sizes = {f}, dropout = {d}.")
                input_size = next(iter(train_loader))[0].shape[1:]
                model = CNN(input_size=input_size, conv_channels=f, kernel=(k, k), dropout=d)
                model.to(numpy_to_torch_dtype_dict[dtype])
                t_res, v_res = training_model(model, epochs, device, weights, train_loader, val_loader, dtype, plot_results=plot_results, verbose=verbose)
                sensitivity.append(v_res.sen[-1])
                specificity.append(v_res.spe[-1])
                auc.append(v_res.auc[-1])
                f1.append(v_res.f1[-1])
                pr_auc.append(v_res.pr_auc[-1])
                if (save_path):
                    print("Saving model")
                    save_path += 'CNN_k' + str(k) + '_f' + str(f[0]) + '_' + str(f[1]) + '_d' + str(d).replace('.', '')
                    torch.save(model.state_dict(), save_path)
    
    return t_res, v_res

In [5]:
import pandas as pd
from swec_utils import Results
def write_results_to_excel(writer, sheet_name, tr_res, vl_res):
    df = pd.DataFrame(data={'Sensitivity': [tr_res.sen[-1], vl_res.sen[-1]], 
        'Specificity': [tr_res.spe[-1], vl_res.spe[-1]], 
        'ROC AUC': [tr_res.auc[-1], vl_res.auc[-1]], 
        'F1 Score': [tr_res.f1[-1], vl_res.f1[-1]], 
        'PR AUC': [tr_res.pr_auc[-1], vl_res.pr_auc[-1]]
        }, index=['Training', 'Validation'])
    df.to_excel(writer, sheet_name)


In [6]:
from swec_utils import load_data, load_data_partitioned, STFTDataset, numpy_to_torch_dtype_dict
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight

# high level function to evaluate different data-based hyperparameters
def tune(subject_path, file_name, seq_duration, pil, sph, i_distance, partitioned=True, device='cpu', dtype=np.float64):
    with pd.ExcelWriter(file_name) as writer:
        EPOCHS = 50
        k, f, dr = 5, (5, 5), 0.5
        for n in seq_duration:
            for p in pil:
                for s in sph:
                    for d in i_distance:
                        BATCH = 512 // (n // 5)
                        ds = 10 - min(9, d//(24*3600)) # sketchy formula to get a reasonable number of interictal segments
                        print(f"Sequence duration: {n} | PIL: {p} | SPH: {s} | Interictal distance: {d}")
                        if (partitioned):
                            num_test_sz = 1
                            tr_data, tr_labels, ts_data, ts_labels, fs = load_data_partitioned(subject_path, n, s, p, d, ds=ds, dtype=dtype, num_test_sz=num_test_sz)
                        else:
                            data, labels, fs = load_data(subject_path, n, s, p, d, ds=ds, dtype=dtype)
                            tr_data, ts_data, tr_labels, ts_labels = train_test_split(data, labels, test_size=0.2, stratify=labels, random_state=0)

                        tr_data, vl_data, tr_labels, vl_labels = train_test_split(tr_data, tr_labels, test_size=0.2, stratify=tr_labels, random_state=0)
                        tr_Zxx = pre_process(tr_data, fs, dtype)
                        vl_Zxx = pre_process(vl_data, fs, dtype)
                        ts_Zxx = pre_process(ts_data, fs, dtype)

                        weights = torch.tensor(compute_class_weight(class_weight='balanced', classes=np.unique(tr_labels), y=tr_labels), dtype=numpy_to_torch_dtype_dict[dtype], device=device)
                        train_dataset = STFTDataset(tr_Zxx, tr_labels, dtype=dtype)
                        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH)
                        val_dataset = STFTDataset(vl_Zxx, vl_labels, dtype=dtype)
                        val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH)

                        tr_res, vl_res = grid_search(EPOCHS, device, weights, train_loader, val_loader, dtype, [k], [f], [dr], save_path='', plot_results=False)
                        
                        sheet_name = 'n' + str(n) + ' p' + str(p) + ' s' + str(s) + ' d' + str(d)
                        write_results_to_excel(writer, sheet_name, tr_res, vl_res)




In [7]:
def copy_subject(src, dst):
    src_files = os.listdir(src)
    for file_name in src_files:
        full_file_name = os.path.join(src, file_name)
        if os.path.isfile(full_file_name):
            shutil.copy(full_file_name, dst)
            print(file_name)

subject = 'ID12'
src = 'D:/research/swec/' + subject
dst = './swec/' + subject
data_path = './swec/' + subject + '/data.p'

if (not os.path.exists(dst)):
    os.mkdir(dst)
if (not os.path.exists(data_path)):
    if (len(os.listdir(dst)) <= 1):
        copy_subject(src, dst)

In [8]:
subject_path = './swec/' + subject + '/' + subject
use_cuda = True
device = torch.device("cuda" if use_cuda else "cpu")
tune(subject_path, 'partitioned.xlsx', [5, 10, 30], [600, 1200, 1800, 2400, 3600], [0, 300, 600], [0, 3600*24], True, device, np.float16)
# tune(subject_path, 'test.xlsx', [30], [3600], [0, 300], [0, 3600*24], True, device, np.float16)

Sequence duration: 30 | PIL: 3600 | SPH: 0 | Interictal distance: 0
Training:
	Loss: 2.373046875
	Sensitivity: 0.9618055555555556
	Specificity: 0.9419986023759609
	AUC: 0.9935760249242954
	F1 Score: 0.9134377576257214
	PR AUC: 0.984520834504591
Validation:
	Loss: 6.89990234375
	Sensitivity: 0.5555555555555556
	Specificity: 0.9888268156424581
	AUC: 0.9579841713221601
	F1 Score: 0.7017543859649122
	PR AUC: 0.9089267699407824
Sequence duration: 30 | PIL: 3600 | SPH: 0 | Interictal distance: 86400
Training:
	Loss: 0.837890625
	Sensitivity: 0.9704861111111112
	Specificity: 0.9784615384615385
	AUC: 0.9981757478632478
	F1 Score: 0.9730200174064404
	PR AUC: 0.9979461472004871
Validation:
	Loss: 0.2551116943359375
	Sensitivity: 0.9652777777777778
	Specificity: 0.9877300613496932
	AUC: 0.9989349011588275
	F1 Score: 0.9754385964912281
	PR AUC: 0.9987976037868691
Sequence duration: 30 | PIL: 3600 | SPH: 300 | Interictal distance: 0
Training:
	Loss: 3.05078125
	Sensitivity: 0.9392361111111112
	Spec