In [1]:
import pandas as pd
import numpy as np
from pathlib import Path
import os
from dataclasses import dataclass
import utils
from torch_geometric_temporal import  DynamicGraphTemporalSignal,StaticGraphTemporalSignal, temporal_signal_split, DynamicGraphTemporalSignalBatch
import networkx as nx
from torch_geometric.utils import from_networkx
import scipy
import sklearn
import random
from tqdm import tqdm
from torch_geometric.data import Data
from torch_geometric.loader import DynamicBatchSampler,DataLoader
import torch
from torch import nn
from torch.nn import functional as F
from torchvision.ops import sigmoid_focal_loss
from torch_geometric_temporal.nn.recurrent import DCRNN,  GConvGRU, A3TGCN, TGCN2, TGCN, A3TGCN2
from torch_geometric_temporal.nn.attention import STConv
from torchmetrics.classification import BinaryRecall,BinarySpecificity, AUROC, ROC
from torch_geometric.nn import global_mean_pool
import pytorch_lightning as pl
from torch_geometric.nn import GCNConv,BatchNorm,GATv2Conv
from sklearn.model_selection import KFold,StratifiedKFold, StratifiedShuffleSplit
from mne_features.univariate import compute_kurtosis, compute_hjorth_complexity, compute_hjorth_mobility
import torchaudio

2023-02-13 14:10:34.720997: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-02-13 14:10:34.938554: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-02-13 14:10:34.948939: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-02-13 14:10:34.948956: I tensorflow/compiler/xla/stream_executor/cuda/cudart_stub.cc:29] Ignore 

In [None]:
torch.manual_seed(42)
random.seed(42)

In [None]:
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

In [None]:
# TODO think about using kwargs argument here to specify args for dataloader
@dataclass
class SeizureDataLoader:
    npy_dataset_path: Path
    event_tables_path: Path
    plv_values_path: Path
    loso_patient: str = None
    sampling_f: int = 256
    seizure_lookback: int = 600
    sample_timestep: int = 5
    inter_overlap: int = 0
    ictal_overlap: int = 0
    self_loops: bool = True
    balance: bool = True
    train_test_split: float = None
    fft: bool = False
    hjorth: bool = False
    downsample: int = None
    buffer_time: int = 15
    batch_size : int = 32
    """Class to prepare dataloaders for eeg seizure perdiction from stored files.

    Attributes:
        npy_dataset_path {Path} -- Path to folder with dataset preprocessed into .npy files.
        event_tables_path {Path} -- Path to folder with .csv files containing seizure events information for every patient.
        loso_patient {str} -- Name of patient to be selected for LOSO valdiation, specified in format "chb{patient_number}"",
        eg. "chb16". (default: {None}).
        samplin_f {int} -- Sampling frequency of the loaded eeg data. (default: {256}).
        seizure_lookback {int} -- Time horizon to sample pre-seizure data (length of period before seizure) in seconds. 
        (default: {600}).
        sample_timestep {int} -- Amounts of seconds analyzed in a single sample. (default: {5}).
        overlap {int} -- Amount of seconds overlap between samples. (default: {0}).
        self_loops {bool} -- Wheather to add self loops to nodes of the graph. (default: {True}).
        shuffle {bool} --  Wheather to shuffle training samples.


    """
    assert (fft and hjorth) == False, "When fft is True, hjorth should be False"

    def _get_event_tables(self, patient_name: str) -> tuple[dict, dict]:
        """Read events for given patient into start and stop times lists from .csv extracted files."""

        event_table_list = os.listdir(self.event_tables_path)
        patient_event_tables = [
            os.path.join(self.event_tables_path, ev_table)
            for ev_table in event_table_list
            if patient_name in ev_table
        ]
        patient_event_tables = sorted(patient_event_tables)
        patient_start_table = patient_event_tables[
            0
        ]  ## done terribly, but it has to be so for win/linux compat
        patient_stop_table = patient_event_tables[1]
        start_events_dict = pd.read_csv(patient_start_table).to_dict("index")
        stop_events_dict = pd.read_csv(patient_stop_table).to_dict("index")
        return start_events_dict, stop_events_dict

    def _get_recording_events(self, events_dict, recording) -> list[int]:
        """Read seizure times into list from event_dict"""
        recording_list = list(events_dict[recording + ".edf"].values())
        recording_events = [int(x) for x in recording_list if not np.isnan(x)]
        return recording_events

    def _get_graph(self, n_nodes: int) -> nx.Graph:
        """Creates Networx fully connected graph with self loops"""
        graph = nx.complete_graph(n_nodes)
        self_loops = [[node, node] for node in graph.nodes()]
        graph.add_edges_from(self_loops)
        return graph

    def _get_edge_weights_recording(self, plv_values: np.ndarray) -> np.ndarray:
        """Method that takes plv values for given recording and assigns them
        as edge attributes to a fc graph."""
        graph = self._get_graph(plv_values.shape[0])
        garph_dict = {}
        for edge in graph.edges():
            e_start, e_end = edge
            garph_dict[edge] = {"plv": plv_values[e_start, e_end]}
        nx.set_edge_attributes(graph, garph_dict)
        edge_weights = from_networkx(graph).plv.numpy()
        return edge_weights

    def _get_edges(self):
        """Method to assign edge attributes. Has to be called AFTER get_dataset() method."""
        graph = self._get_graph(self._features.shape[1])
        edges = np.expand_dims(from_networkx(graph).edge_index.numpy(), axis=0)
        edges_per_sample_train = np.repeat(
            edges, repeats=self._features.shape[0], axis=0
        )
        self._edges = torch.tensor(edges_per_sample_train)
        if self.loso_patient is not None:
            edges_per_sample_val = np.repeat(
                edges, repeats=self._val_features.shape[0], axis=0
            )
            self._val_edges = torch.tensor(edges_per_sample_val)

    def _array_to_tensor(self):
        """Method converting features, edges and weights to torch.tensors"""
        self._features = torch.tensor(self._features, dtype=torch.float32)
        self._labels = torch.tensor(self._labels)
        self._time_labels = torch.tensor(self._time_labels)
        self._edge_weights = torch.tensor(self._edge_weights)

    def _val_array_to_tensor(self):
        self._val_features = torch.tensor(self._val_features, dtype=torch.float32)
        self._val_labels = torch.tensor(self._val_labels)
        self._val_time_labels = torch.tensor(self._val_time_labels)
        self._val_edge_weights = torch.tensor(self._val_edge_weights)

    def _get_labels_count(self):
        labels, counts = np.unique(self._labels, return_counts=True)
        self._label_counts = {}
        for n, label in enumerate(labels):
            self._label_counts[int(label)] = counts[n]

    def _get_val_labels_count(self):
        labels, counts = np.unique(self._val_labels, return_counts=True)
        self._val_label_counts = {}
        for n, label in enumerate(labels):
            self._val_label_counts[int(label)] = counts[n]

    def _perform_features_train_fft(self):
        self._features = torch.fft.fft(self._features)

    def _perform_features_val_fft(self):
        self._val_features = torch.fft.fft(self._val_features)

    def _downsample_features_train(self):
        resampler = torchaudio.transforms.Resample(self.sampling_f, self.downsample)
        self._features = resampler(self._features)

    def _downsample_features_val(self):
        resampler = torchaudio.transforms.Resample(self.sampling_f, self.downsample)
        self._val_features = resampler(self._val_features)

    def _calculate_hjorth_features_train(self):

        new_features = [
            np.concatenate(
                [
                    compute_hjorth_mobility(feature),
                    compute_hjorth_complexity(feature),
                ],
                axis=1,
            )
            for feature in self._features
        ]
        self._features = np.array(new_features)

    def _calculate_hjorth_features_val(self):

        new_features = [
            np.concatenate(
                [
                    compute_hjorth_mobility(feature),
                    compute_hjorth_complexity(feature),
                ],
                axis=1,
            )
            for feature in self._val_features
        ]
        self._val_features = np.array(new_features)
    def _features_to_data_list(self,features,edges,edge_weights,labels, time_label):
        data_list = [
                    Data(
                        x=features[i],
                        edge_index=edges[i],
                        edge_attr=edge_weights[i],
                        y=labels[i],
                        time=time_label[i],
                    )
                    for i in range(len(features))
                ]
        return data_list
    def _split_data_list(self,data_list):
        class_labels = [data.y.item() for data in data_list]
        splitter = StratifiedShuffleSplit(n_splits=1, test_size=self.train_test_split, random_state=42)
        train_indices, val_indices = next(splitter.split(data_list, class_labels))
        data_list_train = [data_list[i] for i in train_indices]
        dataset_list_val = [data_list[i] for i in val_indices]
        return data_list_train, dataset_list_val
    def _balance_classes(self):
        negative_label = self._label_counts[0]
        positive_label = self._label_counts[1]

        imbalance = negative_label - positive_label
        negative_indices = np.where(self._labels == 0)[0]
        indices_to_discard = np.random.choice(
            negative_indices, size=imbalance, replace=False
        )

        self._features = np.delete(self._features, obj=indices_to_discard, axis=0)
        self._labels = np.delete(self._labels, obj=indices_to_discard, axis=0)
        self._time_labels = np.delete(self._time_labels, obj=indices_to_discard, axis=0)
        self._edge_weights = np.delete(
            self._edge_weights, obj=indices_to_discard, axis=0
        )

    def _get_labels_features_edge_weights(self):
        """Prepare features, labels, time labels and edge wieghts for training and
        optionally validation data."""
        patient_list = os.listdir(self.npy_dataset_path)
        for patient in patient_list:  # iterate over patient names
            event_tables = self._get_event_tables(
                patient
            )  # extract start and stop of seizure for patient
            patient_path = os.path.join(self.npy_dataset_path, patient)
            recording_list = os.listdir(patient_path)
            for record in recording_list:  # iterate over recordings for a patient
                recording_path = os.path.join(patient_path, record)
                record_id = record.split(".npy")[0]  #  get record id
                start_event_tables = self._get_recording_events(
                    event_tables[0], record_id
                )  # get start events
                stop_event_tables = self._get_recording_events(
                    event_tables[1], record_id
                )  # get stop events
                data_array = np.load(recording_path)  # load the recording

                plv_edge_weights = np.expand_dims(
                    self._get_edge_weights_recording(
                        np.load(os.path.join(self.plv_values_path, patient, record))
                    ),
                    axis=0,
                )

                ##TODO add a gateway to reject seizure periods shorter than lookback
                # extract timeseries and labels from the array
                features, labels, time_labels = utils.extract_training_data_and_labels(
                    data_array,
                    start_event_tables,
                    stop_event_tables,
                    fs=self.sampling_f,
                    seizure_lookback=self.seizure_lookback,
                    sample_timestep=self.sample_timestep,
                    inter_overlap=self.inter_overlap,
                    ictal_overlap=self.ictal_overlap,
                    buffer_time=self.buffer_time,
                )

                if features is None:
                    continue
                time_labels = np.expand_dims(time_labels.astype(np.int32), 1)
                labels = labels.reshape((labels.shape[0], 1)).astype(np.float32)
                """SCALING FEATURES INTO uV!!!"""
                # features = features*(10**6)
                """SCALING FEATURES INTO uV!!!"""
                if patient == self.loso_patient:
                    try:
                        self._val_features = np.concatenate(
                            (self._val_features, features)
                        )
                        self._val_labels = np.concatenate((self._val_labels, labels))
                        self._val_time_labels = np.concatenate(
                            (self._val_time_labels, time_labels)
                        )
                        self._val_edge_weights = np.concatenate(
                            (
                                self._val_edge_weights,
                                np.repeat(plv_edge_weights, features.shape[0], axis=0),
                            )
                        )
                    except:
                        self._val_features = features
                        self._val_labels = labels
                        self._val_time_labels = time_labels
                        self._val_edge_weights = np.repeat(
                            plv_edge_weights, features.shape[0], axis=0
                        )
                else:
                    try:
                        self._features = np.concatenate((self._features, features))
                        self._labels = np.concatenate((self._labels, labels))
                        self._time_labels = np.concatenate(
                            (self._time_labels, time_labels)
                        )
                        self._edge_weights = np.concatenate(
                            (
                                self._edge_weights,
                                np.repeat(plv_edge_weights, features.shape[0], axis=0),
                            )
                        )

                    except:
                        print("Creating initial attributes")
                        self._features = features
                        self._labels = labels
                        self._time_labels = time_labels
                        self._edge_weights = np.repeat(
                            plv_edge_weights, features.shape[0], axis=0
                        )

    # TODO define a method to create edges and calculate plv to get weights
    def get_dataset(self) -> DynamicGraphTemporalSignal:

        """Creating graph data iterators. The iterator yelds dynamic, weighted and undirected graphs
        containing self loops. Every node represents one electrode in EEG. The graph is fully connected,
        edge weights are calculated for every EEG recording as PLV between channels (edge weight describes
        the "strength" of connectivity between two channels in a recording). Node features are values of
        channel voltages in time. Features are of shape [nodes,features,timesteps].

        Returns:
            train_dataset {DynamicGraphTemporalSignal} -- Training data iterator.
            valid_dataset {DynamicGraphTemporalSignal} -- Validation data iterator (only if loso_patient is
            specified in class constructor).
        """
        ### TODO rozkminić o co chodzi z tym całym time labels - na razie wartość liczbowa która tam wchodzi
        ### to shape atrybutu time_labels

        self._get_labels_features_edge_weights()
        self.train_features_min_max = [self._features.min(), self._features.max()]
        if self.balance:
            self._get_labels_count()
            self._balance_classes()
        self._get_edges()
        self._get_labels_count()
        if self.hjorth:
            self._calculate_hjorth_features_train()
        self._array_to_tensor()
        if self.downsample:
            self._downsample_features_train()
        if self.fft:
            self._perform_features_train_fft()
        train_dataset = torch.utils.data.TensorDataset(
            self._features,
            self._edges,
            self._edge_weights,
            self._labels,
            self._time_labels,
        )
        if self.train_test_split is not None:
            if self.fft or self.hjorth:
                data_list = self._features_to_data_list(
                    self._features,self._edges,self._edge_weights, self._labels, self._time_labels
                )
                train_data_list, val_data_list = self._split_data_list(data_list)
                label_count= np.unique([data.y.item() for data in train_data_list],return_counts=True)[1]
                self.alpha = label_count[0]/label_count[1]
                loaders = [
                    DataLoader(train_data_list, batch_size=self.batch_size, shuffle=True,drop_last=False),
                    DataLoader(val_data_list, batch_size=len(val_data_list), shuffle=False,drop_last=False)
                           ]
                
            else:    
                train_dataset, val_dataset = torch.utils.data.random_split(
                    train_dataset,
                    [1 - self.train_test_split, self.train_test_split],
                    generator=torch.Generator().manual_seed(42),
                )

                train_dataloader = torch.utils.data.DataLoader(
                    train_dataset,
                    batch_size=self.batch_size,
                    shuffle=True,
                    num_workers=2,
                    pin_memory=True,
                    prefetch_factor=4,
                    drop_last=False,
                )

                val_dataloader = torch.utils.data.DataLoader(
                    val_dataset,
                    batch_size=self.batch_size,
                    shuffle=False,
                    num_workers=2,
                    pin_memory=True,
                    prefetch_factor=4,
                    drop_last=False,
                )
                loaders = [train_dataloader, val_dataloader]
        else:
            if self.fft or self.hjorth:
                train_data_list = self._features_to_data_list(
                    self._features,self._edges,self._edge_weights, self._labels, self._time_labels
                )
                loaders = [DataLoader(train_data_list, batch_size=self.batch_size, shuffle=True,drop_last=False)]
            else:
                train_dataloader = torch.utils.data.DataLoader(
                    train_dataset,
                    batch_size=self.batch_size,
                    shuffle=True,
                    num_workers=2,
                    pin_memory=True,
                    prefetch_factor=4,
                    drop_last=False,
                )
                loaders = [train_dataloader]
        if self.loso_patient:
            self.val_features_min_max = [
                self._val_features.min(),
                self._val_features.max(),
            ]
            self._get_val_labels_count()
            if self.hjorth:
                self._calculate_hjorth_features_val()
            self._val_array_to_tensor()
            if self.downsample:
                self._downsample_features_val()
            if self.fft:
                self._perform_features_val_fft()
            if self.fft or self.hjorth:
                loso_data_list = self._features_to_data_list(
                    self._val_features,self._val_edges,self._val_edge_weights, self._val_labels, self._val_time_labels
                )
                return (*loaders, DataLoader(loso_data_list, batch_size=len(loso_data_list), shuffle=False,drop_last=False))
            loso_dataset = torch.utils.data.TensorDataset(
                self._val_features,
                self._val_edges,
                self._val_edge_weights,
                self._val_labels,
                self._val_time_labels,
            )
            loso_dataloader = torch.utils.data.DataLoader(
                loso_dataset,
                batch_size=self.batch_size,
                shuffle=False,
                pin_memory=True,
                num_workers=2,
                prefetch_factor=4,
                drop_last=False,
            )
            return (*loaders, loso_dataloader)

        return (*loaders,)

In [None]:
class RecurrentGCN(torch.nn.Module):
    def __init__(self, timestep,sfreq, n_nodes=18,batch_size=32):
        super(RecurrentGCN, self).__init__()
        self.n_nodes = n_nodes
        self.out_features = 128
        self.recurrent_1 = GCNConv(sfreq*timestep,32, add_self_loops=True,improved=False)
        self.recurrent_2 = GCNConv(32,64,add_self_loops=True,improved=False)
        self.recurrent_3 = GCNConv(64,128,add_self_loops=True,improved=False)
        self.fc1 = torch.nn.Linear(n_nodes*128, 64)
        self.fc2 = torch.nn.Linear(64, 32)
        self.fc3 = torch.nn.Linear(32, 16)
        self.fc4 = torch.nn.Linear(16, 1)
        self.flatten = torch.nn.Flatten(start_dim=0)
        self.dropout = torch.nn.Dropout()
    def forward(self, x, edge_index,edge_weight):
        x = torch.squeeze(x)
        h = self.recurrent_1(x, edge_index=edge_index, edge_weight = edge_weight)
        h = torch.nn.BatchNorm1d(32)(h)
        h = F.leaky_relu(h)
        h = self.recurrent_2(h, edge_index,edge_weight)
        h = torch.nn.BatchNorm1d(64)(h)
        h = F.leaky_relu(h)
        h = self.recurrent_3(h, edge_index,edge_weight)
        h = torch.nn.BatchNorm1d(128)(h)
        h = F.leaky_relu(h)
        h = self.flatten(h)
        
        h = self.dropout(h)
        h = self.fc1(h)
        h = F.leaky_relu(h)
        h = self.dropout(h)
        h = self.fc2(h)
        h = F.leaky_relu(h)
        h = self.dropout(h)
        h = self.fc3(h)
        h = F.leaky_relu(h)
        h = self.dropout(h)
        h = self.fc4(h)
        return h

In [None]:
class GATv2(torch.nn.Module):
    def __init__(self, timestep,sfreq, n_nodes=18,batch_size=32):
        super(GATv2, self).__init__()
        self.n_nodes = n_nodes
        self.out_features = 128
        self.recurrent_1 = GATv2Conv(sfreq*timestep,32,heads=6, add_self_loops=True,improved=False,edge_dim=1)

        self.fc1 = torch.nn.Linear(192, 64)
        self.fc2 = torch.nn.Linear(64, 32)
        self.fc3 = torch.nn.Linear(32, 1)
        self.fc4 = torch.nn.Linear(128, 1)
        self.flatten = torch.nn.Flatten(start_dim=0)
        self.dropout = torch.nn.Dropout()
    def forward(self, x, edge_index,edge_weight, batch):
        h = self.recurrent_1(x, edge_index=edge_index, edge_attr = edge_weight)
        h = torch.nn.BatchNorm1d(192)(h)
        h = F.leaky_relu(h)
        h = global_mean_pool(h,batch)
        h = self.dropout(h)
        h = self.fc1(h)
        h = F.leaky_relu(h)
        h = self.dropout(h)
        h = self.fc2(h)
        h = F.leaky_relu(h)
        h = self.dropout(h)
        h = self.fc3(h)
        # h = F.leaky_relu(h)
        # h = self.dropout(h)
        # h = self.fc4(h)
        return h.squeeze()

In [None]:
TIMESTEP = 3
INTER_OVERLAP = 0
ICTAL_OVERLAP = 0
SFREQ = 256

dataloader = SeizureDataLoader(
    npy_dataset_path=Path('data/npy_data'),
    event_tables_path=Path('data/event_tables'),
    plv_values_path=Path('data/plv_arrays'),
    loso_patient='chb16',
    sampling_f=SFREQ,
    seizure_lookback=600,
    sample_timestep= TIMESTEP,
    inter_overlap=INTER_OVERLAP,
    ictal_overlap=ICTAL_OVERLAP,
    self_loops=False,
    balance=False,
  #  train_test_split=0.2,
    fft=False,
    hjorth=True,
    downsample=60,
    batch_size=64
    )
train_loader,loso_loader=dataloader.get_dataset()
alpha = list(dataloader._label_counts.values())[0]/list(dataloader._label_counts.values())[1]


Creating initial attributes
Creating initial attributes


In [None]:
class LightningGATv2(pl.LightningModule):
    def __init__(self, timestep,sfreq,alpha):
        super().__init__()
        self.model = nn.Sequential(
            GATv2Conv(sfreq*timestep,32,heads=6, add_self_loops=True,improved=False),
            nn.BatchNorm1d(192),
            nn.Flatten(start_dim=0),
            nn.Dropout(),
            nn.Linear(3456, 1024),
            nn.Dropout(),
            nn.Linear(1024, 512),
            nn.Dropout(),
            nn.Linear(512, 128),
            nn.Dropout(),
            nn.Linear(128, 1)
        )   
        self.criterion = torch.nn.BCEWithLogitsLoss()
        self.alpha = alpha
        self.save_hyperparameters()
        self.recall = BinaryRecall(threshold=0.5)
        self.specificity = BinarySpecificity(threshold=0.6)
        self.auroc = AUROC(task="binary")
    def forward(self, x, edge_index,edge_weight):
        x = torch.squeeze(x)
        return self.model(x, edge_index,edge_weight)
    def training_step(self, batch, batch_idx):
        x, edge_index, edge_weight, y = batch[0:4]
        y_hat = self(x, edge_index,edge_weight)
        loss = self.criterion(y_hat, y)
        self.log('train_loss', loss)
        return loss

In [None]:
## normal loop

device = torch.device("cpu")
model = GATv2(TIMESTEP,60,batch_size=32).to(device)
loss_fn =  nn.BCEWithLogitsLoss(pos_weight=torch.full([1], alpha))
#loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
recall = BinaryRecall(threshold=0.5)
specificity = BinarySpecificity(threshold=0.6)
auroc = AUROC(task="binary")
roc = ROC('binary')
model.train()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min',patience=2)

for epoch in tqdm(range(13)):
        try:
                del preds, ground_truth
        except:
                pass
        epoch_loss = 0.0
        epoch_loss_valid = 0.0
        model.train()
        sample_counter = 0
        batch_counter = 0
        print(get_lr(optimizer))
        for time, batch in enumerate(train_loader): ## TODO - this thing is still operating with no edge weights!!!
                ## find a way to compute plv per batch fast (is it even possible?)
        
                x = batch.x
                edge_index = batch.edge_index
                edge_attr = batch.edge_attr.float()
                y = batch.y
                batch_idx = batch.batch
                time_to_seizure = batch.time.float()
                x = x.squeeze()
                signal_samples = x.shape[1]
                x = 2 / signal_samples * torch.abs(x)
                x = (x-x.mean(dim=0))/x.std(dim=0)
                y_hat = model(x, edge_index,edge_attr,batch_idx)
                
           
                # y_hat =torch.stack(
                #         [model(x=x[n], edge_index=edge_index[n], edge_weight=edge_attr[n].float()) 
                #          for n in range(x.shape[0])])
                ##loss
        
                loss = loss_fn(y_hat,y)
                #loss = torchvision.ops.sigmoid_focal_loss(y_hat,y,alpha=alpha*0.1,gamma=2,reduction='mean')
                epoch_loss += loss
                ## get preds & gorund truth
                try:
                 preds = torch.cat([preds,y_hat.detach()],dim=0)
                 ground_truth = torch.cat([ground_truth,y],dim=0)
            
                except:
                 preds= y_hat.detach()
                 ground_truth = y
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        ## calculate acc

        train_auroc = auroc(preds,ground_truth)
        train_sensitivity = recall(preds,ground_truth)
        train_specificity = specificity(preds,ground_truth)
        del preds, ground_truth
        print(f'Epoch: {epoch}', f'Epoch loss: {epoch_loss.detach().numpy()/(time+1)}')
        print(f'Epoch sensitivity: {train_sensitivity}')
        print(f'Epoch specificity: {train_specificity}')
        print(f'Epoch AUROC: {train_auroc} ')
        model.eval()
        with torch.no_grad():
                try:
                        del preds_valid, ground_truth_valid
                except:
                        pass
                for time_valid, batch_valid in enumerate(loso_loader):
                        x = batch_valid.x
                        edge_index = batch_valid.edge_index
                        edge_attr = batch_valid.edge_attr.float()
                        y_val = batch_valid.y
                        batch_idx = batch_valid.batch
                        x = x.squeeze()
                        time_to_seizure_val = batch_valid.time.float()
                        signal_samples = x.shape[1]
                        x = 2 / signal_samples * torch.abs(x)
                        x = (x-x.mean(dim=0))/x.std(dim=0)
                        
                        y_hat_val = model(x, edge_index,edge_attr,batch_idx)
                        loss_valid = loss_fn(y_hat_val,y_val)
                        #loss_valid = torchvision.ops.sigmoid_focal_loss(y_hat,y,alpha=alpha*0.1,gamma=2,reduction='mean')
                        epoch_loss_valid += loss_valid
                        try:
                         preds_valid = torch.cat([preds_valid,y_hat_val],dim=0)
                         ground_truth_valid = torch.cat([ground_truth_valid,y_val],dim=0)
                        except:
                         preds_valid= y_hat_val
                         ground_truth_valid = y_val
        scheduler.step(epoch_loss_valid)
        val_auroc = auroc(preds_valid,ground_truth_valid)
        val_sensitivity = recall(preds_valid,ground_truth_valid)
        val_specificity = specificity(preds_valid,ground_truth_valid)
        del preds_valid, ground_truth_valid
        print(f'Epoch val_loss: {epoch_loss_valid.detach().numpy()/(time_valid+1)}')
        print(f'Epoch val_sensitivity: {val_sensitivity}')
        print(f'Epoch val specificity: {val_specificity}')
        print(f'Epoch val AUROC: {val_auroc} ')

  0%|          | 0/13 [00:00<?, ?it/s]

0.001
Epoch: 0 Epoch loss: 0.8945258950287441
Epoch sensitivity: 0.8144013285636902
Epoch specificity: 0.8548429608345032
Epoch AUROC: 0.8456469774246216 


  8%|▊         | 1/13 [00:24<04:59, 24.93s/it]

Epoch val_loss: 0.5601881146430969
Epoch val_sensitivity: 0.8571428656578064
Epoch val specificity: 0.7856636047363281
Epoch val AUROC: 0.8961438536643982 
0.001
Epoch: 1 Epoch loss: 0.7481328424417748
Epoch sensitivity: 0.844264566898346
Epoch specificity: 0.8718659281730652
Epoch AUROC: 0.8967494368553162 


 15%|█▌        | 2/13 [00:50<04:35, 25.08s/it]

Epoch val_loss: 0.5388560891151428
Epoch val_sensitivity: 0.8571428656578064
Epoch val specificity: 0.8168914318084717
Epoch val AUROC: 0.8875595331192017 
0.001
Epoch: 2 Epoch loss: 0.7110874391951651
Epoch sensitivity: 0.8361707925796509
Epoch specificity: 0.885688841342926
Epoch AUROC: 0.905806303024292 


 23%|██▎       | 3/13 [01:17<04:23, 26.34s/it]

Epoch val_loss: 0.48474112153053284
Epoch val_sensitivity: 0.9047619104385376
Epoch val specificity: 0.8225691914558411
Epoch val AUROC: 0.9212207198143005 
0.001
Epoch: 3 Epoch loss: 0.6817307958063089
Epoch sensitivity: 0.8462182283401489
Epoch specificity: 0.8891198039054871
Epoch AUROC: 0.9135712385177612 


 31%|███       | 4/13 [01:42<03:50, 25.59s/it]

Epoch val_loss: 0.513518750667572
Epoch val_sensitivity: 0.9047619104385376
Epoch val specificity: 0.8147622346878052
Epoch val AUROC: 0.8679577708244324 
0.001
Epoch: 4 Epoch loss: 0.6685941876105542
Epoch sensitivity: 0.8590566515922546
Epoch specificity: 0.8911651968955994
Epoch AUROC: 0.9171880483627319 


 38%|███▊      | 5/13 [02:06<03:21, 25.18s/it]

Epoch val_loss: 0.4794256389141083
Epoch val_sensitivity: 0.8095238208770752
Epoch val specificity: 0.8374733924865723
Epoch val AUROC: 0.876001238822937 
0.001
Epoch: 5 Epoch loss: 0.6265284124410377
Epoch sensitivity: 0.8568238615989685
Epoch specificity: 0.902711808681488
Epoch AUROC: 0.926659345626831 


 46%|████▌     | 6/13 [02:34<03:01, 25.89s/it]

Epoch val_loss: 0.5233948826789856
Epoch val_sensitivity: 0.8571428656578064
Epoch val specificity: 0.8225691914558411
Epoch val AUROC: 0.8948595523834229 
0.001
Epoch: 6 Epoch loss: 0.6135775584094929
Epoch sensitivity: 0.8713368773460388
Epoch specificity: 0.9034375548362732
Epoch AUROC: 0.930705189704895 


 54%|█████▍    | 7/13 [03:01<02:39, 26.52s/it]

Epoch val_loss: 0.4945216476917267
Epoch val_sensitivity: 0.9523809552192688
Epoch val specificity: 0.8289567232131958
Epoch val AUROC: 0.9135151505470276 
0.001
Epoch: 7 Epoch loss: 0.5914310095445166
Epoch sensitivity: 0.8660340309143066
Epoch specificity: 0.9041633605957031
Epoch AUROC: 0.935143232345581 


 62%|██████▏   | 8/13 [03:31<02:16, 27.39s/it]

Epoch val_loss: 0.5303516387939453
Epoch val_sensitivity: 0.8571428656578064
Epoch val specificity: 0.8183108568191528
Epoch val AUROC: 0.8924937844276428 
0.0001
Epoch: 8 Epoch loss: 0.5638482435694281
Epoch sensitivity: 0.8869662284851074
Epoch specificity: 0.9064397215843201
Epoch AUROC: 0.9415299892425537 


 69%|██████▉   | 9/13 [04:00<01:51, 27.94s/it]

Epoch val_loss: 0.4699624180793762
Epoch val_sensitivity: 0.761904776096344
Epoch val specificity: 0.8431511521339417
Epoch val AUROC: 0.8873229026794434 
0.0001
Epoch: 9 Epoch loss: 0.5548006453604069
Epoch sensitivity: 0.8827797770500183
Epoch specificity: 0.910596489906311
Epoch AUROC: 0.9428689479827881 


 77%|███████▋  | 10/13 [04:27<01:23, 27.84s/it]

Epoch val_loss: 0.4627740681171417
Epoch val_sensitivity: 0.9047619104385376
Epoch val specificity: 0.8459900617599487
Epoch val AUROC: 0.8996925354003906 
0.0001
Epoch: 10 Epoch loss: 0.5387183423312205
Epoch sensitivity: 0.8794306516647339
Epoch specificity: 0.915808916091919
Epoch AUROC: 0.9457902908325195 


 85%|████████▍ | 11/13 [04:58<00:57, 28.61s/it]

Epoch val_loss: 0.4848044514656067
Epoch val_sensitivity: 0.8571428656578064
Epoch val specificity: 0.8431511521339417
Epoch val AUROC: 0.8963466286659241 
0.0001
Epoch: 11 Epoch loss: 0.5415818124447229
Epoch sensitivity: 0.8883616924285889
Epoch specificity: 0.9121140241622925
Epoch AUROC: 0.94566810131073 


 92%|█████████▏| 12/13 [05:26<00:28, 28.38s/it]

Epoch val_loss: 0.46443280577659607
Epoch val_sensitivity: 0.8095238208770752
Epoch val specificity: 0.859474778175354
Epoch val AUROC: 0.8764067888259888 
0.0001
Epoch: 12 Epoch loss: 0.5278212637271521
Epoch sensitivity: 0.886687159538269
Epoch specificity: 0.9194708466529846
Epoch AUROC: 0.9480636119842529 


100%|██████████| 13/13 [05:55<00:00, 27.34s/it]

Epoch val_loss: 0.4895547032356262
Epoch val_sensitivity: 0.8571428656578064
Epoch val specificity: 0.853087306022644
Epoch val AUROC: 0.8872553110122681 





In [None]:
model.eval()
epoch_loss_valid = 0.0
with torch.no_grad():
        for time_valid, batch_valid in enumerate([loso_loader]):
                x = batch.x
                edge_index = batch.edge_index
                edge_attr = batch.edge_attr.float()
                y_val = batch.y
                batch_idx = batch.batch
                x = x.squeeze()
                signal_samples = x.shape[1]
                x = 2 / signal_samples * torch.abs(x)
                x = (x-x.mean(dim=0))/x.std(dim=0)
                
                y_hat_val = model(x, edge_index,edge_attr,batch_idx)
                loss_valid = loss_fn(y_hat_val,y_val)
                #loss_valid = torchvision.ops.sigmoid_focal_loss(y_hat,y,alpha=alpha*0.1,gamma=2,reduction='mean')
                epoch_loss_valid += loss_valid
                try:
                    preds_valid = torch.cat([preds_valid,y_hat_val],dim=0)
                    ground_truth_valid = torch.cat([ground_truth_valid,y_val],dim=0)
                except:
                    preds_valid= y_hat_val
                    ground_truth_valid = y_val
#scheduler.step(epoch_loss_valid)
val_auroc = auroc(preds_valid,ground_truth_valid)
val_sensitivity = recall(preds_valid,ground_truth_valid)
val_specificity = specificity(preds_valid,ground_truth_valid)
del preds_valid, ground_truth_valid
print(f'Epoch: {epoch}',f'Epoch val_sensitivity: {val_sensitivity}', f'Epoch val_loss: {epoch_loss_valid.detach().numpy()/(time_valid+1)}')
print(f'Epoch val specificity: {train_specificity}')
print(f'Epoch val AUROC: {val_auroc} ')
del epoch_loss_valid

In [None]:
time_valid+1