In [1]:
import pandas as pd
import numpy as np
from pathlib import Path
import os
from dataclasses import dataclass
import utils
from torch_geometric_temporal import  DynamicGraphTemporalSignal,StaticGraphTemporalSignal, temporal_signal_split, DynamicGraphTemporalSignalBatch
import networkx as nx
from torch_geometric.utils import from_networkx
import scipy
import sklearn
from tqdm import tqdm
from torch_geometric.data import Data
from torch_geometric.loader import DynamicBatchSampler,DataLoader
import torch
from torch import nn
from torch.nn import functional as F
from torchvision.ops import sigmoid_focal_loss
from torch_geometric_temporal.nn.recurrent import DCRNN,  GConvGRU, A3TGCN, TGCN2, TGCN
from torch_geometric_temporal.nn.attention import STConv
from torchmetrics.classification import BinaryRecall, AUROC, ROC
from torch_geometric.nn import global_mean_pool
import timeit
import pytorch_lightning as pl
from torch_geometric.nn import GCNConv,BatchNorm
from sklearn.model_selection import KFold,StratifiedKFold

In [2]:
# TODO think about using kwargs argument here to specify args for dataloader
@dataclass
class SeizureDataLoader:
    npy_dataset_path :Path
    event_tables_path : Path
    plv_values_path : Path
    loso_patient : str = None
    sampling_f : int = 256
    seizure_lookback: int = 600
    sample_timestep: int = 5
    inter_overlap: int = 0
    ictal_overlap: int = 0
    self_loops : bool = True
    balance : bool = True
    train_test_split:  float = None
    
    """Class to prepare dataloaders for eeg seizure perdiction from stored files.

    Attributes:
        npy_dataset_path {Path} -- Path to folder with dataset preprocessed into .npy files.
        event_tables_path {Path} -- Path to folder with .csv files containing seizure events information for every patient.
        loso_patient {str} -- Name of patient to be selected for LOSO valdiation, specified in format "chb{patient_number}"",
        eg. "chb16". (default: {None}).
        samplin_f {int} -- Sampling frequency of the loaded eeg data. (default: {256}).
        seizure_lookback {int} -- Time horizon to sample pre-seizure data (length of period before seizure) in seconds. 
        (default: {600}).
        sample_timestep {int} -- Amounts of seconds analyzed in a single sample. (default: {5}).
        overlap {int} -- Amount of seconds overlap between samples. (default: {0}).
        self_loops {bool} -- Wheather to add self loops to nodes of the graph. (default: {True}).
        shuffle {bool} --  Wheather to shuffle training samples.


    """
    def _get_event_tables(self,patient_name : str) -> tuple[dict,dict]:
        """Read events for given patient into start and stop times lists from .csv extracted files."""

        event_table_list = os.listdir(self.event_tables_path)
        patient_event_tables = [os.path.join(self.event_tables_path,ev_table)
        for ev_table in event_table_list if patient_name in ev_table]
        patient_event_tables = sorted(patient_event_tables)
        patient_start_table = patient_event_tables[0] ## done terribly, but it has to be so for win/linux compat
        patient_stop_table = patient_event_tables[1]
        start_events_dict = pd.read_csv(patient_start_table).to_dict('index')
        stop_events_dict = pd.read_csv(patient_stop_table).to_dict('index')
        return start_events_dict,stop_events_dict
        
    def _get_recording_events(self,events_dict,recording) -> list[int]:
        """Read seizure times into list from event_dict"""
        recording_list = list(events_dict[recording+'.edf'].values())
        recording_events = [int(x) for x in recording_list if not np.isnan(x)]
        return recording_events


    def _get_graph(self,n_nodes: int) -> nx.Graph :
        """Creates Networx fully connected graph with self loops"""
        graph = nx.complete_graph(n_nodes)
        self_loops = [[node,node]for node in graph.nodes()]
        graph.add_edges_from(self_loops)
        return graph
    
    def _get_edge_weights_recording(self,plv_values: np.ndarray) ->np.ndarray:
        """Method that takes plv values for given recording and assigns them 
        as edge attributes to a fc graph."""
        graph = self._get_graph(plv_values.shape[0])
        garph_dict = {}
        for edge in graph.edges():
            e_start,e_end = edge
            garph_dict[edge] = {'plv':plv_values[e_start,e_end]}
        nx.set_edge_attributes(graph, garph_dict)
        edge_weights = from_networkx(graph).plv.numpy()
        return edge_weights
    
    def _get_edges(self):
        """Method to assign edge attributes. Has to be called AFTER get_dataset() method."""
        graph = self._get_graph(self._features.shape[1])
        edges = np.expand_dims(from_networkx(graph).edge_index.numpy(),axis=0)
        edges_per_sample_train = np.repeat(edges,repeats =self._features.shape[0],axis=0)
        self._edges = torch.tensor(edges_per_sample_train)
        if self.loso_patient is not None:
            edges_per_sample_val = np.repeat(edges,repeats =self._val_features.shape[0],axis=0)
            self._val_edges = torch.tensor(edges_per_sample_val)
       
    def _array_to_tensor(self):
        """Method converting features, edges and weights to torch.tensors"""
        self._features = torch.tensor(self._features)
        self._labels = torch.tensor(self._labels)
        self._time_labels = torch.tensor(self._time_labels)
        self._edge_weights = torch.tensor(self._edge_weights)
     
    
    def _val_array_to_tensor(self):
        self._val_features = torch.tensor(self._val_features)
        self._val_labels = torch.tensor(self._val_labels)
        self._val_time_labels = torch.tensor(self._val_time_labels)
        self._val_edge_weights = torch.tensor(self._val_edge_weights)
     
        
    def _get_labels_count(self):
        labels,counts = np.unique(self._labels,return_counts=True)
        self._label_counts = {}
        for n, label in enumerate(labels):
            self._label_counts[int(label)] = counts[n]
        
    def _get_val_labels_count(self):
        labels,counts = np.unique(self._val_labels,return_counts=True)
        self._val_label_counts = {}
        for n, label in enumerate(labels):
            self._val_label_counts[int(label)] = counts[n]
        
    def _balance_classes(self):
        negative_label = self._label_counts[0]
        positive_label = self._label_counts[1]
    
        imbalance = negative_label - positive_label
        negative_indices = np.where(self._labels == 0)[0]
        indices_to_discard = np.random.choice(negative_indices,size = imbalance,replace=False)

        self._features = np.delete(self._features,obj=indices_to_discard,axis=0)
        self._labels = np.delete(self._labels,obj=indices_to_discard,axis=0)
        self._time_labels = np.delete(self._time_labels,obj=indices_to_discard,axis=0)
        self._edge_weights = np.delete(self._edge_weights,obj=indices_to_discard,axis=0)

        
    def _get_labels_features_edge_weights(self):
        """Prepare features, labels, time labels and edge wieghts for training and 
        optionally validation data."""
        patient_list = os.listdir(self.npy_dataset_path)
        for patient in patient_list: # iterate over patient names
            event_tables = self._get_event_tables(patient) # extract start and stop of seizure for patient 
            patient_path = os.path.join(self.npy_dataset_path,patient)
            recording_list = os.listdir(patient_path)
            for record in recording_list: # iterate over recordings for a patient
                recording_path = os.path.join(patient_path,record)
                record_id = record.split('.npy')[0] #  get record id
                start_event_tables = self._get_recording_events(event_tables[0],record_id) # get start events
                stop_event_tables = self._get_recording_events(event_tables[1],record_id) # get stop events
                data_array = np.load(recording_path) # load the recording

                plv_edge_weights = np.expand_dims(
                    self._get_edge_weights_recording(
                        np.load(os.path.join(self.plv_values_path,patient,record))
          
                ),
                axis = 0
                )

                ##TODO add a gateway to reject seizure periods shorter than lookback
                # extract timeseries and labels from the array
                features,labels,time_labels = utils.extract_training_data_and_labels(
                    data_array,
                    start_event_tables,
                    stop_event_tables,
                    fs = self.sampling_f,
                    seizure_lookback = self.seizure_lookback,
                    sample_timestep = self.sample_timestep,
                    inter_overlap = self.inter_overlap,
                    ictal_overlap = self.ictal_overlap
                )
                
                if  features is None:
                    continue
                time_labels = np.expand_dims(time_labels.astype(np.int32),1)
                labels = labels.reshape((labels.shape[0],1)).astype(np.float32)
            
                if patient == self.loso_patient:
                    try:
                        self._val_features = np.concatenate((self._val_features, features))
                        self._val_labels = np.concatenate((self._val_labels, labels))
                        self._val_time_labels = np.concatenate((self._val_time_labels , time_labels))
                        self._val_edge_weights = np.concatenate((
                            self._val_edge_weights,
                            np.repeat(plv_edge_weights,features.shape[0],axis=0)
                            ))
                    except:
                        self._val_features = features
                        self._val_labels = labels
                        self._val_time_labels = time_labels
                        self._val_edge_weights = np.repeat(plv_edge_weights,features.shape[0],axis=0)
                else:
                    try:
                        self._features = np.concatenate((self._features, features))
                        self._labels = np.concatenate((self._labels, labels))
                        self._time_labels = np.concatenate((self._time_labels , time_labels))
                        self._edge_weights = np.concatenate((
                            self._edge_weights,
                            np.repeat(plv_edge_weights,features.shape[0],axis=0)
                        ))
                       
                    except:
                        print("Creating initial attributes")
                        self._features = features
                        self._labels = labels
                        self._time_labels = time_labels
                        self._edge_weights = np.repeat(plv_edge_weights,features.shape[0],axis=0)
                
        
    # TODO define a method to create edges and calculate plv to get weights
    def get_dataset(self) -> DynamicGraphTemporalSignal:

        """Creating graph data iterators. The iterator yelds dynamic, weighted and undirected graphs
        containing self loops. Every node represents one electrode in EEG. The graph is fully connected,
        edge weights are calculated for every EEG recording as PLV between channels (edge weight describes 
        the "strength" of connectivity between two channels in a recording). Node features are values of 
        channel voltages in time. Features are of shape [nodes,features,timesteps].

        Returns:
            train_dataset {DynamicGraphTemporalSignal} -- Training data iterator.
            valid_dataset {DynamicGraphTemporalSignal} -- Validation data iterator (only if loso_patient is
            specified in class constructor).
        """
        ### TODO rozkminić o co chodzi z tym całym time labels - na razie wartość liczbowa która tam wchodzi
        ### to shape atrybutu time_labels
        
        self._get_labels_features_edge_weights()
        if self.balance:
            self._get_labels_count()
            self._balance_classes()
        self._get_edges()
        self._get_labels_count()
        
        self._array_to_tensor()
        
        train_dataset = torch.utils.data.TensorDataset(
        self._features, self._edges, self._edge_weights, self._labels,  self._time_labels
        )
        if self.train_test_split is not None:
            train_dataset, val_dataset = torch.utils.data.random_split(
                train_dataset,[1-self.train_test_split,self.train_test_split]
            )
            
            train_dataloader = torch.utils.data.DataLoader(
                train_dataset, batch_size = 16,shuffle = True,num_workers=2,pin_memory = True, prefetch_factor = 30,
                drop_last=True
            )
            
            val_dataloader = torch.utils.data.DataLoader(
                val_dataset, batch_size = 16,shuffle = False,num_workers=2,pin_memory = True, prefetch_factor = 30,
                drop_last=True
            )
            loaders = [train_dataloader,val_dataloader]
        else:
            train_dataloader = torch.utils.data.DataLoader(
                train_dataset, batch_size = 16,shuffle = True,num_workers=2,pin_memory = True, prefetch_factor = 30,
                drop_last=True
            )
            loaders = [train_dataloader]
        if self.loso_patient:
            self._get_val_labels_count()
            self._val_array_to_tensor()
            loso_dataset = torch.utils.data.TensorDataset(
                self._val_features, self._val_edges, self._val_edge_weights, self._val_labels, self._val_time_labels
            )
            loso_dataloader = torch.utils.data.DataLoader(
                loso_dataset, batch_size = 16,shuffle = False,pin_memory = True,num_workers=2, prefetch_factor = 30,
                drop_last=True
            )
            return (*loaders,loso_dataloader)

        return (*loaders,)
        
                

        

In [3]:
class RecurrentGCN(torch.nn.Module):
    def __init__(self, timestep,sfreq, n_nodes=18,batch_size=32):
        super(RecurrentGCN, self).__init__()
        self.n_nodes = n_nodes
        self.out_features = 128
        self.recurrent_1 = TGCN(timestep*sfreq,32, add_self_loops=True,improved=True)
        self.recurrent_2 = TGCN(32,64,add_self_loops=True,improved=True)
        self.recurrent_3 = TGCN(64,128,add_self_loops=True,improved=True)
        self.fc1 = torch.nn.Linear(n_nodes*128, 64)
        self.fc2 = torch.nn.Linear(64, 32)
        self.fc3 = torch.nn.Linear(32, 16)
        self.fc4 = torch.nn.Linear(16, 1)
        self.flatten = torch.nn.Flatten(start_dim=0)
        self.dropout = torch.nn.Dropout()
    def forward(self, x, edge_index,edge_weight):
        x = torch.squeeze(x)
  
       # x = F.normalize(x,dim=1)
        h = self.recurrent_1(x, edge_index=edge_index, edge_weight = edge_weight)
        
      #  h = torch.nn.BatchNorm1d(18)(h)
        h = F.leaky_relu(h)
        h = self.recurrent_2(h, edge_index,edge_weight)
      #  h = torch.nn.BatchNorm1d(18)(h)
        h = F.leaky_relu(h)
        h = self.recurrent_3(h, edge_index,edge_weight)
      #  h = torch.nn.BatchNorm1d(18)(h)
        h = F.leaky_relu(h)
        h = self.flatten(h)
        #print(h.shape)
        #h = global_mean_pool(h,torch.zeros(self.n_nodes,dtype=torch.int64)).squeeze()
        
        h = self.dropout(h)
        h = self.fc1(h)
        h = F.leaky_relu(h)
        h = self.dropout(h)
        h = self.fc2(h)
        h = F.leaky_relu(h)
        h = self.dropout(h)
        h = self.fc3(h)
        h = F.leaky_relu(h)
        h = self.dropout(h)
        h = self.fc4(h)
        return h

In [4]:
TIMESTEP = 10
INTER_OVERLAP = 0
ICTAL_OVERLAP = 0
SFREQ = 256

dataloader = SeizureDataLoader(
    npy_dataset_path=Path('data/npy_data'),
    event_tables_path=Path('data/event_tables'),
    plv_values_path=Path('data/plv_arrays'),
    loso_patient='chb16',
    sampling_f=SFREQ,
    seizure_lookback=600,
    sample_timestep= TIMESTEP,
    inter_overlap=INTER_OVERLAP,
    ictal_overlap=ICTAL_OVERLAP,
    self_loops=False,
    balance=False,
    train_test_split=0.2
    
    )
train_loader,val_loader,loso_dataloader=dataloader.get_dataset()
del dataloader

Len Seizure features: 4
Creating initial features for this recording at iteration 0
Len Seizure features: 4
Entering here
<class 'numpy.ndarray'>
Creating initial attributes
Len Seizure features: 4
Creating initial features for this recording at iteration 0
<class 'numpy.ndarray'>
Len Seizure features: 4
Creating initial features for this recording at iteration 0
<class 'numpy.ndarray'>
Len Seizure features: 4
Creating initial features for this recording at iteration 0
<class 'numpy.ndarray'>
Len Seizure features: 4
Creating initial features for this recording at iteration 0
<class 'numpy.ndarray'>
Len Seizure features: 4
Creating initial features for this recording at iteration 0
Len Seizure features: 4
Entering here
Len Seizure features: 4
Entering here
<class 'numpy.ndarray'>
Len Seizure features: 4
Creating initial features for this recording at iteration 0
<class 'numpy.ndarray'>
Len Seizure features: 4
Creating initial features for this recording at iteration 0
<class 'numpy.ndar

In [None]:
from sklearn import metrics
def get_accuracy(y_true, y_prob):
    """Binary accuracy calculation"""
    y_prob = np.array(y_prob)
    y_prob = np.where(y_prob <= 0.0, 0, y_prob)
    y_prob = np.where(y_prob > 0.0, 1, y_prob)

    accuracy = metrics.accuracy_score(y_true, y_prob)
    return accuracy

In [None]:
## TODO 
## 1. Implement kfold cross validation .get_dataset() will return also a kfold cross validation split?
## 2. Think of time series augmentation techniques
## 3. Try to add a feature balancing between ictal and iterictal periods.
## 4.Try to run the algorithm on list of patients shown in the articles.
## 5. !Why there is no patient 15 and 20 in the dataset? !!

In [None]:
## kfold loop
from torch.utils.data import SubsetRandomSampler
k=5
splits=KFold(n_splits=k,shuffle=True,random_state=42)
# mean = dataloader._features.squeeze().mean(dim=0)
# std = dataloader._features.squeeze().std(dim=0)

for fold, (train_idx,val_idx) in enumerate(splits.split(np.arange(len(train_loader)))):
        device = torch.device("cpu")
        model = RecurrentGCN(TIMESTEP,SFREQ,batch_size=16).to(device)
        #pos_weight=torch.full([1], 1.1
        loss_fn =  nn.BCEWithLogitsLoss(pos_weight=torch.full([1], 8.64))
        scaler = sklearn.preprocessing.StandardScaler()
        optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
        recall = BinaryRecall(threshold=0.5)
        auroc = AUROC(task="binary")
        roc = ROC('binary')
        model.train()
        
        
        train_sampler = SubsetRandomSampler(train_idx)
        test_sampler = SubsetRandomSampler(val_idx)
        train_loader_fold = DataLoader(train_loader.dataset, batch_size=16, sampler=train_sampler,drop_last = True)
        test_loader_fold = DataLoader(train_loader.dataset, batch_size=16, sampler=test_sampler,drop_last = True)
        print(f'Fold {fold+1}')
        for epoch in tqdm(range(30)):

                preds = []
                ground_truth = []
                epoch_loss = 0.0
                epoch_loss_valid = 0.0
                preds_valid = []
                ground_truth_valid = []
                model.train()
                sample_counter = 0
                batch_counter = 0
                for time, batch in enumerate(train_loader_fold):
                        x, edge_index, edge_attr,y = batch[0:4]
                        x = x.squeeze()
                        #mean = torch.mean(x,dim=0)
                        #std = torch.std(x,dim=0)
                        x = (x-mean)/std
                        y_hat = model(x=x.float(), edge_index=edge_index[0])
                        
                        ##loss
                        loss = loss_fn(y_hat,y)
                        
                        epoch_loss += loss
                        ## get preds & gorund truth
                        preds.append(y_hat.detach().numpy())
                        ground_truth.append(y.numpy())
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                ## calculate acc

                train_auroc = auroc(torch.FloatTensor(preds),torch.FloatTensor(ground_truth))
                sensitivity = recall(torch.FloatTensor(preds),torch.FloatTensor(ground_truth))
                print(f'Epoch: {epoch}',f'Epoch sensitivity: {sensitivity}', f'Epoch loss: {epoch_loss.detach().numpy()/time+1}')
                print(f'Epoch AUROC: {train_auroc} ')
                model.eval()
                with torch.no_grad():
                        for time_valid, batch_valid in enumerate(test_loader_fold):
                                x, edge_index, edge_attr,y_val = batch_valid[0:4]
                                #print(y_val)
                                x = x.squeeze()
                                #mean = torch.mean(x,dim=0)
                                #std = torch.std(x,dim=0)
                                x = (x-mean)/std
                                y_hat_val = model(x=x.float(), edge_index=edge_index[0])
                                loss_valid = loss_fn(y_hat_val,y_val)

                                #loss_valid = sigmoid_focal_loss(y_hat,snapshot.y,alpha=0.8,gamma=1).squeeze()

                                epoch_loss_valid += loss_valid
                                preds_valid.append(y_hat_val.detach().numpy())
                                ground_truth_valid.append(y_val.numpy())

                val_auroc = auroc(torch.FloatTensor(preds_valid),torch.FloatTensor(ground_truth_valid))
                val_sensitivity = recall(torch.FloatTensor(preds_valid),torch.FloatTensor(ground_truth_valid))
                print(f'Epoch: {epoch}',f'Epoch val_sensitivity: {val_sensitivity}', f'Epoch val_loss: {epoch_loss_valid.detach().numpy()/time_valid+1}')
                print(f'Epoch val AUROC: {val_auroc} ')

In [5]:
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

In [6]:
## normal loop
device = torch.device("cpu")
model = RecurrentGCN(TIMESTEP,SFREQ,batch_size=16).to(device)
#pos_weight=torch.full([1], 1.1
loss_fn =  nn.BCEWithLogitsLoss(pos_weight=torch.full([1], 8.65))
optimizer = torch.optim.Adam(model.parameters(),lr=0.0001)
recall = BinaryRecall(threshold=0.5)
auroc = AUROC(task="binary")
roc = ROC('binary')
model.train()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min',patience=2)

for epoch in tqdm(range(20)):

        preds = []
        ground_truth = []
        epoch_loss = 0.0
        epoch_loss_valid = 0.0
        preds_valid = []
        ground_truth_valid = []
        model.train()
        sample_counter = 0
        batch_counter = 0
        print(get_lr(optimizer))
        for time, batch in enumerate(train_loader): ## TODO - this thing is still operating with no edge weights!!!
                ## find a way to compute plv per batch fast (is it even possible?)
        
                x, edge_index, edge_attr,y = batch[0:4]
                
                signal_samples = x.shape[3]
                x = 2 / signal_samples * torch.abs(torch.fft.fft(x))
                
                x = x.squeeze()
                x = (x-x.mean(dim=0))/x.std(dim=0)
                # mean = torch.mean(x,dim=0)
                # std = torch.std(x,dim=0)
                #x = (x-mean)/std
           
                y_hat =torch.stack(
                        [model(x=x[n].float(), edge_index=edge_index[n], edge_weight=edge_attr[n].float()) 
                         for n in range(x.shape[0])])
                ##loss
        
                loss = loss_fn(y_hat,y)
                
                epoch_loss += loss
                ## get preds & gorund truth
                preds.append(y_hat.detach())
                ground_truth.append(y)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        ## calculate acc

        train_auroc = auroc(torch.stack(preds),torch.stack(ground_truth))
        sensitivity = recall(torch.stack(preds),torch.stack(ground_truth))
        print(f'Epoch: {epoch}',f'Epoch sensitivity: {sensitivity}', f'Epoch loss: {epoch_loss.detach().numpy()/time+1}')
        print(f'Epoch AUROC: {train_auroc} ')
        model.eval()
        with torch.no_grad():
                for time_valid, batch_valid in enumerate(val_loader):
                        x, edge_index, edge_attr,y_val = batch_valid[0:4]
                        signal_samples = x.shape[3]
                        x = 2 / signal_samples * torch.abs(torch.fft.fft(x))
                        x = x.squeeze()
                        # mean = torch.mean(x,dim=0)
                        # std = torch.std(x,dim=0)
                        x = (x-x.mean(dim=0))/x.std(dim=0)
                        #x = (x-mean)/std
                        y_hat_val = torch.stack(
                                [model(x=x[n].float(), edge_index=edge_index[n], edge_weight=edge_attr[n].float()) 
                                 for n in range(x.shape[0])])
                        loss_valid = loss_fn(y_hat_val,y_val)

                        #loss_valid = sigmoid_focal_loss(y_hat,snapshot.y,alpha=0.8,gamma=1).squeeze()

                        epoch_loss_valid += loss_valid
                        preds_valid.append(y_hat_val.detach())
                        ground_truth_valid.append(y_val)
        scheduler.step(epoch_loss_valid)
        val_auroc = auroc(torch.stack(preds_valid),torch.stack(ground_truth_valid))
        val_sensitivity = recall(torch.stack(preds_valid),torch.stack(ground_truth_valid))
        print(f'Epoch: {epoch}',f'Epoch val_sensitivity: {val_sensitivity}', f'Epoch val_loss: {epoch_loss_valid.detach().numpy()/time_valid+1}')
        print(f'Epoch val AUROC: {val_auroc} ')

  0%|          | 0/20 [00:00<?, ?it/s]

0.0001
Epoch: 0 Epoch sensitivity: 0.07099143415689468 Epoch loss: 2.203608827007828
Epoch AUROC: 0.6201717853546143 


  5%|▌         | 1/20 [01:14<23:35, 74.48s/it]

Epoch: 0 Epoch val_sensitivity: 0.6634615659713745 Epoch val_loss: 2.149944850376674
Epoch val AUROC: 0.7853262424468994 
0.0001
Epoch: 1 Epoch sensitivity: 0.5544675588607788 Epoch loss: 2.1143140322593075
Epoch AUROC: 0.720425546169281 


 10%|█         | 2/20 [02:33<23:11, 77.30s/it]

Epoch: 1 Epoch val_sensitivity: 0.7115384340286255 Epoch val_loss: 2.0353512234157987
Epoch val AUROC: 0.8009394407272339 
0.0001
Epoch: 2 Epoch sensitivity: 0.665036678314209 Epoch loss: 2.0462087898329635
Epoch AUROC: 0.7677388787269592 


 15%|█▌        | 3/20 [03:52<22:08, 78.12s/it]

Epoch: 2 Epoch val_sensitivity: 0.754807710647583 Epoch val_loss: 1.9467063782707092
Epoch val AUROC: 0.8257949352264404 
0.0001
Epoch: 3 Epoch sensitivity: 0.6719706058502197 Epoch loss: 2.007756220048231
Epoch AUROC: 0.7935322523117065 


 20%|██        | 4/20 [05:12<20:57, 78.60s/it]

Epoch: 3 Epoch val_sensitivity: 0.8990384340286255 Epoch val_loss: 1.9654719034830728
Epoch val AUROC: 0.8395722508430481 
0.0001
Epoch: 4 Epoch sensitivity: 0.702570378780365 Epoch loss: 1.990576508952786
Epoch AUROC: 0.8030022978782654 


 25%|██▌       | 5/20 [06:32<19:48, 79.23s/it]

Epoch: 4 Epoch val_sensitivity: 0.8653846383094788 Epoch val_loss: 2.0219748360770087
Epoch val AUROC: 0.8186941146850586 
0.0001


 25%|██▌       | 5/20 [07:19<21:58, 87.89s/it]


KeyboardInterrupt: 

In [None]:
model.eval()
preds_test = []
ground_truth_test = []
loss_test = 0.0
with torch.no_grad():
        for time_test, batch_test in enumerate(loso_dataloader):
                x, edge_index, edge_attr,y_test = batch_test[0:4]
                signal_samples = x.shape[3]
                x = 2 / signal_samples * torch.abs(torch.fft.fft(x))
                x = x.squeeze()
                # mean = torch.mean(x,dim=0)
                # std = torch.std(x,dim=0)
                x = (x-x.mean(dim=0))/x.std(dim=0)
                y_hat_test = torch.stack(
                                [model(x=x[n].float(), edge_index=edge_index[n], edge_weight=edge_attr[n].float()) 
                                 for n in range(x.shape[0])])
                loss_test = loss_fn(y_hat_test,y_test)

                #loss_valid = sigmoid_focal_loss(y_hat,snapshot.y,alpha=0.8,gamma=1).squeeze()

                loss_test+= loss_test
                preds_test.append(y_hat_val.detach())
                ground_truth_test.append(y_val)

test_auroc = auroc(torch.stack(preds_test),torch.stack(ground_truth_test))
test_sensitivity = recall(torch.stack(preds_test),torch.stack(ground_truth_test))
print(f'Test sensitivity: {test_sensitivity}', f'Test loss: {loss_test.detach().numpy()/time_test+1}')
print(f'Epoch val AUROC: {test_auroc} ')

In [None]:
fpr, tpr, thresholds = roc(torch.stack(preds_valid),torch.stack(ground_truth_valid))

In [None]:
from matplotlib import pyplot as plt
plt.plot(fpr,tpr)

In [None]:
np.equal(np.where(np.array(ground_truth) == 1)[0],np.where(np.array(preds) >0)[0] )