In [None]:
import pandas as pd
import numpy as np
from pathlib import Path
import os
from dataclasses import dataclass
import utils
from torch_geometric_temporal import  DynamicGraphTemporalSignal,StaticGraphTemporalSignal
import networkx as nx
from torch_geometric.utils import from_networkx
import scipy
import sklearn
import torch
from torch import nn
from torch.nn import functional as F
from torch_geometric_temporal.nn.recurrent import DCRNN
import torch_geometric.transforms as T

In [None]:
def plv_connectivity(sensors,data):
    """
    Parameters
    ----------
    sensors : INT
        DESCRIPTION. No of sensors used for capturing EEG
    data : Array of float 
        DESCRIPTION. EEG Data
    
    Returns
    -------
    connectivity_matrix : Matrix of float
        DESCRIPTION. PLV connectivity matrix
    connectivity_vector : Vector of flaot 
        DESCRIPTION. PLV connectivity vector
    """
    print("PLV in process.....")
    
    # Predefining connectivity matrix
    connectivity_matrix = np.zeros([sensors,sensors],dtype=float)
    
    # Computing hilbert transform
    data_points = data.shape[-1]
    data_hilbert = np.imag(scipy.signal.hilbert(data))
    phase = np.arctan(data_hilbert/data)
    
    # Computing connectivity matrix 
    for i in range(sensors):
        for k in range(sensors):
            connectivity_matrix[i,k] = np.abs(np.sum(np.exp(1j*(phase[i,:]-phase[k,:]))))/data_points
            
    # Computing connectivity vector
   # connectivity_vector = connectivity_matrix[np.triu_indices(connectivity_matrix.shape[0],k=1)] 
      
    # returning connectivity matrix and vector
    print("PLV done!")
    return connectivity_matrix#, connectivity_vector

In [None]:

@dataclass
class SeizureDataLoader:
    npy_dataset_path :str
    event_tables_path : str
    loso_patient : str = None
    sampling_f : int = 256
    seizure_lookback: int = 600
    sample_timestep: int = 5
    overlap: int = 0
    self_loops : bool = True,
    shuffle=True

    def _get_event_tables(self,patient_name):
        event_table_list = os.listdir(self.event_tables_path)
        patient_start_table, patient_stop_table = [os.path.join(self.event_tables_path,ev_table)
        for ev_table in event_table_list if patient_name in ev_table]
        start_events_dict = self._load_csv_table_events(patient_start_table)
        stop_events_dict = self._load_csv_table_events(patient_stop_table)
        return start_events_dict,stop_events_dict
        
    def _get_recording_events(self,events_dict,recording):
        recording_list = list(events_dict[recording+'.edf'].values())
        recording_events = [int(x) for x in recording_list if not np.isnan(x)]
        return recording_events

    def _load_csv_table_events(self,table_path):
        event_dict = pd.read_csv(table_path).to_dict('index')
        return event_dict
    
    def _get_graph(self,n_nodes):
        """Creates Networx fc graph with self loops"""
        graph = nx.complete_graph(n_nodes)
        self_loops = [[node,node]for node in graph.nodes()]
        graph.add_edges_from(self_loops)
        return graph
    
    def _get_edge_weights_recording(self,plv_values):
        """Method that takes plv values for given recording and assigns them 
        as edge attributes to a fc graph."""
        graph = self._get_graph(plv_values.shape[0])
        garph_dict = {}
        for edge in graph.edges():
            e_start,e_end = edge
            garph_dict[edge] = {'plv':plv_values[e_start,e_end]}
        nx.set_edge_attributes(graph, garph_dict)
        edge_weights = from_networkx(graph).plv.numpy()
        return edge_weights
    
    def _get_edges(self):
        """Method to assign edge attribute. Has to be called AFTER get_dataset() method."""
        graph = self._get_graph(self._features.shape[1])
        edges = np.expand_dims(from_networkx(graph).edge_index.numpy(),axis=0)
        edges_per_sample_train = np.repeat(edges,repeats =self._features.shape[0],axis=0)
        edges_per_sample_val = np.repeat(edges,repeats =self._val_features.shape[0],axis=0)
        self._edges = edges_per_sample_train
        self._val_edges = edges_per_sample_val
        #return edges_per_sample
        
    def _get_labels_and_features(self):
        ## Prepare samples  for the dataset
        patient_list = os.listdir(self.npy_dataset_path)
        for patient in patient_list: # iterate over patient names
            event_tables = self._get_event_tables(patient) # extract start and stop of seizure for patient 
            patient_path = os.path.join(self.npy_dataset_path,patient)
            recording_list = os.listdir(patient_path)
            for record in recording_list: # iterate over recordings for a patient
                recording_path = os.path.join(patient_path,record)
                record_id = record.split('.npy')[0] #  get record id
                start_event_tables = self._get_recording_events(event_tables[0],record_id) # get start events
                stop_event_tables = self._get_recording_events(event_tables[1],record_id) # get stop events
                data_array = np.load(recording_path) # load the recording
                
                plv_edge_weights = np.expand_dims(
                    self._get_edge_weights_recording(
                        np.random.uniform(size=(18,18))
                    #plv_connectivity(data_array.shape[0],data_array)
                ),
                axis = 0
                )
                

                ##TODO add a gateway to reject seizure periods shorter than lookback
                # extract timeseries and labels from the array
                features,labels,time_labels = utils.extract_training_data_and_labels(
                    data_array,
                    start_event_tables,
                    stop_event_tables,
                    fs = self.sampling_f,
                    seizure_lookback = self.seizure_lookback,
                    sample_timestep = self.sample_timestep,
                    overlap = self.overlap,
                )

                time_labels = time_labels.astype(np.int32)
                labels = labels.reshape((labels.shape[0],1)).astype(np.float32)
                if patient == self.loso_patient:
                    try:
                        self._val_features = np.concatenate((self._val_features, features))
                        self._val_labels = np.concatenate((self._val_labels, labels))
                        self._val_time_labels = np.concatenate((self._val_time_labels , time_labels))
                        self._val_edge_weights = np.concatenate((
                            self._val_edge_weights,
                            np.repeat(plv_edge_weights,features.shape[0],axis=0)
                            ))
                    except:
                        self._val_features = features
                        self._val_labels = labels
                        self._val_time_labels = time_labels
                        self._val_edge_weights = np.repeat(plv_edge_weights,features.shape[0],axis=0)
                else:
                    try:
                        self._features = np.concatenate((self._features, features))
                        self._labels = np.concatenate((self._labels, labels))
                        self._time_labels = np.concatenate((self._time_labels , time_labels))
                        self._edge_weights = np.concatenate((
                            self._edge_weights,
                            np.repeat(plv_edge_weights,features.shape[0],axis=0)
                        ))
                       
                    except:
                        print("Creating initial attributes")
                        self._features = features
                        self._labels = labels
                        self._time_labels = time_labels
                        self._edge_weights = np.repeat(plv_edge_weights,features.shape[0],axis=0)
                
        if self.shuffle is True:
            shuffled_features, shuffled_labels, shuffled_time_labels, shuffled_edge_weights = sklearn.utils.shuffle(self._features,self._labels,self._time_labels,self._edge_weights)
            self._features = shuffled_features
            self._labels = shuffled_labels
            self._time_labels = shuffled_time_labels
            self._edge_weights = shuffled_edge_weights
    # TODO define a method to create edges and calculate plv to get weights
    def get_dataset(self) -> DynamicGraphTemporalSignal:
        """Creating the Dog age video keypoints data iterator. The iterator yelds static,
        fully connected, unweighted graphs with bodyparts assigned to given node and label for every
        set of features. A set of features describes given clip (collection of following frames).
        Features are of shape [nodes,features,timesteps].

        Return types:
            * **dataset** *(StaticGraphTemporalSignal)* - The Dog Age Video dataset.
        """
        ### TODO rozkminić o co chodzi z tym całym time labels - na razie wartość liczbowa która tam wchodzi
        ### to shape atrybutu time_labels
        
        self._get_labels_and_features()
        self._get_edges()
        
        train_dataset = DynamicGraphTemporalSignal(
        self._edges, self._edge_weights, self._features, self._labels, time_labels = self._time_labels
        )

        val_dataset = DynamicGraphTemporalSignal(
        self._val_edges, self._val_edge_weights, self._val_features, self._val_labels, time_labels = self._val_time_labels
        )
        
        return train_dataset, val_dataset
                

        

In [None]:
dataloader = SeizureDataLoader(Path('npy_data'),Path('event_tables'),loso_patient='chb16')

In [None]:
train_loader,val_loader=dataloader.get_dataset()

In [None]:
snap = train_loader.__getitem__(1)

In [None]:
snap.x.shape

In [None]:
F.normalize(snap.x,dim=2)

In [None]:
targets = [snapshot.y.numpy()[0] for snapshot in train_loader]

In [None]:
np.unique(targets,return_counts=True)

In [None]:
for snapshot in val_loader:
    print(snapshot)

In [None]:
class RecurrentGCN(torch.nn.Module):
    def __init__(self, node_features, n_nodes=18):
        super(RecurrentGCN, self).__init__()
        out_features = 32
        self.recurrent_1 = DCRNN(node_features, out_features, 1)
        self.fc1 = torch.nn.Linear(out_features*n_nodes, 1)
        self.flatten = torch.nn.Flatten(start_dim=0)
    def forward(self, x, edge_index, edge_weight):
        x = torch.squeeze(x)
        x = F.normalize(x,dim=1)
        h = self.recurrent_1(x, edge_index, edge_weight)
        h = F.relu(h)
        h = self.flatten(h)
        h = self.fc1(h)
        #h = F.tanh(h)
        return h

In [None]:
device = torch.device("cpu")

In [None]:
from sklearn import metrics
def get_accuracy(y_true, y_prob):
    """Binary accuracy calculation"""
    y_prob = np.array(y_prob)
    y_prob = np.where(y_prob <= 0.0, 0, y_prob)
    y_prob = np.where(y_prob > 0.0, 1, y_prob)

    accuracy = metrics.accuracy_score(y_true, y_prob)
    return accuracy

In [None]:
from tqdm import tqdm
model = RecurrentGCN(1280).to(device)

loss_fn =  nn.BCEWithLogitsLoss(pos_weight=torch.full([1], 15))
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)

model.train()

for epoch in tqdm(range(5)):
    
    preds = []
    ground_truth = []
    epoch_loss = 0.0
    epoch_loss_valid = 0.0
    preds_valid = []
    ground_truth_valid = []
    model.train()
    for time, snapshot in enumerate(train_loader):
        
            y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
            ##loss
    
            loss = loss_fn(y_hat,snapshot.y)
            epoch_loss += loss
            ## get preds & gorund truth
            preds.append(y_hat.detach().numpy())
            ground_truth.append(snapshot.y.numpy())
            ##backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    ## calculate acc
    acc = get_accuracy(ground_truth,preds)
    print(f'Epoch: {epoch}', f'Epoch accuracy: {acc}', f'Epoch loss: {epoch_loss.detach().numpy()}')
    model.eval()
    for time_valid, snapshot_valid in enumerate(val_loader):

       # if time_valid == 171 or time_valid == 169:
            y_hat_val = model(snapshot_valid.x,snapshot_valid.edge_index, snapshot_valid.edge_attr)
            loss_valid = loss_fn(y_hat_val,snapshot_valid.y)
            epoch_loss_valid += loss_valid
            preds_valid.append(y_hat_val.detach().numpy())
            ground_truth_valid.append(snapshot_valid.y.numpy())
            
    
    acc_valid = get_accuracy(ground_truth_valid,preds_valid)
    print(f'Epoch: {epoch}', f'Epoch val_accuracy: {acc_valid}', f'Epoch loss: {epoch_loss_valid.detach().numpy()}')

In [None]:
np.where(np.array(ground_truth) == 1)[0] 

In [None]:
np.equal(np.where(np.array(ground_truth) == 1)[0],np.where(np.array(preds) >0)[0] )

In [None]:
np.where(np.array(preds) >0)[0]