In [1]:
import pandas as pd
import numpy as np
from pathlib import Path
import os
from dataclasses import dataclass
import utils
from torch_geometric_temporal import  DynamicGraphTemporalSignal,StaticGraphTemporalSignal, temporal_signal_split, DynamicGraphTemporalSignalBatch
import networkx as nx
from torch_geometric.utils import from_networkx
import scipy
import sklearn
import random
from tqdm import tqdm
from torch_geometric.data import Data
from torch_geometric.loader import DynamicBatchSampler,DataLoader
import torch
from torch import nn
from torch.nn import functional as F
from torchvision.ops import sigmoid_focal_loss
from torch_geometric_temporal.nn.recurrent import DCRNN,  GConvGRU, A3TGCN, TGCN2, TGCN, A3TGCN2
from torch_geometric_temporal.nn.attention import STConv
from torchmetrics.classification import BinaryRecall,BinarySpecificity, AUROC, ROC
from torch_geometric.nn import global_mean_pool
import pytorch_lightning as pl
from torch_geometric.nn import GCNConv,BatchNorm,GATv2Conv
from sklearn.model_selection import KFold,StratifiedKFold, StratifiedShuffleSplit
from mne_features.univariate import compute_variance, compute_hjorth_complexity, compute_hjorth_mobility
import torchaudio
import mne_features
import torch_geometric
from imblearn.over_sampling import SMOTE
from collections import Counter
from librosa import zero_crossings
from scipy.signal import find_peaks, peak_prominences
from statistics import mean
from joblib import Parallel, delayed
import multiprocessing
import time

  from .autonotebook import tqdm as notebook_tqdm
2023-03-27 17:32:03.361234: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-03-27 17:32:03.667807: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-03-27 17:32:03.701808: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-03-27 17:32:03.701831: I tensorflow/compiler/x

In [2]:
torch_geometric.seed_everything(42)

In [3]:
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

In [4]:
# TODO think about using kwargs argument here to specify args for dataloader
@dataclass
class SeizureDataLoader:
    npy_dataset_path: Path
    event_tables_path: Path
    plv_values_path: Path
    loso_patient: str = None
    sampling_f: int = 256
    seizure_lookback: int = 600
    sample_timestep: int = 5
    inter_overlap: int = 0
    ictal_overlap: int = 0
    self_loops: bool = True
    balance: bool = True
    train_test_split: float = None
    fft: bool = False
    hjorth: bool = False
    teager_keiser: bool = False
    downsample: int = None
    buffer_time: int = 15
    batch_size: int = 32
    smote: bool = False
    """Class to prepare dataloaders for eeg seizure perdiction from stored files.

    Attributes:
        npy_dataset_path {Path} -- Path to folder with dataset preprocessed into .npy files.
        event_tables_path {Path} -- Path to folder with .csv files containing seizure events information for every patient.
        loso_patient {str} -- Name of patient to be selected for LOSO valdiation, specified in format "chb{patient_number}"",
        eg. "chb16". (default: {None}).
        samplin_f {int} -- Sampling frequency of the loaded eeg data. (default: {256}).
        seizure_lookback {int} -- Time horizon to sample pre-seizure data (length of period before seizure) in seconds. 
        (default: {600}).
        sample_timestep {int} -- Amounts of seconds analyzed in a single sample. (default: {5}).
        overlap {int} -- Amount of seconds overlap between samples. (default: {0}).
        self_loops {bool} -- Wheather to add self loops to nodes of the graph. (default: {True}).
        shuffle {bool} --  Wheather to shuffle training samples.


    """
    assert (fft and hjorth) == False, "When fft is True, hjorth should be False"

    def _get_event_tables(self, patient_name: str) -> tuple[dict, dict]:
        """Read events for given patient into start and stop times lists from .csv extracted files."""

        event_table_list = os.listdir(self.event_tables_path)
        patient_event_tables = [
            os.path.join(self.event_tables_path, ev_table)
            for ev_table in event_table_list
            if patient_name in ev_table
        ]
        patient_event_tables = sorted(patient_event_tables)
        patient_start_table = patient_event_tables[
            0
        ]  ## done terribly, but it has to be so for win/linux compat
        patient_stop_table = patient_event_tables[1]
        start_events_dict = pd.read_csv(patient_start_table).to_dict("index")
        stop_events_dict = pd.read_csv(patient_stop_table).to_dict("index")
        return start_events_dict, stop_events_dict

    def _get_recording_events(self, events_dict, recording) -> list[int]:
        """Read seizure times into list from event_dict"""
        recording_list = list(events_dict[recording + ".edf"].values())
        recording_events = [int(x) for x in recording_list if not np.isnan(x)]
        return recording_events

    def _get_graph(self, n_nodes: int) -> nx.Graph:
        """Creates Networx fully connected graph with self loops"""
        graph = nx.complete_graph(n_nodes)
        self_loops = [[node, node] for node in graph.nodes()]
        graph.add_edges_from(self_loops)
        return graph

    def _get_edge_weights_recording(self, plv_values: np.ndarray) -> np.ndarray:
        """Method that takes plv values for given recording and assigns them
        as edge attributes to a fc graph."""
        graph = self._get_graph(plv_values.shape[0])
        garph_dict = {}
        for edge in graph.edges():
            e_start, e_end = edge
            garph_dict[edge] = {"plv": plv_values[e_start, e_end]}
        nx.set_edge_attributes(graph, garph_dict)
        edge_weights = from_networkx(graph).plv.numpy()
        return edge_weights

    def _get_edges(self):
        """Method to assign edge attributes. Has to be called AFTER get_dataset() method."""
        graph = self._get_graph(self._features.shape[1])
        edges = np.expand_dims(from_networkx(graph).edge_index.numpy(), axis=0)
        edges_per_sample_train = np.repeat(
            edges, repeats=self._features.shape[0], axis=0
        )
        self._edges = torch.tensor(edges_per_sample_train)
        if self.loso_patient is not None:
            edges_per_sample_val = np.repeat(
                edges, repeats=self._val_features.shape[0], axis=0
            )
            self._val_edges = torch.tensor(edges_per_sample_val)

    def _array_to_tensor(self):
        """Method converting features, edges and weights to torch.tensors"""

        self._features = torch.tensor(self._features, dtype=torch.float32)
        self._labels = torch.tensor(self._labels)
        self._time_labels = torch.tensor(self._time_labels)
        self._edge_weights = torch.tensor(self._edge_weights)
        if self.loso_patient is not None:
            self._val_features = torch.tensor(self._val_features, dtype=torch.float32)
            self._val_labels = torch.tensor(self._val_labels)
            self._val_time_labels = torch.tensor(self._val_time_labels)
            self._val_edge_weights = torch.tensor(self._val_edge_weights)

    def _get_labels_count(self):
        labels, counts = np.unique(self._labels, return_counts=True)
        print(labels, counts)
        self._label_counts = {}
        for n, label in enumerate(labels):
            self._label_counts[int(label)] = counts[n]
        if self.loso_patient is not None:
            labels, counts = np.unique(self._val_labels, return_counts=True)
            self._val_label_counts = {}
            for n, label in enumerate(labels):
                self._val_label_counts[int(label)] = counts[n]

    def _perform_features_fft(self):
        self._features = torch.fft.rfft(self._features)
        if self.loso_patient is not None:
            self._val_features = torch.fft.rfft(self._val_features)

    def _downsample_features(self):
        resampler = torchaudio.transforms.Resample(self.sampling_f, self.downsample)
        self._features = resampler(self._features)
        if self.loso_patient is not None:
            self._val_features = resampler(self._val_features)

    def _calculate_hjorth_features(self, features):
        new_features = np.array(
            [
                np.concatenate(
                    [
                        np.expand_dims(compute_variance(feature), 1),
                        np.expand_dims(compute_hjorth_mobility(feature), 1),
                        np.expand_dims(compute_hjorth_complexity(feature), 1),
                        np.expand_dims([len(find_peaks(sig)[0]) for sig in feature], 1),
                        np.expand_dims(np.sum(zero_crossings(feature), axis=1), 1),
                        np.expand_dims(
                            [
                                peak_prominences(sig, find_peaks(sig)[0])[0].mean()
                                for sig in feature
                            ],
                            1,
                        ),
                    ],
                    axis=1,
                )
                for feature in features
            ]
        )
        return new_features

    def _features_to_data_list(self, features, edges, edge_weights, labels, time_label):
        data_list = [
            Data(
                x=features[i],
                edge_index=edges[i],
                edge_attr=edge_weights[i],
                y=labels[i],
                # time=time_label[i],
            )
            for i in range(len(features))
        ]
        return data_list

    def _split_data_list(self, data_list):
        class_labels = torch.tensor(
            [data.y.item() for data in data_list], dtype=torch.float32
        ).unsqueeze(1)
        patient_labels = torch.tensor(
            np.expand_dims(self._patient_number, 1), dtype=torch.float32
        )
        class_labels_patient_labels = torch.cat([class_labels, patient_labels], dim=1)
        splitter = StratifiedShuffleSplit(
            n_splits=1, test_size=self.train_test_split, random_state=42
        )
        train_indices, val_indices = next(
            splitter.split(data_list, class_labels_patient_labels)
        )
        self._indexes_to_later_delete = {"train": train_indices, "val": val_indices}
        data_list_train = [data_list[i] for i in train_indices]
        dataset_list_val = [data_list[i] for i in val_indices]
        return data_list_train, dataset_list_val

    def _initialize_dicts(self):
        self._features_dict = {}
        self._labels_dict = {}
        self._time_labels_dict = {}
        self._edge_weights_dict = {}
        self._patient_number_dict = {}
        if self.loso_patient:
            self._val_features_dict = {}
            self._val_labels_dict = {}
            self._val_time_labels_dict = {}
            self._val_edge_weights_dict = {}
            self._val_patient_number_dict = {}

    def _convert_dict_to_array(self):
        self._features = np.concatenate(
            [self._features_dict[key] for key in self._features_dict.keys()]
        )
        del self._features_dict
        self._labels = np.concatenate(
            [self._labels_dict[key] for key in self._labels_dict.keys()]
        )
        del self._labels_dict
        self._time_labels = np.concatenate(
            [self._time_labels_dict[key] for key in self._time_labels_dict.keys()]
        )
        del self._time_labels_dict
        self._edge_weights = np.concatenate(
            [self._edge_weights_dict[key] for key in self._edge_weights_dict.keys()]
        )
        del self._edge_weights_dict
        self._patient_number = np.concatenate(
            [self._patient_number_dict[key] for key in self._patient_number_dict.keys()]
        )
        del self._patient_number_dict
        if self.loso_patient:
            self._val_features = np.concatenate(
                [self._val_features_dict[key] for key in self._val_features_dict.keys()]
            )
            del self._val_features_dict
            self._val_labels = np.concatenate(
                [self._val_labels_dict[key] for key in self._val_labels_dict.keys()]
            )
            del self._val_labels_dict
            self._val_time_labels = np.concatenate(
                [
                    self._val_time_labels_dict[key]
                    for key in self._val_time_labels_dict.keys()
                ]
            )
            del self._val_time_labels_dict
            self._val_edge_weights = np.concatenate(
                [
                    self._val_edge_weights_dict[key]
                    for key in self._val_edge_weights_dict.keys()
                ]
            )
            del self._val_edge_weights_dict
            self._val_patient_number = np.concatenate(
                [
                    self._val_patient_number_dict[key]
                    for key in self._val_patient_number_dict.keys()
                ]
            )
            del self._val_patient_number_dict

    def _balance_classes(self):
        negative_label = self._label_counts[0]
        positive_label = self._label_counts[1]
        
        print(f"Number of negative samples pre removal {negative_label}")
        print(f"Number of positive samples pre removal {positive_label}")
        imbalance = negative_label - positive_label
        print(f"imbalance {imbalance}")
        negative_indices = np.where(self._labels == 0)[0]
        indices_to_discard = np.random.choice(
            negative_indices, size=imbalance, replace=False
        )

        self._features = np.delete(self._features, obj=indices_to_discard, axis=0)
        self._labels = np.delete(self._labels, obj=indices_to_discard, axis=0)
        self._time_labels = np.delete(self._time_labels, obj=indices_to_discard, axis=0)
        self._edge_weights = np.delete(
            self._edge_weights, obj=indices_to_discard, axis=0
        )
        self._patient_number = np.delete(self._patient_number, obj=indices_to_discard,axis=0)

    def _standardize_data(self, features, labels, loso_features=None):
        indexes = np.where(labels == 0)[0]  
        features_negative = features[indexes]
        channel_mean = features_negative.mean()
        channel_std = features_negative.std()
        # if features_negative.shape[0] == 1:
        #     channel_mean = features_negative.mean(2).squeeze()
        #     channel_std = features_negative.std(2).squeeze()
        # else:
        #    # print(features_negative[0].shape)
        #     channel_mean = features_negative.mean(axis=0).mean(1)
        #     channel_std = features_negative.std(axis=0).std(1)
        for i in range(features.shape[0]):
            for n in range(features.shape[1]):
                #        features[i,n,:] = (features[i,n,:] - channel_mean[n])/channel_std[n]
                features[i, n, :] = (features[i, n, :] - channel_mean) / channel_std
        if (
            loso_features is not None
        ):  ## standardize loso features with the same values as for training data
            for i in range(loso_features.shape[0]):
                for n in range(loso_features.shape[1]):
                    loso_features[i, n, :] = (
                        loso_features[i, n, :] - channel_mean
                    ) / channel_std

    def _min_max_scale(self, features, labels):
        indexes = np.where(labels == 0)[0]  ## changed from 0!
        features_negative = features[indexes]

        channel_min = features_negative.min(axis=0).min(1)
        channel_max = features_negative.max(axis=0).max(1)
        for i in range(features.shape[0]):
            for n in range(features.shape[1]):
                features[i, n, :] = (features[i, n, :] - channel_min[n]) / (
                    channel_max[n] - channel_min[n]
                )
                # features[i,n,:] = (features[i,n,:] - channel_min)/(channel_max - channel_min)

    def _apply_smote(self, features, labels):
        dim_1 = np.array(features).shape[0]
        dim_2 = np.array(features).shape[1]
        dim_3 = np.array(features).shape[2]

        new_dim = dim_1 * dim_2
        new_x_train = features.reshape(new_dim, dim_3)
        new_y_train = []
        for i in range(len(labels)):
            new_y_train.extend([labels[i]] * dim_2)

        new_y_train = np.array(new_y_train)

        # transform the dataset
        oversample = SMOTE(random_state=42)
        x_train, y_train = oversample.fit_resample(new_x_train, new_y_train)
        x_train_smote = x_train.reshape(int(x_train.shape[0] / dim_2), dim_2, dim_3)
        y_train_smote = []
        for i in range(int(x_train.shape[0] / dim_2)):
            # print(i)
            value_list = list(y_train.reshape(int(x_train.shape[0] / dim_2), dim_2)[i])
            # print(list(set(value_list)))
            y_train_smote.extend(list(set(value_list)))
            ## Check: if there is any different value in a list
            if len(set(value_list)) != 1:
                print(
                    "\n\n********* STOP: THERE IS SOMETHING WRONG IN TRAIN ******\n\n"
                )
        y_train_smote = np.array(y_train_smote)
        # print(np.unique(y_train_smote,return_counts=True))
        return x_train_smote, y_train_smote

    def _get_labels_features_edge_weights_seizure(self, patient):
        """Prepare features, labels, time labels and edge wieghts for training and
        optionally validation data."""

        event_tables = self._get_event_tables(
            patient
        )  # extract start and stop of seizure for patient
        patient_path = os.path.join(self.npy_dataset_path, patient)
        recording_list = os.listdir(patient_path)
        for record in recording_list:  # iterate over recordings for a patient
            # if "seizures_" not in record:
            #     ## skip non-seizure files
            #     continue

            recording_path = os.path.join(patient_path, record)
            record = record.replace(
                "seizures_", ""
            )  ## some magic to get it properly working with event tables
            record_id = record.split(".npy")[0]  #  get record id
            start_event_tables = self._get_recording_events(
                event_tables[0], record_id
            )  # get start events
            stop_event_tables = self._get_recording_events(
                event_tables[1], record_id
            )  # get stop events
            data_array = np.load(recording_path)  # load the recording

            plv_edge_weights = np.expand_dims(
                self._get_edge_weights_recording(
                    np.load(os.path.join(self.plv_values_path, patient, record))
                ),
                axis=0,
            )

            features, labels, time_labels = utils.extract_training_data_and_labels(
                data_array,
                start_event_tables,
                stop_event_tables,
                fs=self.sampling_f,
                seizure_lookback=self.seizure_lookback,
                sample_timestep=self.sample_timestep,
                inter_overlap=self.inter_overlap,
                ictal_overlap=self.ictal_overlap,
                buffer_time=self.buffer_time,
            )

            if features is None:
                # print(
                #     f"Skipping the recording {record} patients {patient} cuz features are none"
                # )
                continue
            if len(np.unique(labels)) != 2:
                # print(
                #     f"Skipping the recording {record} patients {patient} cuz no seizure samples"
                # )
                continue

            features = features.squeeze()

            if (self.smote):  
               # print(f"Applying smote on loso patient {patient} features")
                smote_start = time.time()
                features, labels = self._apply_smote(features, labels)
                logging.info(f"Applied one smote in {time.time() - smote_start} for patient {patient}")

            time_labels = np.expand_dims(time_labels.astype(np.int32), 1)
            labels = labels.reshape((labels.shape[0], 1)).astype(np.float32)
            patient_number = torch.full(
                [labels.shape[0]],
                int("".join(x for x in patient if x.isdigit())),
                dtype=torch.float32,
            )
            if patient == self.loso_patient:
                #logging.info(f"Adding recording {record} of patient {patient}")
                try:
                    self._val_features_dict[patient] = np.concatenate(
                        (self._val_features_dict[patient], features), axis=0
                    )
                    self._val_labels_dict[patient] = np.concatenate(
                        (self._val_labels_dict[patient], labels), axis=0
                    )
                    self._val_time_labels_dict[patient] = np.concatenate(
                        (self._val_time_labels_dict[patient], time_labels), axis=0
                    )
                    self._val_edge_weights_dict[patient] = np.concatenate(
                        (
                            self._val_edge_weights_dict[patient],
                            np.repeat(plv_edge_weights, features.shape[0], axis=0),
                        )
                    )

                    self._val_patient_number_dict[patient] = np.concatenate(
                        (self._val_patient_number_dict[patient], patient_number)
                    )
                except:
                    self._val_features_dict[patient] = features
                    self._val_labels_dict[patient] = labels
                    self._val_time_labels_dict[patient] = time_labels
                    self._val_edge_weights_dict[patient] = np.repeat(
                        plv_edge_weights, features.shape[0], axis=0
                    )
                    self._val_patient_number_dict[patient] = patient_number

            else:
                try:
                    self._features_dict[patient] = np.concatenate(
                        (self._features_dict[patient], features), axis=0
                    )
                    self._labels_dict[patient] = np.concatenate(
                        (self._labels_dict[patient], labels), axis=0
                    )
                    self._time_labels_dict[patient] = np.concatenate(
                        (self._time_labels_dict[patient], time_labels), axis=0
                    )
                    self._edge_weights_dict[patient] = np.concatenate(
                        (
                            self._edge_weights_dict[patient],
                            np.repeat(plv_edge_weights, features.shape[0], axis=0),
                        )
                    )

                    self._patient_number_dict[patient] = np.concatenate(
                        (self._patient_number_dict[patient], patient_number)
                    )
                except:
                    #print("Creating initial attributes")
                    self._features_dict[patient] = features
                    self._labels_dict[patient] = labels
                    self._time_labels_dict[patient] = time_labels
                    self._edge_weights_dict[patient] = np.repeat(
                        plv_edge_weights, features.shape[0], axis=0
                    )
                    self._patient_number_dict[patient] = patient_number

    def _get_labels_features_edge_weights_interictal(
        self, samples_recording: int = None
    ):
        patient_list = os.listdir(self.npy_dataset_path)
        interictal_samples = len(np.where(self._labels == 0))
        loso_interictal_samples = len(np.where(self._val_labels == 0))
        for patient in patient_list:
            patient_path = os.path.join(self.npy_dataset_path, patient)
            ## get all non-seizure recordings
            recording_list = [
                recording
                for recording in os.listdir(patient_path)
                if not "seizures_" in recording
            ]
            if not samples_recording:
                if patient == self.loso_patient:
                    samples_per_recording = int(
                        loso_interictal_samples / len(recording_list)
                    )
                else:
                    samples_per_recording = int(
                        interictal_samples / len(recording_list)
                    )
            else:
                samples_per_recording = samples_recording
            for recording in recording_list:
                recording_path = os.path.join(patient_path, recording)
                data_array = np.load(recording_path)

    # TODO define a method to create edges and calculate plv to get weights
    def get_dataset(self) -> DynamicGraphTemporalSignal:
        """Creating graph data iterators. The iterator yelds dynamic, weighted and undirected graphs
        containing self loops. Every node represents one electrode in EEG. The graph is fully connected,
        edge weights are calculated for every EEG recording as PLV between channels (edge weight describes
        the "strength" of connectivity between two channels in a recording). Node features are values of
        channel voltages in time. Features are of shape [nodes,features,timesteps].

        Returns:
            train_dataset {DynamicGraphTemporalSignal} -- Training data iterator.
            valid_dataset {DynamicGraphTemporalSignal} -- Validation data iterator (only if loso_patient is
            specified in class constructor).
        """
        ### TODO rozkminić o co chodzi z tym całym time labels - na razie wartość liczbowa która tam wchodzi
        ### to shape atrybutu time_labels

        self._initialize_dicts()
        patient_list = os.listdir(self.npy_dataset_path)
        start_time = time.time()
        if self.smote:
            for patient in patient_list:
                self._get_labels_features_edge_weights_seizure(patient)
        else:
            Parallel(n_jobs=6, require="sharedmem")(
                delayed(self._get_labels_features_edge_weights_seizure)(patient)
                for patient in patient_list
            )
        
        self._convert_dict_to_array()
        self._get_labels_count()
        if self.balance:
            self._balance_classes()
        
        print(f"Finished processing in {time.time() - start_time} seconds")
        print(f"Features shape {self._features.shape}")

        start_time_preprocessing = time.time()
        self._standardize_data(self._features, self._labels, self._val_features)
        
        self._get_edges()
        #self._get_labels_count()
        if self.hjorth:
            self._features = self._calculate_hjorth_features(self._features)
        self._array_to_tensor()
        if self.downsample and not self.hjorth:
            self._downsample_features()
        if self.fft:
            self._perform_features_fft()

        train_dataset = torch.utils.data.TensorDataset(
            self._features,
            self._edges,
            self._edge_weights,
            self._labels,
            # self._time_labels,
        )
        if self.train_test_split is not None:
            if self.fft or self.hjorth:
                data_list = self._features_to_data_list(
                    self._features,
                    self._edges,
                    self._edge_weights,
                    self._labels,
                    self._time_labels,
                )
                train_data_list, val_data_list = self._split_data_list(data_list)
                label_count = np.unique(
                    [data.y.item() for data in train_data_list], return_counts=True
                )[1]
                self.alpha = label_count[0] / label_count[1]
                loaders = [
                    DataLoader(
                        train_data_list,
                        batch_size=self.batch_size,
                        shuffle=True,
                        drop_last=False,
                    ),
                    DataLoader(
                        val_data_list,
                        batch_size=len(val_data_list),
                        shuffle=False,
                        drop_last=False,
                    ),
                ]

            else:
                train_dataset, val_dataset = torch.utils.data.random_split(
                    train_dataset,
                    [1 - self.train_test_split, self.train_test_split],
                    generator=torch.Generator().manual_seed(42),
                )

                train_dataloader = torch.utils.data.DataLoader(
                    train_dataset,
                    batch_size=self.batch_size,
                    shuffle=True,
                    # num_workers=2,
                    # pin_memory=True,
                    # prefetch_factor=4,
                    drop_last=False,
                )

                val_dataloader = torch.utils.data.DataLoader(
                    val_dataset,
                    batch_size=self.batch_size,
                    shuffle=False,
                    # num_workers=2,
                    # pin_memory=True,
                    # prefetch_factor=4,
                    drop_last=False,
                )
                loaders = [train_dataloader, val_dataloader]
        else:
            if self.fft or self.hjorth:
                train_data_list = self._features_to_data_list(
                    self._features,
                    self._edges,
                    self._edge_weights,
                    self._labels,
                    self._time_labels,
                )
                loaders = [
                    DataLoader(
                        train_data_list,
                        batch_size=self.batch_size,
                        shuffle=True,
                        drop_last=False,
                    )
                ]
            else:
                train_dataloader = torch.utils.data.DataLoader(
                    train_dataset,
                    batch_size=self.batch_size,
                    shuffle=True,
                    # num_workers=2,
                    # pin_memory=True,
                    # prefetch_factor=4,
                    drop_last=False,
                )
                loaders = [train_dataloader]
        if self.loso_patient:
            if self.hjorth:
                self._val_features = self._calculate_hjorth_features(self._val_features)

            if self.fft or self.hjorth:
                loso_data_list = self._features_to_data_list(
                    self._val_features,
                    self._val_edges,
                    self._val_edge_weights,
                    self._val_labels,
                    self._val_time_labels,
                )
                print("Preprocessing time: ", time.time() - start_time_preprocessing)
                return (
                    *loaders,
                    DataLoader(
                        loso_data_list,
                        batch_size=len(loso_data_list),
                        shuffle=False,
                        drop_last=False,
                    ),
                )
            loso_dataset = torch.utils.data.TensorDataset(
                self._val_features,
                self._val_edges,
                self._val_edge_weights,
                self._val_labels,
                #  self._val_time_labels,
            )
            loso_dataloader = torch.utils.data.DataLoader(
                loso_dataset,
                batch_size=self.batch_size,
                shuffle=False,
                # pin_memory=True,
                # num_workers=2,
                # prefetch_factor=4,
                drop_last=False,
            )

            return (*loaders, loso_dataloader)

        return (*loaders,)

In [12]:
class RecurrentGCN(torch.nn.Module):
    def __init__(self, timestep,sfreq, n_nodes=18,batch_size=32):
        super(RecurrentGCN, self).__init__()
        self.n_nodes = n_nodes
        self.out_features = 128
        self.recurrent_1 = A3TGCN2(sfreq*timestep,32, add_self_loops=True,improved=False)
        self.recurrent_2 = GCNConv(32,64,add_self_loops=True,improved=False)
        self.recurrent_3 = GCNConv(64,128,add_self_loops=True,improved=False)
        self.fc1 = torch.nn.Linear(n_nodes*128, 64)
        self.fc2 = torch.nn.Linear(64, 32)
        self.fc3 = torch.nn.Linear(32, 16)
        self.fc4 = torch.nn.Linear(16, 1)
        self.flatten = torch.nn.Flatten(start_dim=0)
        self.dropout = torch.nn.Dropout()
    def forward(self, x, edge_index,edge_weight,batch):
        x = torch.squeeze(x)
        h = self.recurrent_1(x, edge_index=edge_index, edge_weight = edge_weight)
        h = torch.nn.BatchNorm1d(32)(h)
        h = F.leaky_relu(h)
        h = self.recurrent_2(h, edge_index,edge_weight)
        h = torch.nn.BatchNorm1d(64)(h)
        h = F.leaky_relu(h)
        h = self.recurrent_3(h, edge_index,edge_weight)
        h = torch.nn.BatchNorm1d(128)(h)
        h = F.leaky_relu(h)
        h = global_mean_pool(h,batch)
        h = self.dropout(h)
        h = self.fc1(h)
        h = F.leaky_relu(h)
        h = self.dropout(h)
        h = self.fc2(h)
        h = F.leaky_relu(h)
        h = self.dropout(h)
        h = self.fc3(h)
        h = F.leaky_relu(h)
        h = self.dropout(h)
        h = self.fc4(h)
        return h

In [13]:
class GATv2(torch.nn.Module):
    def __init__(self, timestep,sfreq, n_nodes=18,batch_size=32):
        super(GATv2, self).__init__()
        self.n_nodes = n_nodes
        self.out_features = 128
        n_heads = 4
        self.recurrent_1 = GATv2Conv(int((sfreq*timestep/2)+1),32,heads=n_heads,negative_slope=0.01,dropout=0.4, add_self_loops=True,improved=True,edge_dim=1)
        self.recurrent_2 = GATv2Conv(128,64,heads=n_heads,negative_slope=0.01,dropout=0.4, add_self_loops=True,improved=True,edge_dim=1)

        # int((sfreq*timestep/2)+1)
        # for children in list(self.recurrent_1.children()):
        #     for param in list(children.named_parameters()):
        #         if param[0] == 'weight':
        #             nn.init.kaiming_uniform_(param[1], a=0.01)
        #nn.init.kaiming_uniform_(self.recurrent_1.att,a=0.01)
        self.fc1 = torch.nn.Linear(64*n_heads, 128)
        nn.init.kaiming_uniform_(self.fc1.weight,a=0.01)
        self.fc2 = torch.nn.Linear(128, 64)
        nn.init.kaiming_uniform_(self.fc2.weight,a=0.01)
        self.fc3 = torch.nn.Linear(64, 1)
        nn.init.kaiming_uniform_(self.fc3.weight,a=0.01)
        self.fc4 = torch.nn.Linear(128, 1)
        self.connectivity = torch.nn.Linear(sfreq*timestep*n_nodes,324)
        self.connectivity_2 = torch.nn.Linear(sfreq*timestep,324)
        self.batch_norm_1 = torch.nn.BatchNorm1d(32*n_heads)
        self.batch_norm_2 = torch.nn.BatchNorm1d(64*n_heads)
        self.dropout = torch.nn.Dropout()
    def forward(self, x, edge_index, edge_attr,batch):
        h = self.recurrent_1(x, edge_index=edge_index, edge_attr = edge_attr)
        h = self.batch_norm_1(h)
        h = F.leaky_relu(h)
        #h = global_mean_pool(h,batch)
   
        h = self.recurrent_2(h, edge_index=edge_index, edge_attr = edge_attr)
        h = self.batch_norm_2(h)
        h = F.leaky_relu(h)
        h = global_mean_pool(h,batch)
    
        h = self.dropout(h)
        h = self.fc1(h)
        h = F.leaky_relu(h)
        h = self.dropout(h)
        h = self.fc2(h)
        h = F.leaky_relu(h)
        h = self.dropout(h)
        h = self.fc3(h)
        # h = F.leaky_relu(h)
        # h = self.dropout(h)
        # h = self.fc4(h)
        return h.squeeze(1)

In [5]:
TIMESTEP = 6
INTER_OVERLAP = 0
ICTAL_OVERLAP = 0
SFREQ = 256
torch_geometric.seed_everything(42)
dataloader = SeizureDataLoader(
    npy_dataset_path=Path('data/npy_data'),
    event_tables_path=Path('data/event_tables'),
    plv_values_path=Path('data/plv_arrays'),
    loso_patient='chb20',
    sampling_f=SFREQ,
    seizure_lookback=600,
    sample_timestep= TIMESTEP,
    inter_overlap=INTER_OVERLAP,
    ictal_overlap=ICTAL_OVERLAP,
    self_loops=False,
    balance=False,
    train_test_split=0.05,
    fft=True,
    hjorth=False,
    teager_keiser=False,
    downsample=60,
    batch_size=64,
    buffer_time=60,
    smote=False,
    )
train_loader,valid_loader, loso_loader =dataloader.get_dataset() #,loso_loader
alpha = list(dataloader._label_counts.values())[0]/list(dataloader._label_counts.values())[1]
alpha

[0. 1.] [15131  1711]
Finished processing in 19.281068325042725 seconds
Features shape (16842, 18, 1536)
Preprocessing time:  12.738569736480713


8.843366452367038

In [15]:
alpha

8.843366452367038

In [32]:
features = dataloader._features
labels = dataloader._labels
indexes = np.where(dataloader._val_labels == 0)[0]
features_negative = features[indexes]
indexes_positive = np.where(dataloader._val_labels == 1)[0]
features_positive = features[indexes_positive]

In [None]:
## compare samples 14323 negative and 2543 positive - for the standarization they were the same amplitudes
## how do I choose criteria for rejecting a sample when it is flat most of the time?
## wavelet coef energy is useful

In [None]:
from matplotlib import pyplot as plt
hjorth_features_negative = mne_features.univariate.compute_variance(features_negative[0].numpy())
hjorth_features_positive = mne_features.univariate.compute_variance(features_positive[1].numpy())
plt.plot(hjorth_features_negative)
plt.plot(hjorth_features_positive)
plt.legend(['negative','positive'])

In [None]:
from matplotlib import pyplot as plt
hjorth_features_negative = mne_features.univariate.compute_hjorth_complexity(features_negative[0].numpy())
hjorth_features_positive = mne_features.univariate.compute_hjorth_complexity(features_positive[2].numpy())
plt.plot(hjorth_features_negative)
plt.plot(hjorth_features_positive)
plt.legend(['negative','positive'])

In [None]:
from matplotlib import pyplot as plt
hjorth_features_negative = mne_features.univariate.compute_hjorth_mobility(features_negative[1].numpy())
hjorth_features_positive = mne_features.univariate.compute_hjorth_mobility(features_positive[2].numpy())
plt.plot(hjorth_features_negative)
plt.plot(hjorth_features_positive)
plt.legend(['negative','positive'])

In [None]:
from matplotlib import pyplot as plt
fig,ax = plt.subplots(18,1,figsize=(20,20))
for i in range(18):
    
    ax[i].scatter(y=features_negative[17,i,1],x =[1])
    ax[i].scatter(y=features_positive[0,i,1],x =[1])
    ax[i].set_title(f'Channel {i+1}')
plt.legend(['negative','positive'])
plt.show()

# fig,ax = plt.subplots(18,1,figsize=(10,10))
# for i in range(18):
#     ax [i].plot(features_positive[865,i,:])
#     ax[i].set_title(f'Channel {i+1}')
# plt.show()

In [None]:
from matplotlib import pyplot as plt
fig,ax = plt.subplots(18,1,figsize=(10,10))
for i in range(18):
    
    ax[i].plot(features_negative[6532,i,:])
    ax[i].plot(features_positive[7000,i,:])
    ax[i].set_title(f'Channel {i+1}')
plt.legend(['negative','positive'])
plt.show()

# fig,ax = plt.subplots(18,1,figsize=(10,10))
# for i in range(18):
#     ax [i].plot(features_positive[865,i,:])
#     ax[i].set_title(f'Channel {i+1}')
# plt.show()

In [None]:
features_loso = dataloader._val_features.numpy()
indexes_loso = np.where(dataloader._val_labels == 0)[0]
features_negative_loso = features_loso[indexes_loso]
indexes_positive_loso = np.where(dataloader._val_labels == 1)[0]
features_positive_loso = features_loso[indexes_positive_loso]

In [None]:
from matplotlib import pyplot as plt
channel_mean_loso = np.mean(features_negative_loso,axis=0).mean(1)
channel_std_loso = np.std(features_negative_loso,axis=0).std(1)
for i in range(features_negative_loso.shape[0]):
    for n in range(features_negative_loso.shape[1]):
        features_negative_loso[i,n,:] = (features_negative_loso[i,n,:] - channel_mean_loso[n])/channel_std_loso[n]
for i in range(features_positive_loso.shape[0]):
    for n in range(features_positive_loso.shape[1]):
        features_positive_loso[i,n,:] = (features_positive_loso[i,n,:] - channel_mean_loso[n])/channel_std_loso[n]
fig,ax = plt.subplots(5,1,figsize=(10,10))
for i in range(5):
    ax[i].plot(features_negative_loso[3,i,:])
    ax[i].set_title(f'Channel {i+1}')
plt.show()

fig,ax = plt.subplots(5,1,figsize=(10,10))
for i in range(5):
    ax[i].plot(features_positive_loso[12,i,:])
    ax[i].set_title(f'Channel {i+1}')
plt.show()

In [None]:
channel_std = np.std(features_negative,axis=0).mean(1)

In [None]:
from matplotlib import pyplot as plt
channel_mean = np.mean(features_negative,axis=0).mean(1)
channel_std = np.std(features_negative,axis=0).std(1)
for i in range(features_negative.shape[0]):
    for n in range(features_negative.shape[1]):
        features_negative[i,n,:] = (features_negative[i,n,:] - channel_mean[n])/channel_std[n]
for i in range(features_positive.shape[0]):
    for n in range(features_positive.shape[1]):
        features_positive[i,n,:] = (features_positive[i,n,:] - channel_mean[n])/channel_std[n]
fig,ax = plt.subplots(5,1,figsize=(10,10))
for i in range(5):
    ax[i].plot(features_positive[1,i,:])
    ax[i].set_title(f'Channel {i+1}')
plt.show()

In [None]:
from torch_geometric.explain import Explainer, AttentionExplainer, ExplainerConfig
explainer = Explainer(
    model=model,
    algorithm=AttentionExplainer(),
    explanation_type='model',
   # node_mask_type='attributes',
    edge_mask_type='object',
    model_config=dict(
        mode='binary_classification',
        task_level='graph',
        return_type='raw',  # Model returns log probabilities.
    ),
)

In [None]:
class RecurrentGCN(torch.nn.Module):
    def __init__(self, timestep,sfreq, n_nodes=18,batch_size=32):
        super(RecurrentGCN, self).__init__()
        self.n_nodes = n_nodes
        self.out_features = 128
        self.recurrent_1 = GCNConv(3,32, add_self_loops=True,improved=False)
        self.recurrent_2 = GCNConv(32,64,add_self_loops=True,improved=False)
        self.recurrent_3 = GCNConv(64,128,add_self_loops=True,improved=False)
        self.fc1 = torch.nn.Linear(128, 64)
        self.fc2 = torch.nn.Linear(64, 32)
        self.fc3 = torch.nn.Linear(32, 16)
        self.fc4 = torch.nn.Linear(16, 1)
        self.batch_norm_1 = torch.nn.BatchNorm1d(32)
        self.batch_norm_2 = torch.nn.BatchNorm1d(64)
        self.batch_norm_3 = torch.nn.BatchNorm1d(128)
        self.flatten = torch.nn.Flatten(start_dim=0)
        self.dropout = torch.nn.Dropout()
    def forward(self, x, edge_index,edge_weight,batch):
        x = torch.squeeze(x)
        h = self.recurrent_1(x, edge_index=edge_index, edge_weight = edge_weight)
        h = self.batch_norm_1(h)
        h = F.leaky_relu(h)
        h = self.recurrent_2(h, edge_index,edge_weight)
        h = self.batch_norm_2(h)
        h = F.leaky_relu(h)
        h = self.recurrent_3(h, edge_index,edge_weight)
        h = self.batch_norm_3(h)
        h = F.leaky_relu(h)
        h = global_mean_pool(h,batch)
        h = self.dropout(h)
        h = self.fc1(h)
        h = F.leaky_relu(h)
        h = self.dropout(h)
        h = self.fc2(h)
        h = F.leaky_relu(h)
        h = self.dropout(h)
        h = self.fc3(h)
        h = F.leaky_relu(h)
        h = self.dropout(h)
        h = self.fc4(h)
        return h.squeeze()

In [None]:
class GATv2Lightning(pl.LightningModule):
    def __init__(self, timestep,sfreq,alpha,threshold=0.5, hidden_channels=32,heads=8,negative_slope = 0.01, dropout=0.5):
        super().__init__()
        self.hidden_channels = hidden_channels
        self.recurrent_1 = GATv2Conv(
            sfreq*timestep,hidden_channels,heads=heads,negative_slope=negative_slope,dropout=dropout, add_self_loops=True,improved=True,edge_dim=1)
        self.fc1 = torch.nn.Linear(hidden_channels*heads, 64)
        nn.init.kaiming_uniform_(self.fc1.weight,a=negative_slope)
        self.fc2 = torch.nn.Linear(64, 32)
        nn.init.kaiming_uniform_(self.fc2.weight,a=negative_slope)
        self.fc3 = torch.nn.Linear(32, 1)
        nn.init.kaiming_uniform_(self.fc3.weight,a=negative_slope)
        self.batch_norm = torch.nn.BatchNorm1d(hidden_channels*heads)
        self.dropout = torch.nn.Dropout(dropout)
        self.activation_recurrent = nn.Sequential(
            self.batch_norm,
            nn.LeakyReLU()
        )
        self.classifier = nn.Sequential(
            self.dropout,
            self.fc1,
            nn.LeakyReLU(),
            self.dropout,
            self.fc2,
            nn.LeakyReLU(),
            self.dropout,
            self.fc3 
        )
        self.loss = nn.BCEWithLogitsLoss(pos_weight=torch.full([1], alpha))
        self.sensitivity = BinaryRecall(threshold=threshold)
        self.specificity = BinarySpecificity(threshold=threshold)
        self.auroc = AUROC(task="binary")
    def forward(self, x, edge_index, edge_attr,batch):
        h = self.recurrent_1(x, edge_index, edge_attr)
        h = self.activation_recurrent(h)
        h = global_mean_pool(h,batch)
        h = self.classifier(h)
        return h.squeeze()
    def training_step(self, batch, batch_idx):
        x = batch.x
        x = x.squeeze()
        signal_samples = x.shape[1]
        x = 2 / signal_samples * torch.abs(x)
        x = (x-x.mean(dim=0))/x.std(dim=0)
        edge_index = batch.edge_index
        edge_attr = batch.edge_attr.float()
        y = batch.y
        batches = batch.batch
        y_hat = self(x, edge_index, edge_attr,batches)
        loss = self.loss(y_hat, y)
        self.log('train_loss', loss, on_step=False, on_epoch=True,prog_bar=True)
        self.log('train_sensitivity', self.sensitivity(y_hat, y), on_step=False, on_epoch=True,prog_bar=True)
        self.log('train_specificity', self.specificity(y_hat, y), on_step=False, on_epoch=True,prog_bar=True)
        return loss
    def validation_step(self, batch, batch_idx):
        x = batch.x
        x = x.squeeze()
        signal_samples = x.shape[1]
        x = 2 / signal_samples * torch.abs(x)
        x = (x-x.mean(dim=0))/x.std(dim=0)
        edge_index = batch.edge_index
        edge_attr = batch.edge_attr.float()
        y = batch.y
        batches = batch.batch
        y_hat = self(x, edge_index, edge_attr,batches)
        loss = self.loss(y_hat, y)
        self.log('valid_loss', loss, on_step=False, on_epoch=True,prog_bar=True)
        self.log('valid_sensitivity', self.sensitivity(y_hat, y), on_step=False, on_epoch=True,prog_bar=True)
        self.log('valid_specificity', self.specificity(y_hat, y), on_step=False, on_epoch=True,prog_bar=True)
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=0.001, weight_decay=0.0001)
        return optimizer

            

In [None]:
lightning_model = GATv2Lightning(TIMESTEP,60,alpha,threshold=0.5, hidden_channels=32,heads=8,negative_slope = 0.01, dropout=0.5)

In [None]:
trainer = pl.Trainer(max_epochs=10, callbacks=[pl.callbacks.EarlyStopping(monitor='valid_loss', patience=5, mode='min')],precision=16,gradient_clip_val=0.5)

In [None]:
trainer.fit(lightning_model, train_loader, loso_loader)

In [None]:
explanation = explainer(transform_x(sample.x), sample.edge_index ,edge_attr=sample.edge_attr.float(),batch=sample.batch)
explanation.visualize_graph()

In [30]:
array = np.load("/home/szymon/code/sano/sano_eeg/data/npy_data_full/chb01/chb01_01.npy")
total_samples = array.shape[1]
fs = 256
timestep = 6
overlap = 5
samples_per_recording = 71*fs
random_start_time = np.random.randint(0,total_samples-samples_per_recording)
interictal_period = array[
                :, random_start_time  : random_start_time+samples_per_recording
            ]
interictal_period = (
            np.expand_dims(interictal_period.transpose(), axis=2)
            .swapaxes(0, 2)
            .swapaxes(0, 1)
        )  ##reshape for preprocessing
final_array = utils.prepare_timestep_array(interictal_period, timestep*fs, overlap*fs)

In [11]:
## normal loop
torch_geometric.seed_everything(42)
import torchvision
early_stopping = utils.EarlyStopping(patience=3, verbose=True)
device = torch.device("cpu")
model = GATv2(TIMESTEP,60,batch_size=32).to(device) #Gatv2
loss_fn =  nn.BCEWithLogitsLoss(pos_weight=torch.full([1], alpha))
#loss_fn = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(),lr=0.001,weight_decay=0.0001)
recall = BinaryRecall(threshold=0.5)
specificity = BinarySpecificity(threshold=0.5)
auroc = AUROC(task="binary")
roc = ROC('binary')
model.train()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min',patience=2)

for epoch in tqdm(range(25)):
        try:
                del preds, ground_truth
        except:
                pass
        epoch_loss = 0.0
        epoch_loss_valid = 0.0
        model.train()
        sample_counter = 0
        batch_counter = 0
        print(get_lr(optimizer))
        for time, batch in enumerate(train_loader): ## TODO - this thing is still operating with no edge weights!!!
                ## find a way to compute plv per batch fast (is it even possible?)
        
                x = batch.x
        
                edge_index = batch.edge_index
                edge_attr = batch.edge_attr.float()

                y = batch.y
                batch_idx = batch.batch
                #time_to_seizure = batch.time.float()
               
               # y = torch.tensor(np.array([1 if y_train == 0 else 0 for y_train in y]),dtype=torch.float32)
            
               # x = x.squeeze()
              
                #signal_samples = x.shape[1]
                x = torch.square(torch.abs(x))
                #x = (x-x.mean(dim=0))/x.std(dim=0)
                y_hat = model(x, edge_index,None,batch_idx)
              
           
                # y_hat =torch.stack(
                #         [model(x=x[n], edge_index=edge_index[n], edge_weight=edge_attr[n].float()) 
                #          for n in range(x.shape[0])])
                ##loss
        
                loss = loss_fn(y_hat,y)
                #loss = torchvision.ops.sigmoid_focal_loss(y_hat,y,alpha=alpha,gamma=3,reduction='mean')
                epoch_loss += loss
                ## get preds & gorund truth
                try:
                 preds = torch.cat([preds,y_hat.detach()],dim=0)
                 ground_truth = torch.cat([ground_truth,y],dim=0)
            
                except:
                 preds= y_hat.detach()
                 ground_truth = y
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        ## calculate acc
        
        train_auroc = auroc(preds,ground_truth)
        train_sensitivity = recall(preds,ground_truth)
        train_specificity = specificity(preds,ground_truth)
        del preds, ground_truth
        print(f'Epoch: {epoch}', f'Epoch loss: {epoch_loss.detach().numpy()/(time+1)}')
        print(f'Epoch sensitivity: {train_sensitivity}')
        print(f'Epoch specificity: {train_specificity}')
        print(f'Epoch AUROC: {train_auroc} ')
        model.eval()
        with torch.no_grad():
                try:
                        del preds_valid, ground_truth_valid
                except:
                        pass
                for time_valid, batch_valid in enumerate(valid_loader):
                        x = batch_valid.x
                        edge_index = batch_valid.edge_index
                        edge_attr = batch_valid.edge_attr.float()
                        y_val = batch_valid.y
                        batch_idx = batch_valid.batch
                        #inverse_y_val = torch.tensor(np.array([1 if y == 0 else 0 for y in y_val]),dtype=torch.float32)
                        #time_to_seizure_val = batch_valid.time.float()
                        x = x.squeeze()
                        #signal_samples = x.shape[1]
                        x = torch.square(torch.abs(x))
                        # x = (x-x.mean(dim=0))/x.std(dim=0)
                        
                        y_hat_val = model(x, edge_index,None,batch_idx)
                        loss_valid = loss_fn(y_hat_val,y_val)
                       # loss_valid = torchvision.ops.sigmoid_focal_loss(y_hat_val,inverse_y_val,alpha=alpha,gamma=3,reduction='mean')
                        epoch_loss_valid += loss_valid
                        try:
                         preds_valid = torch.cat([preds_valid,y_hat_val],dim=0)
                         ground_truth_valid = torch.cat([ground_truth_valid,y_val],dim=0)
                        except:
                         preds_valid= y_hat_val
                         ground_truth_valid = y_val
        #scheduler.step(epoch_loss_valid)
        early_stopping(epoch_loss_valid.numpy()/(time_valid+1), model)
        val_auroc = auroc(preds_valid,ground_truth_valid)
        val_sensitivity = recall(preds_valid,ground_truth_valid)
        val_specificity = specificity(preds_valid,ground_truth_valid)
        del preds_valid, ground_truth_valid
        print(f'Epoch val_loss: {epoch_loss_valid.detach().numpy()/(time_valid+1)}')
        print(f'Epoch val_sensitivity: {val_sensitivity}')
        print(f'Epoch val specificity: {val_specificity}')
        print(f'Epoch val AUROC: {val_auroc} ')
        if early_stopping.early_stop:
                print("Early stopping")
                model.load_state_dict(torch.load('checkpoint.pt'))
                break


  0%|          | 0/25 [00:00<?, ?it/s]

0.001
Epoch: 0 Epoch loss: 1.3020436564289717
Epoch sensitivity: 0.5598683953285217
Epoch specificity: 0.7376834750175476
Epoch AUROC: 0.6794337630271912 


  4%|▍         | 1/25 [00:28<11:23, 28.49s/it]

Validation loss decreased (inf --> 1.109545).  Saving model ...
Epoch val_loss: 1.109545111656189
Epoch val_sensitivity: 0.29629629850387573
Epoch val specificity: 0.9724770784378052
Epoch val AUROC: 0.8393120169639587 
0.001
Epoch: 1 Epoch loss: 1.06699447327876
Epoch sensitivity: 0.5960526466369629
Epoch specificity: 0.8096878528594971
Epoch AUROC: 0.7644992470741272 


  8%|▊         | 2/25 [00:55<10:29, 27.37s/it]

Validation loss decreased (1.109545 --> 1.057326).  Saving model ...
Epoch val_loss: 1.057326316833496
Epoch val_sensitivity: 0.40740740299224854
Epoch val specificity: 0.956749677658081
Epoch val AUROC: 0.8571752309799194 
0.001
Epoch: 2 Epoch loss: 0.9829776353570094
Epoch sensitivity: 0.6861842274665833
Epoch specificity: 0.7976986169815063
Epoch AUROC: 0.8110201358795166 


 12%|█▏        | 3/25 [01:19<09:28, 25.85s/it]

Validation loss decreased (1.057326 --> 1.000369).  Saving model ...
Epoch val_loss: 1.0003694295883179
Epoch val_sensitivity: 0.43209877610206604
Epoch val specificity: 0.942332923412323
Epoch val AUROC: 0.8670939207077026 
0.001
Epoch: 3 Epoch loss: 0.9308767736670506
Epoch sensitivity: 0.7348684072494507
Epoch specificity: 0.7796458601951599
Epoch AUROC: 0.8327381014823914 


 16%|█▌        | 4/25 [01:43<08:50, 25.24s/it]

EarlyStopping counter: 1 out of 3
Epoch val_loss: 1.03385329246521
Epoch val_sensitivity: 0.43209877610206604
Epoch val specificity: 0.9541284441947937
Epoch val AUROC: 0.8741484880447388 
0.001
Epoch: 4 Epoch loss: 0.8795564203148344
Epoch sensitivity: 0.7703947424888611
Epoch specificity: 0.78460693359375
Epoch AUROC: 0.8547898530960083 


 20%|██        | 5/25 [02:08<08:22, 25.12s/it]

EarlyStopping counter: 2 out of 3
Epoch val_loss: 1.043623447418213
Epoch val_sensitivity: 0.48148149251937866
Epoch val specificity: 0.9436435103416443
Epoch val AUROC: 0.871414065361023 
0.001
Epoch: 5 Epoch loss: 0.8371732795380976
Epoch sensitivity: 0.7881578803062439
Epoch specificity: 0.7791634798049927
Epoch AUROC: 0.86629718542099 


 24%|██▍       | 6/25 [02:33<07:55, 25.04s/it]

Validation loss decreased (1.000369 --> 0.915884).  Saving model ...
Epoch val_loss: 0.9158839583396912
Epoch val_sensitivity: 0.604938268661499
Epoch val specificity: 0.9213630557060242
Epoch val AUROC: 0.8827565312385559 
0.001
Epoch: 6 Epoch loss: 0.8236094926932893
Epoch sensitivity: 0.7980263233184814
Epoch specificity: 0.780059278011322
Epoch AUROC: 0.8715394735336304 


 28%|██▊       | 7/25 [02:59<07:37, 25.41s/it]

EarlyStopping counter: 1 out of 3
Epoch val_loss: 1.0733695030212402
Epoch val_sensitivity: 0.5061728358268738
Epoch val specificity: 0.9344692230224609
Epoch val AUROC: 0.8773442506790161 
0.001
Epoch: 7 Epoch loss: 0.8187342183998382
Epoch sensitivity: 0.8092105388641357
Epoch specificity: 0.7745469808578491
Epoch AUROC: 0.8741692304611206 


 32%|███▏      | 8/25 [03:25<07:13, 25.50s/it]

EarlyStopping counter: 2 out of 3
Epoch val_loss: 0.9698489308357239
Epoch val_sensitivity: 0.5432098507881165
Epoch val specificity: 0.9187418222427368
Epoch val AUROC: 0.8741647005081177 
0.001
Epoch: 8 Epoch loss: 0.8099378608612425
Epoch sensitivity: 0.82039475440979
Epoch specificity: 0.790532648563385
Epoch AUROC: 0.8799048662185669 


 32%|███▏      | 8/25 [03:49<08:07, 28.65s/it]

EarlyStopping counter: 3 out of 3
Epoch val_loss: 1.0783798694610596
Epoch val_sensitivity: 0.43209877610206604
Epoch val specificity: 0.956749677658081
Epoch val AUROC: 0.8696988821029663 
Early stopping





In [13]:
model.eval()
recall = BinaryRecall(threshold=0.5)
specificity = BinarySpecificity(threshold=0.5)
with torch.no_grad():
    try:
        del preds_valid, ground_truth_valid
    except:
        pass
    epoch_loss_loso = 0.0
    for time_loso, batch_loso in enumerate(loso_loader):
            x = batch_loso.x.to(device)
            edge_index = batch_loso.edge_index.to(device)
            edge_attr = batch_loso.edge_attr.float().to(device)
            y_loso = batch_loso.y.to(device)
            batch_idx = batch_loso.batch.to(device)
            #y_loso = torch.tensor(np.array([1 if y == 0 else 0 for y in y_loso]))
            #time_to_seizure_val = batch_valid.time.float()
            x = x.squeeze()
            # x = batch_valid[0].to(device)
            # edge_index = batch_valid[1].to(device)
            # y_val = batch_valid[3].squeeze().to(device)
           # x = (x-x.mean(dim=0))/x.std(dim=0)
            x = torch.square(torch.abs(x))
            
            
            y_hat_loso = model(x, edge_index,None,batch_idx)
            loss_loso = loss_fn(y_hat_loso,y_loso)
            #loss_loso = torchvision.ops.sigmoid_focal_loss(y_hat,y,alpha=0.65,gamma=4,reduction='mean')
            epoch_loss_loso += loss_loso
            try:
                preds_loso = torch.cat([preds_loso,y_hat_loso],dim=0)
                ground_truth_loso= torch.cat([ground_truth_loso,y_loso],dim=0)
            except:
                preds_loso= y_hat_loso
                ground_truth_loso = y_loso
    loso_auroc = auroc(preds_loso,ground_truth_loso)
    loso_sensitivity = recall(preds_loso,ground_truth_loso)
    loso_specificity = specificity(preds_loso,ground_truth_loso)
    del preds_loso, ground_truth_loso

    print(f'Loso_loss: {epoch_loss_loso.cpu().numpy()/(time_loso+1)}')
    print(f'Loso_sensitivity: {loso_sensitivity}')
    print(f'Loso_specificity: {loso_specificity}')
    print(f'Loso_AUROC: {loso_auroc} ')

Loso_loss: 1.655221700668335
Loso_sensitivity: 0.5496688485145569
Loso_specificity: 0.9800000190734863
Loso_AUROC: 0.9359867572784424 
