In [24]:
import pandas as pd
import numpy as np
from pathlib import Path
import re
import os
from dataclasses import dataclass
import utils
from torch_geometric_temporal import StaticGraphTemporalSignal, DynamicGraphTemporalSignal
import networkx as nx
from torch_geometric.utils import from_networkx,to_networkx
import scipy
import sklearn

In [2]:
def plv_connectivity(sensors,data):
    """
    Parameters
    ----------
    sensors : INT
        DESCRIPTION. No of sensors used for capturing EEG
    data : Array of float 
        DESCRIPTION. EEG Data
    
    Returns
    -------
    connectivity_matrix : Matrix of float
        DESCRIPTION. PLV connectivity matrix
    connectivity_vector : Vector of flaot 
        DESCRIPTION. PLV connectivity vector
    """
    print("PLV in process.....")
    
    # Predefining connectivity matrix
    connectivity_matrix = np.zeros([sensors,sensors],dtype=float)
    
    # Computing hilbert transform
    data_points = data.shape[-1]
    data_hilbert = np.imag(scipy.signal.hilbert(data))
    phase = np.arctan(data_hilbert/data)
    
    # Computing connectivity matrix 
    for i in range(sensors):
        for k in range(sensors):
            connectivity_matrix[i,k] = np.abs(np.sum(np.exp(1j*(phase[i,:]-phase[k,:]))))/data_points
            
    # Computing connectivity vector
   # connectivity_vector = connectivity_matrix[np.triu_indices(connectivity_matrix.shape[0],k=1)] 
      
    # returning connectivity matrix and vector
    print("PLV done!")
    return connectivity_matrix#, connectivity_vector

In [37]:

@dataclass
class SeizureDataLoader:
    npy_dataset_path :str
    event_tables_path : str
    loso_patient : str = None
    sampling_f : int = 256
    seizure_lookback: int = 600
    sample_timestep: int = 5
    overlap: int = 0
    self_loops : bool = True,
    shuffle=True

    def _get_event_tables(self,patient_name):
        event_table_list = os.listdir(self.event_tables_path)
        patient_start_table, patient_stop_table = [os.path.join(self.event_tables_path,ev_table)
        for ev_table in event_table_list if patient_name in ev_table]
        start_events_dict = self._load_csv_table_events(patient_start_table)
        stop_events_dict = self._load_csv_table_events(patient_stop_table)
        return start_events_dict,stop_events_dict
        
    def _get_recording_events(self,events_dict,recording):
        recording_list = list(events_dict[recording+'.edf'].values())
        recording_events = [int(x) for x in recording_list if not np.isnan(x)]
        return recording_events

    def _load_csv_table_events(self,table_path):
        event_dict = pd.read_csv(table_path).to_dict('index')
        return event_dict
    
    def _get_graph(self,n_nodes):
        """Creates Networx fc graph with self loops"""
        graph = nx.complete_graph(n_nodes)
        self_loops = [[node,node]for node in graph.nodes()]
        graph.add_edges_from(self_loops)
        return graph
    
    def _get_edge_weights_recording(self,plv_values):
        """Method that takes plv values for given recording and assigns them 
        as edge attributes to a fc graph."""
        graph = self._get_graph(plv_values.shape[0])
        garph_dict = {}
        for edge in graph.edges():
            e_start,e_end = edge
            garph_dict[edge] = {'plv':plv_values[e_start,e_end]}
        nx.set_edge_attributes(graph, garph_dict)
        edge_weights = from_networkx(graph).plv.numpy()
        return edge_weights
    
    def _get_edges(self):
        """Method to assign edge attribute. Has to be called AFTER get_dataset() method."""
        graph = self._get_graph(self._features.shape[1])
        edges = np.expand_dims(from_networkx(graph).edge_index.numpy(),axis=0)
        edges_per_sample_train = np.repeat(edges,repeats =self._features.shape[0],axis=0)
        edges_per_sample_val = np.repeat(edges,repeats =self._val_features.shape[0],axis=0)
        self._edges = edges_per_sample_train
        self._val_edges = edges_per_sample_val
        #return edges_per_sample
        
    def _get_labels_and_features(self):
        ## Prepare samples  for the dataset
        patient_list = os.listdir(self.npy_dataset_path)
        for patient in patient_list: # iterate over patient names
            event_tables = self._get_event_tables(patient) # extract start and stop of seizure for patient 
            patient_path = os.path.join(self.npy_dataset_path,patient)
            recording_list = os.listdir(patient_path)
            for record in recording_list: # iterate over recordings for a patient
                recording_path = os.path.join(patient_path,record)
                record_id = record.split('.npy')[0] #  get record id
                start_event_tables = self._get_recording_events(event_tables[0],record_id) # get start events
                stop_event_tables = self._get_recording_events(event_tables[1],record_id) # get stop events
                data_array = np.load(recording_path) # load the recording
                
                plv_edge_weights = np.expand_dims(
                    self._get_edge_weights_recording(
                        np.random.uniform(size=(18,18))
                    #plv_connectivity(data_array.shape[0],data_array)
                ),
                axis = 0
                )
                

                ##TODO add a gateway to reject seizure periods shorter than lookback
                # extract timeseries and labels from the array
                features,labels,time_labels = utils.extract_training_data_and_labels(
                    data_array,
                    start_event_tables,
                    stop_event_tables,
                    fs = self.sampling_f,
                    seizure_lookback = self.seizure_lookback,
                    sample_timestep = self.sample_timestep,
                    overlap = self.overlap,
                )
                time_labels = time_labels.astype(np.int32)
                if patient == self.loso_patient:
                    try:
                        self._val_features = np.concatenate((self._val_features, features))
                        self._val_labels = np.concatenate((self._val_labels, labels))
                        self._val_time_labels = np.concatenate((self._val_time_labels , time_labels))
                        self._val_edge_weights = np.concatenate((
                            self._val_edge_weights,
                            np.repeat(plv_edge_weights,features.shape[0],axis=0)
                            ))
                    except:
                        self._val_features = features
                        self._val_labels = labels
                        self._val_time_labels = time_labels
                        self._val_edge_weights = np.repeat(plv_edge_weights,features.shape[0],axis=0)
                else:
                    try:
                        self._features = np.concatenate((self._features, features))
                        self._labels = np.concatenate((self._labels, labels))
                        self._time_labels = np.concatenate((self._time_labels , time_labels))
                        self._edge_weights = np.concatenate((
                            self._edge_weights,
                            np.repeat(plv_edge_weights,features.shape[0],axis=0)
                        ))
                    except:
                        print("Creating initial attributes")
                        self._features = features
                        self._labels = labels
                        self._time_labels = time_labels
                        self._edge_weights = np.repeat(plv_edge_weights,features.shape[0],axis=0)
                
        if self.shuffle is True:
            shuffled_features, shuffled_labels, shuffled_time_labels, shuffled_edge_weights = sklearn.utils.shuffle(self._features,self._labels,self._time_labels,self._edge_weights)
            self._features = shuffled_features
            self._labels = shuffled_labels
            self._time_labels = shuffled_time_labels
            self._edge_weights = shuffled_edge_weights
    # TODO define a method to create edges and calculate plv to get weights
    def get_dataset(self) -> DynamicGraphTemporalSignal:
        """Creating the Dog age video keypoints data iterator. The iterator yelds static,
        fully connected, unweighted graphs with bodyparts assigned to given node and label for every
        set of features. A set of features describes given clip (collection of following frames).
        Features are of shape [nodes,features,timesteps].

        Return types:
            * **dataset** *(StaticGraphTemporalSignal)* - The Dog Age Video dataset.
        """
               
        
        self._get_labels_and_features()
        self._get_edges()
        #self._get_edge_weights()
        train_dataset = DynamicGraphTemporalSignal(
        self._edges, self._edge_weights, self._features, self._labels, time_labels = self._time_labels
        )

        val_dataset = DynamicGraphTemporalSignal(
        self._val_edges, self._val_edge_weights, self._val_features, self._val_labels, time_labels = self._val_time_labels
        )
        
        return train_dataset, val_dataset
                

        

In [38]:
dataloader = SeizureDataLoader(Path('npy_data'),Path('event_tables'),loso_patient='chb16')

In [39]:
train_loader,val_loader=dataloader.get_dataset()

Creating initial attributes


In [40]:
for data in  train_loader:
    print(data)

Data(x=[18, 1, 1280], edge_index=[2, 324], edge_attr=[324], y=[0], time_labels=[390])
Data(x=[18, 1, 1280], edge_index=[2, 324], edge_attr=[324], y=[0], time_labels=[30])
Data(x=[18, 1, 1280], edge_index=[2, 324], edge_attr=[324], y=[0], time_labels=[350])
Data(x=[18, 1, 1280], edge_index=[2, 324], edge_attr=[324], y=[0], time_labels=[140])
Data(x=[18, 1, 1280], edge_index=[2, 324], edge_attr=[324], y=[0], time_labels=[110])
Data(x=[18, 1, 1280], edge_index=[2, 324], edge_attr=[324], y=[0], time_labels=[110])
Data(x=[18, 1, 1280], edge_index=[2, 324], edge_attr=[324], y=[0], time_labels=[60])
Data(x=[18, 1, 1280], edge_index=[2, 324], edge_attr=[324], y=[0], time_labels=[145])
Data(x=[18, 1, 1280], edge_index=[2, 324], edge_attr=[324], y=[0], time_labels=[215])
Data(x=[18, 1, 1280], edge_index=[2, 324], edge_attr=[324], y=[0], time_labels=[205])
Data(x=[18, 1, 1280], edge_index=[2, 324], edge_attr=[324], y=[0], time_labels=[540])
Data(x=[18, 1, 1280], edge_index=[2, 324], edge_attr=[32

In [6]:
dataloader._get_edges()

array([[[ 0,  0,  0, ..., 17, 17, 17],
        [ 1,  2,  3, ..., 15, 16, 17]],

       [[ 0,  0,  0, ..., 17, 17, 17],
        [ 1,  2,  3, ..., 15, 16, 17]],

       [[ 0,  0,  0, ..., 17, 17, 17],
        [ 1,  2,  3, ..., 15, 16, 17]],

       ...,

       [[ 0,  0,  0, ..., 17, 17, 17],
        [ 1,  2,  3, ..., 15, 16, 17]],

       [[ 0,  0,  0, ..., 17, 17, 17],
        [ 1,  2,  3, ..., 15, 16, 17]],

       [[ 0,  0,  0, ..., 17, 17, 17],
        [ 1,  2,  3, ..., 15, 16, 17]]], dtype=int64)

In [None]:
df = dataloader._get_event_tables('chb16')

In [None]:
list_to_process =dataloader._get_recording_events(df[0], 'chb16_10')

In [None]:
list_to_process

In [None]:
[x for x in list_to_process if not np.isnan(x)]

In [None]:
df.to_dict('index')

In [None]:
path_to_file = Path("raw_dataset\chb16\chb16-summary.txt")
open(path_to_file,'r').readlines()

In [None]:
string = 'chb10_27.edf'

In [None]:
string.split('.edf')[0]

In [None]:
ds_path = Path('preprocessed_data')
target_path = Path('npy_data')

In [None]:
save_timeseries_array(ds_path,target_path)

In [None]:
def get_patient_annotations(path_to_file : Path, savedir : Path):
    raw_txt = open(path_to_file,'r')
    raw_txt_lines = raw_txt.readlines()
    event_dict_start = dict()
    event_dict_stop = dict()
    p = '[\d]+'
    for n,line in enumerate(raw_txt_lines):
        if "File Name" in line:
            current_file_name = line.split(': ')[1][:-1]
        if "Number of Seizures in File" in line:
            num_of_seizures = int(line[-2:])
            if  num_of_seizures > 0:
                events_in_recording = raw_txt_lines[n+1:n+num_of_seizures*2+1]
                for event in events_in_recording:
                    if "Start Time" in event:
                        sub_ev = event.split(': ')[1]
                        time_value = int(re.search(p,sub_ev).group())
                
                        if not current_file_name in event_dict_start.keys():
                            event_dict_start[current_file_name] = [time_value]
                        else:
                            event_dict_start[current_file_name].append(time_value)
                    elif "End Time" in event:
                        sub_ev = event.split(': ')[1]
                        
                        time_value = int(re.search(p,sub_ev).group())
                        
                        if not current_file_name in event_dict_stop.keys():
                            event_dict_stop[current_file_name] = [time_value]
                            
                        else:
                            event_dict_stop[current_file_name].append(time_value)
    df = pd.DataFrame.from_dict(event_dict_start,orient='index')
    col_list = []
    for n in range(1,len(df.columns)+1):
        col_list.append(f'Seizure {n}')
    df_start = pd.DataFrame.from_dict(event_dict_start,orient='index',columns=col_list)
    df_end = pd.DataFrame.from_dict(event_dict_stop,orient='index',columns=col_list)
    patient_id = current_file_name.split('_')[0]
    if not os.path.exists(savedir):
        os.mkdir(savedir)
    dst_dir_start = os.path.join(savedir,f"{patient_id}_start.csv")
    dst_dir_stop = os.path.join(savedir,f"{patient_id}_stop.csv")
    pd.DataFrame.to_csv(df_start,dst_dir_start,index_label=False) 
    pd.DataFrame.to_csv(df_end,dst_dir_stop,index_label=False) 

In [None]:
def get_annotation_files(dataset_path):
    patient_folders = os.listdir(dataset_path)
    for folder in patient_folders:
        patient_folder_path = os.path.join(dataset_path,folder)
        if os.path.isdir(patient_folder_path):
            patient_files = os.listdir(patient_folder_path)
            for filename in patient_files:
                if "summary" in filename:
                    annotation_path = os.path.join(patient_folder_path,filename)
                    get_patient_annotations(annotation_path,Path("event_tables"))


In [None]:
annotation_files = get_annotation_files(Path("raw_dataset"))
