In [1]:
import mne
import mne.io

In [2]:
import matplotlib

matplotlib.use('Agg')
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os, sys
import natsort

# %matplotlib inline

In [3]:
ecog_root_dir = "/home/pfilipia/inria/chu_nice_inria/patients_ecog"
connections_root_dir = "/user/pfilipia/home/inria/chu_nice_inria/patients_dmri"


def get_ecog_data_dir(patient_id):
    return os.path.join(ecog_root_dir, "patient%02d" % patient_id)


def get_connections_file(patient_id):
    return os.path.join(
        connections_root_dir, 
        "patient%02d/bids/sub-patient%02d/ses-presurgical/connectivity/connections_common_avg_seed5k_after_shift.csv" % (patient_id, patient_id)
    )
    

In [4]:
def load_data(data_path, timestamp, channel_ids, channel_names):
    data_file_prefix = data_path + "/" + timestamp
    
    raw_time_series = mne.io.read_raw_brainvision(
        vhdr_fname="%s_brainvision.vhdr" % data_file_prefix, montage=None
    )
    
    recorded_data = raw_time_series.get_data()[channel_ids]
    
    recorded_data_info = mne.create_info(
        ch_names=channel_names,
        sfreq=raw_time_series.info['sfreq'], ch_types='eeg'
    )
    recorded_data_time_series = mne.io.RawArray(recorded_data, recorded_data_info)
        
    recorded_data_time_series.load_data()
    recorded_data_time_series.filter(l_freq=0.5, h_freq=1000)
    
    events_data = pd.read_csv(
        "%s_events.csv" % data_file_prefix, header=0
    )
    
    return recorded_data_time_series, events_data

In [5]:
def common_average(raw_time_series):
    raw_data = raw_time_series.get_data()

    common_avg = np.mean(raw_data, axis=0)
    raw_data_common_avg = raw_data - common_avg

    return mne.io.RawArray(raw_data_common_avg, raw_time_series.info)

In [6]:
def common_reference(raw_time_series, reference_electrode_id):

    raw_data = raw_time_series.get_data()
    raw_data_common_ref = raw_data - raw_data[reference_electrode_id]

    return mne.io.RawArray(raw_data_common_ref, raw_time_series.info)

In [7]:
def epoch_stimulation_site(stimulation_site_data, time_series_raw, offset):

    time_series_data = time_series_raw.get_data().T
    artifact_window_length = int(np.round(time_series_raw.info['sfreq'] / 50))
    
    for stimulation_site_row in stimulation_site_data.iterrows():
            
        events_range = np.arange(
            stimulation_site_row[1]['time_begin'],
            stimulation_site_row[1]['time_end'],
            stimulation_site_row[1]['time_interval']
        )

        for t_start in events_range:
            t_start_index = time_series_raw.time_as_index(t_start)[0]
            time_series_data[t_start_index + offset - artifact_window_length : t_start_index + offset, :] = 0
#             time_series_data[t_start_index + offset - artifact_window_length : t_start_index + offset, :] = time_series_data[t_start_index + offset - 2 * artifact_window_length : t_start_index + offset - artifact_window_length, :]

    raw_without_artifact = mne.io.RawArray(time_series_data.T, time_series_raw.info)
#     raw_without_artifact.notch_filter(freqs=(50, 100), notch_widths=9)
#     raw_without_artifact.filter(l_freq=0.5, h_freq=1000)
    raw_without_artifact.filter(l_freq=0.5, h_freq=30)
            
    events = np.empty((0, 3), dtype=int)

    for stimulation_site_row in stimulation_site_data.iterrows():
            
        t_event_1_beg = stimulation_site_row[1]['time_begin'] + stimulation_site_row[1]['time_interval']
        t_event_1_end = stimulation_site_row[1]['time_end'] - stimulation_site_row[1]['time_interval']
        t_interval = stimulation_site_row[1]['time_interval']

        for t_start in np.arange(t_event_1_beg, t_event_1_end, t_interval):
            events = np.vstack([
                events, 
                np.array([raw_without_artifact.time_as_index(t_start), 0, 1], dtype=int)
            ])

                        
    event_id = dict(stimulation=1)

    t_min = 0
    t_max = t_interval

    return mne.Epochs(raw_without_artifact, events, event_id, t_min, t_max) #, baseline=(events_data.loc[stim_id]['time_interval'] - 0.005, events_data.loc[stim_id]['time_interval']))

In [8]:
def get_electrode_id(electrode_names, electrode_name):
    
    if isinstance(electrode_name, int):
        electrode_name = 'e' + str(electrode_name)
        
    return electrode_names.index(electrode_name)
    

def get_electrode_coords(electrode_locations, electrode_name):

    if isinstance(electrode_name, str):
        electrode_name_numeric = int(electrode_name.replace('e', ''))
    else:
        electrode_name_numeric = int(electrode_name)
    
    electrode_coords = np.where(electrode_locations == electrode_name_numeric)

    return electrode_coords[0][0], electrode_coords[1][0]


def get_subplot_data(electrode_locations, electrode_names):

    subplot_rows = electrode_locations.shape[0]
    subplot_cols = electrode_locations.shape[1]

    electrodes_num = len(electrode_names)
    subplot_ids = np.zeros(electrodes_num)

    for i in range(electrodes_num):

        electrode_coords = get_electrode_coords(electrode_locations, electrode_names[i])
        subplot_ids[i] = electrode_coords[0] * subplot_cols + electrode_coords[1] + 1
        
    return subplot_rows, subplot_cols, subplot_ids


def in_proximity_of_stimulation_site(
    ref_electrode_id, stimulation_site_electrodes, 
    electrode_names, electrode_locations, radius = 1
):
    
    ref_electrode_coords = np.array(get_electrode_coords(
        electrode_locations, electrode_names[ref_electrode_id]
    ))

    for stimulation_site_electrode in stimulation_site_electrodes:
    
        try:
            stimulation_site_electrode_coords = np.array(get_electrode_coords(
                electrode_locations, stimulation_site_electrode
            ))
        except:
            continue
            
        if np.max(np.abs(ref_electrode_coords - stimulation_site_electrode_coords)) <= radius:
            return True
        
    return False


In [9]:
# def plot_local_extrema(data, plot_ylim, offset, p1_lower_bound = 30, p1_upper_bound = 400):
    
#     n1 = 0
#     n2 = 0
#     p1 = 0
    
#     arg_min = np.argmin(data[p1_lower_bound : p1_upper_bound])

#     if arg_min > 0 and arg_min < p1_upper_bound - 1:
        
#         p1 = p1_lower_bound + arg_min
#         plt.plot([p1, p1], plot_ylim, 'k:')
        
#         lhs_arg_max = np.argmax(data[0 : p1])
        
#         if lhs_arg_max > offset and lhs_arg_max < p1 - 1:
            
#             n1 = lhs_arg_max
#             plt.plot([n1, n1], plot_ylim, 'k:')
            
#         rhs_arg_max = np.argmax(data[p1 : ])
        
#         if rhs_arg_max > p1:
            
#             n2 = p1 + rhs_arg_max
#             plt.plot([n2, n2], plot_ylim, 'k:')
            
#     return n1, p1, n2
        
    
def plot_local_extrema(data, plot_ylim, offset, p1_lower_bound = 20, p1_upper_bound = 150, min_diff = 0.00002):
    
    n1 = 0
    n2 = 0
    p1 = 0
    
    arg_min = np.argmin(data[p1_lower_bound : p1_upper_bound])

    if arg_min > 0 and arg_min < p1_upper_bound - 1:
        
        local_min = p1_lower_bound + arg_min
        plt.plot([local_min, local_min], plot_ylim, 'k:')
        
        l_bound = local_min - p1_lower_bound
        r_bound = local_min + p1_lower_bound
        
        r_diff = data[r_bound] - data[local_min]
        l_diff = data[l_bound] - data[local_min]
        
        if l_diff > min_diff and r_diff > min_diff:
            
            p1 = local_min
            plt.plot([l_bound, l_bound], plot_ylim, 'k:')
            plt.plot([r_bound, r_bound], plot_ylim, 'k:')
        
    return n1, p1, n2
            

In [10]:
def plot_averages(stimulation_site_data, epochs, electrode_locations, electrode_names, offset, data_path, prefix):

#     filter_name = "notch45-55_95-105Hz"
#     filter_name = "no-notch"
    filter_name = "band30Hz"
    
    subplot_rows, subplot_cols, subplot_ids = get_subplot_data(electrode_locations, electrode_names)
    plt.figure(figsize=(2.5 * subplot_cols, 2 * subplot_rows))

    epochs.load_data()
    epoch_data = epochs.get_data()

    avg_data = np.mean(epoch_data, axis=0)
    
    plot_xlim = [offset, np.minimum(500, avg_data.shape[1])]    
    plot_ylim = [-0.0008, 0.0003]
    
    spreadsheets_path = "%s/spreadsheets" % data_path
    if not os.path.isdir(spreadsheets_path):
        os.mkdir(spreadsheets_path)
    
    f_out = open(
        "%s/stim_without_artifact_%s_%sHz_%s_%s.csv" % (
            spreadsheets_path, 
            stimulation_site_data.iloc[0]['stimulation_site'],
            stimulation_site_data.iloc[0]['frequency'],
            prefix, 
            filter_name
        ), 'w'
    )
    
    p1_delays = np.zeros(len(electrode_names))
    p1_values = np.zeros_like(p1_delays)
        
    for electrode_name in natsort.natsorted(electrode_names):
        electrode_id = get_electrode_id(electrode_names, electrode_name)
    
        plt.subplot(subplot_rows, subplot_cols, subplot_ids[electrode_id])
        
        for j in range(epoch_data.shape[0]):
            plt.plot(
                np.arange(plot_xlim[0], plot_xlim[1]), 
                epoch_data[j, electrode_id, plot_xlim[0] : plot_xlim[1]], 
                color=(0.75, 0.75, 0.75)
            )
            
        data = avg_data[electrode_id]

        n1_key, p1_key, n2_key = plot_local_extrema(data, plot_ylim, offset)
        
        n1_ms = np.round(n1_key / (0.001 * epochs.info['sfreq']))
        p1_ms = np.round(p1_key / (0.001 * epochs.info['sfreq']))
        n2_ms = np.round(n2_key / (0.001 * epochs.info['sfreq']))
        
        p1_delays[electrode_id] = p1_ms
        p1_values[electrode_id] = data[p1_key]
        
        nearest_electrodes = stimulation_site_data.iloc[0]['nearest_electrodes'].replace(' ', '').split(',')
        if electrode_name in nearest_electrodes:
            plt.plot(np.arange(plot_xlim[0], plot_xlim[1]), data[plot_xlim[0] : plot_xlim[1]], color='red', linewidth=2)    
        else:
            plt.plot(np.arange(plot_xlim[0], plot_xlim[1]), data[plot_xlim[0] : plot_xlim[1]], color='blue', linewidth=2)    

        if subplot_ids[electrode_id] + subplot_cols in subplot_ids:
            plt.xticks([])
        else:
            plt.xticks(
                np.arange(0, data.shape[0], 0.1 * epochs.info['sfreq']),
                np.arange(0, data.shape[0] / (0.001 * epochs.info['sfreq']), 100, dtype=int)
            )

        if subplot_ids[electrode_id] - 1 in subplot_ids:
            plt.yticks([])
        else:
            plt.yticks([-0.0005, 0])
            
        plt.xlim(plot_xlim)
        plt.ylim(plot_ylim)
        
#         plt.title("%s | %d, %d, %d" % (
#             electrode_name, n1_ms, p1_ms, n2_ms
#         ))
        plt.title("%s | %d" % (electrode_name, p1_ms))

        f_out.write("%s, %d, %d, %d, %f, %f, %f\n" % (
            electrode_name, 
            n1_ms, p1_ms, n2_ms,
            data[n1_key], data[p1_key], data[n2_key]
        ))

    f_out.close()
    
    images_path = "%s/images" % data_path
    if not os.path.isdir(images_path):
        os.mkdir(images_path)
        
    plt.savefig(
        "%s/stim_without_artifact_%s_%sHz_%s_%s_average.png" % (
            images_path, 
            stimulation_site_data.iloc[0]['stimulation_site'],
            stimulation_site_data.iloc[0]['frequency'],
            prefix, filter_name
        )
    )
    plt.close()
    
    return p1_delays, p1_values
    

In [11]:
def plot_epochs(epochs, offset):
    epochs.load_data()
    epoch_data = epochs.get_data()
    
    for i in range(epochs.info['nchan']):
        plt.figure(figsize=(20,5))
        for j in range(epoch_data.shape[0]):
            plt.plot(epoch_data[j, i], color=(0.75, 0.75, 0.75))
            plt.plot(epoch_data[j, i, :offset], color=(1.00, 0.00, 0.00))
  
        plt.ylim([-0.0010, 0.0010])
    
    sys.exit(1)
    

In [12]:
def is_valid_p1(p1_values, electrode_id, nearest_electrodes, electrode_names, electrode_locations):
    
    # non-positive values are always invalid
    if not p1_values[electrode_id] > 0:
        return False

    # positive values in the proximity of the stimulation site are always valid
    if in_proximity_of_stimulation_site(electrode_id, nearest_electrodes, electrode_names, electrode_locations):
        return True
    
    # find minimum delay in the proximity of the stimulation site
    nearest_electrode_p1 = np.inf
    
    for i in range(len(p1_values)):
        
        if np.isnan(p1_values[i]) or p1_values[i] >= nearest_electrode_p1:
            continue
        
        if in_proximity_of_stimulation_site(i, nearest_electrodes, electrode_names, electrode_locations):
            nearest_electrode_p1 = p1_values[i]
            
    # delays lesser or equal to minimum delay in the proximity of the stimulation site are always invalid
    if p1_values[electrode_id] <= nearest_electrode_p1:
        return False
    
    # only electrodes having greater delay than electrodes in the proximity of the stimulation site reach this point
    return True
    

In [13]:
def run(data_path, timestamp, connections_file, channel_ids, electrode_names, electrode_locations, offset):

    mne.set_log_level('CRITICAL')
    raw_time_series, events_data = load_data(
        data_path, timestamp, channel_ids, electrode_names
    )
    
    stimulation_sites = natsort.natsorted(events_data['stimulation_site'].unique())
    
    stimulation_sites_num = len(stimulation_sites)
    electrodes_num = len(electrode_names)
    
    p1_delays_matrix = np.zeros([stimulation_sites_num, electrodes_num])
    p1_values_matrix = np.zeros_like(p1_delays_matrix)

    common_avg_time_series = common_average(raw_time_series)

    for stimulation_site_id in range(stimulation_sites_num):

        stimulation_site_data = events_data.loc[events_data['stimulation_site'] == stimulation_sites[stimulation_site_id]]

        epochs_common_avg = epoch_stimulation_site(stimulation_site_data, common_avg_time_series, offset)
        p1_delays, p1_values = plot_averages(
            stimulation_site_data, epochs_common_avg, 
            electrode_locations, electrode_names, offset,
            data_path, 'common_avg'
        )

        p1_delays_matrix[stimulation_site_id, :] = p1_delays
        p1_values_matrix[stimulation_site_id, :] = p1_values

    output_p1_values = []
    for stimulation_site_id in range(stimulation_sites_num):

        stimulation_site_data = events_data.loc[events_data['stimulation_site'] == stimulation_sites[stimulation_site_id]]
        nearest_electrodes = stimulation_site_data.iloc[0]['nearest_electrodes'].replace(' ', '').split(',')
        
        p1_delays_vector = p1_delays_matrix[stimulation_site_id, :]
        p1_delays_vector[p1_delays_vector == 0] = np.nan

        p1_values_vector = p1_values_matrix[stimulation_site_id, :]

        for electrode_name in natsort.natsorted(electrode_names):

            electrode_id = get_electrode_id(electrode_names, electrode_name)
          
            if is_valid_p1(p1_delays_vector, electrode_id, nearest_electrodes, electrode_names, electrode_locations):
                output_p1_values.append([p1_delays_vector[electrode_id], p1_values_vector[electrode_id]])
            else:
                output_p1_values.append([np.nan, np.nan])
                                
        
    data_pd = pd.read_csv(connections_file)
    data_pd[['p1_delay', 'p1_value']] = np.array(output_p1_values)
    data_pd.to_csv(connections_file, index=False)
            

# Patient #1

In [14]:
patient_id = 1
timestamp = "20180507_1240"

channel_ids = np.arange(0, 14)
electrode_names = ['e1', 'e2', 'e3', 'e4', 'e5', 'e6', 'e7', 'e8', 'e9', 'e10', 'e11', 'e12', 'e13', 'e14']

electrode_locations = np.array([
    [ 1,  5,  0,  0,  0,  0,  0,  0],
    [ 2,  6,  0,  0,  0,  0,  0,  0],
    [ 3,  7,  0,  0,  0,  0,  0,  0],
    [ 4,  8,  0,  0,  0,  0,  0,  0],
    [ 0,  0, 14, 13, 12, 11, 10,  9]
])

offset = 8 # 20

run(
    get_ecog_data_dir(patient_id), timestamp, 
    get_connections_file(patient_id),
    channel_ids, electrode_names, electrode_locations, offset
)

# Patient #3

In [15]:
patient_id = 3
timestamp = "20180801_1354"

channel_ids = [0, 1, 2, 3, 8, 9, 10, 11, 12, 13]
electrode_names = ['e1', 'e2', 'e3', 'e4', 'e5', 'e6', 'e8', 'e9', 'e10', 'e7']

electrode_locations = np.array([
    [ 1,  0,  0,  0,  0,  0 ],
    [ 2,  0,  0,  0,  0,  0 ],
    [ 3,  0,  0,  0,  0,  0 ],
    [ 4,  0,  0,  0,  0,  0 ],
    [ 0,  0,  0,  0,  0,  0 ],
    [ 5,  6,  7,  8,  9, 10 ]
])

offset = 8 # 25

run(
    get_ecog_data_dir(patient_id), timestamp, 
    get_connections_file(patient_id),
    channel_ids, electrode_names, electrode_locations, offset
)

# Patient #4

In [16]:
patient_id = 4
timestamp = "20180806_1240"

channel_ids = np.arange(0, 14)
electrode_names = ['e1', 'e2', 'e3', 'e4', 'e5', 'e6', 'e7', 'e8', 'e9', 'e10', 'e12', 'e13', 'e14', 'e11']

electrode_locations = np.array([
    [  0,  0,  0,  0,  0,  9,  0,  0 ],
    [  0,  0,  0,  0, 10,  0,  0,  0 ],
    [  0,  0,  0, 11,  0,  0,  0,  0 ],
    [  0,  0, 12,  0,  0,  0,  0,  0 ],
    [  0, 13,  0,  0,  0,  0,  0,  0 ],
    [ 14,  0,  0,  0,  4,  3,  2,  1 ],
    [  0,  0,  0,  0,  8,  7,  6,  5 ]
])

offset = 8 # 25

run(
    get_ecog_data_dir(patient_id), timestamp, 
    get_connections_file(patient_id),
    channel_ids, electrode_names, electrode_locations, offset
)

# Patient #6

In [17]:
patient_id = 6
timestamp = "20181126_1225"

channel_ids = np.arange(0, 14)
electrode_names = ['e1', 'e2', 'e3', 'e4', 'e5', 'e6', 'e7', 'e8', 'e9', 'e10', 'e12', 'e13', 'e14', 'e11']

electrode_locations = np.array([
    [  0,  0,  0,  8,  4,  0 ],
    [  0,  0,  0,  7,  3,  0 ],
    [  0,  0,  0,  6,  2,  0 ],
    [  0,  0,  0,  5,  1,  0 ],
    [ 14, 13, 12, 11, 10,  9 ]
])

offset = 8 # 20

run(
    get_ecog_data_dir(patient_id), timestamp, 
    get_connections_file(patient_id),
    channel_ids, electrode_names, electrode_locations, offset
)

# Patient #7

In [18]:
patient_id = 7
timestamp = "20180917_1327"

channel_ids = np.arange(0, 14)
electrode_names = ['e1', 'e2', 'e3', 'e4', 'e5', 'e6', 'e7', 'e8', 'e9', 'e10', 'e12', 'e13', 'e14', 'e11']

electrode_locations = np.array([
    [  9, 10, 11,  0,  0,  0 ],
    [  0,  0,  0, 12, 13, 14 ],
    [  0,  0,  0,  0,  0,  0 ],
    [  0,  5,  6,  7,  8,  0 ],
    [  0,  1,  2,  3,  4,  0 ]
])

offset = 8 # 20

run(
    get_ecog_data_dir(patient_id), timestamp, 
    get_connections_file(patient_id),
    channel_ids, electrode_names, electrode_locations, offset
)

# Patient #8

In [23]:
patient_id = 8
timestamp = "20181120_1207"

channel_ids = np.arange(0, 8)
electrode_names = ['e1', 'e2', 'e3', 'e4', 'e5', 'e6', 'e7', 'e8']

electrode_locations = np.array([
    [  0,  0,  0,  0,  0,  0 ],
    [  0,  4,  3,  2,  1,  0 ],
    [  0,  8,  7,  6,  5,  0 ],
    [  0,  0,  0,  0,  0,  0 ]
])

offset = 8 # 20

run(
    get_ecog_data_dir(patient_id), timestamp, 
    get_connections_file(patient_id),
    channel_ids, electrode_names, electrode_locations, offset
)

ValueError: Invalid file path or buffer object type: <class 'NoneType'>

# Patient #10

In [20]:
patient_id = 10
timestamp = "20190408_1248"

channel_ids = np.arange(0, 14)
electrode_names = ['e1', 'e2', 'e3', 'e4', 'e5', 'e6', 'e7', 'e8', 'e9', 'e10', 'e12', 'e13', 'e14', 'e11']

electrode_locations = np.array([
    [ 8,  4,  0,  0,  0,  0,  0,  0],
    [ 7,  3,  0,  0,  0,  0,  0,  0],
    [ 6,  2,  0,  0,  0,  0,  0,  0],
    [ 5,  1,  0,  0,  0,  0,  0,  0],
    [ 0,  0, 14, 13, 12, 11, 10,  9]
])

offset = 8 # 20

run(
    get_ecog_data_dir(patient_id), timestamp, 
    get_connections_file(patient_id),
    channel_ids, electrode_names, electrode_locations, offset
)

# Patient #11

In [24]:
patient_id = 11
timestamp = "20190902_1426"

channel_ids = np.arange(0, 14)
electrode_names = ['e1', 'e2', 'e3', 'e4', 'e5', 'e6', 'e7', 'e8', 'e9', 'e10', 'e12', 'e13', 'e14', 'e11']

electrode_locations = np.array([
    [ 8,  4,  0,  0,  0,  0,  0,  0],
    [ 7,  3,  0,  0,  0,  0,  0,  0],
    [ 6,  2,  0,  0,  0,  0,  0,  0],
    [ 5,  1,  0,  0,  0,  0,  0,  0],
    [ 0,  0, 14, 13, 12, 11, 10,  9]
])

offset = 8 # 20

run(
    get_ecog_data_dir(patient_id), timestamp, 
    get_connections_file(patient_id),
    channel_ids, electrode_names, electrode_locations, offset
)

# Patient #12

In [22]:
patient_id = 12
timestamp = "20190909_1148"

channel_ids = np.arange(0, 14)
electrode_names = ['e1', 'e2', 'e3', 'e4', 'e5', 'e6', 'e7', 'e8', 'e9', 'e10', 'e12', 'e13', 'e14', 'e11']

electrode_locations = np.array([
    [ 0,  9,  0,  0,  0],
    [ 0, 10,  5,  1,  0],
    [ 0, 11,  6,  2,  0],
    [ 0, 12,  7,  3,  0],
    [ 0, 13,  8,  4,  0],
    [ 0, 14,  0,  0,  0]
])

offset = 8 # 20

run(
    get_ecog_data_dir(patient_id), timestamp, 
    get_connections_file(patient_id),
    channel_ids, electrode_names, electrode_locations, offset
)