In [None]:
import os
import mne
import pycartool.io
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib qt

import umap
from features import *
from my_io import *

## Read EEG data

In [None]:
def read_epiliptic_events(file, sfreq):
    df = pd.read_csv(file, sep="\t", skiprows=1, names=['start', 'stop', 'label'])
    df['start_time'] = df['start'] / sfreq
    df['stop_time'] = df['stop'] / sfreq
    df['duration'] = df['stop_time'] - df['start_time']
    df['label'] = [l.split('_')[0] for l in df['label'].values]
    annotations = mne.Annotations(df['start_time'], df['duration'], df['label'])
    return(annotations)
file = fr'V:\switchdrive\Brainhack\KMR11\d17\Epileptic_events.mrk'

def read_file(fname):
    # Read Raw
    base_path = os.path.dirname(fname) 
    base_name = os.path.basename(fname)
    raw = pycartool.io.read_sef(fname)
    # Read Bads
    bad_annotations = mne.Annotations(0, 0, 'null')
    for file in os.listdir(base_path):
        if file.lower().startswith('bad'):
            print(file)
            path = os.path.join(base_path, file)
            annotations = read_bad_file(path, raw.info['sfreq'])
            bad_annotations += annotations
    # Read epileptic
    epileptic_annotations = mne.Annotations(0, 0, 'null')
    for file in os.listdir(base_path):
        if file.lower().startswith('epileptic'):
            print(file)
            path = os.path.join(base_path, file)
            annotations = read_epiliptic_events(path, raw.info['sfreq'])
            epileptic_annotations += annotations
    # Read background
    background_annotations = mne.Annotations(0, 0, 'null')
    for file in os.listdir(base_path):
        if file.lower().endswith('bck.mrk'):
            print(file)
            path = os.path.join(base_path, file)
            annotations = read_background_events_file(path, raw.info['sfreq'])
            background_annotations += annotations
    annotations = epileptic_annotations + bad_annotations + background_annotations
    raw.set_annotations(annotations) 
    return(raw)

In [None]:
files = list()

subject_folder = fr'V:\switchdrive\Brainhack\KMR11'
for day_folder in os.listdir(subject_folder):
    day_folder = os.path.join(subject_folder, day_folder)
    if os.path.isdir(day_folder):
        for file in os.listdir(day_folder):
            if file.endswith('.sef'):
                file = os.path.join(day_folder, file)
                files.append(file)

In [None]:
bands = [(1,30), (200,240)]

In [None]:
epoch_duration = 1
all_features = []
datas = list()
channel = ['e11']
for file in files:
    try:    
        raw = read_file(file)
        raw.pick(channel)
    except Exception as e:
        print(file , e)
        continue
    day = file.split("\\")[4]
    subject = file.split("\\")[3]
    features = []
    column_names = []
    
    events, events_id = mne.events_from_annotations(raw)
    for band in bands:
        raw_ = raw.copy().filter(band[0], band[1])
        epochs = mne.Epochs(raw_, events, event_id=events_id, tmin=0, tmax=epoch_duration, baseline=None,
                            on_missing='ignore', event_repeated='drop')
        data = epochs.get_data()
        activity_features = activity(data)
        features.append(activity_features)
        column_names += [f'{band}_activity_feature_{i}' for i in range(activity_features.shape[-1])]

        mobility_features = mobility(data)
        features.append(mobility_features)
        column_names += [f'{band}_mobility_feature_{i}' for i in range(mobility_features.shape[-1])]

        complexity_features = complexity(data)
        features.append(complexity_features)
        column_names += [f'{band}_complexity_feature_{i}' for i in range(complexity_features.shape[-1])]

        time_features = extract_time_feat(data)
        time_features = time_features.reshape((time_features.shape[0],-1))
        features.append(time_features)
        column_names += [f'{band}_time_feature_{i}' for i in range(time_features.shape[-1])]

        frequency_features = extract_freq_feat(data, sfreq=epochs.info['sfreq'])
        frequency_features = frequency_features.reshape((frequency_features.shape[0],-1))
        features.append(frequency_features)
        column_names += [f'{band}_frequency_feature_{i}' for i in range(frequency_features.shape[-1])]

        information_features = extract_information_feat(data, sfreq=epochs.info['sfreq'])
        information_features = information_features.reshape((information_features.shape[0],-1))
        features.append(information_features)
        column_names += [f'{band}_information_feature_{i}' for i in range(information_features.shape[-1])]

        dwt_features = extract_dwt_feat(data)
        dwt_features = dwt_features.reshape((dwt_features.shape[0],-1))
        features.append(dwt_features)
        column_names += [f'{band}_dwt_feature_{i}' for i in range(dwt_features.shape[-1])]

    events_ = np.array([list(events_id.keys())[list(events_id.values()).index(event)] for event in epochs.events[:,2]]).reshape(-1,1)
    features.append(events_)
    column_names += ['event_name']

    ts = epochs.events[:,0].reshape(-1,1) / raw.info['sfreq']
    features.append(ts)
    column_names += ['start']

    days = np.array([day] * len(epochs)).reshape(-1,1)
    features.append(days)
    column_names += ['day']

    subjects = np.array([subject] * len(epochs)).reshape(-1,1)
    features.append(subjects)
    column_names += ['subject']
    
    epochs = mne.Epochs(raw, events, event_id=events_id, tmin=0, tmax=1, baseline=None,
                            on_missing='ignore', event_repeated='drop')
    data = epochs.get_data()
    datas.append(data)
    
    features = np.hstack(features)

    all_features.append(features)
    
df = pd.DataFrame(np.vstack(all_features), columns=column_names)

In [None]:
df.columns

In [None]:
df = pd.DataFrame(np.vstack(all_features), columns=column_names)
df['code'] =  pd.Categorical(df.event_name).codes
features = [column for column in df.columns if "feature" in column]
non_features = [column for column in df.columns if not "feature" in column]

df[features].to_csv('features.tsv',sep='\t', index=False, header=False)
df[non_features].to_csv('non_features.tsv',sep='\t')

In [None]:
fit = umap.UMAP(n_neighbors=15, n_components=2)
data = df[features].values

In [None]:
u = fit.fit_transform(data)
df['data'] = [data for data in np.vstack(datas)[:,0,:]]
df['x1'] = u[:,0].reshape(-1)
df['x2'] = u[:,1].reshape(-1)
plt.scatter(u[:,0], u[:,1], c=df['code'] , s=1)
plt.legend()

In [None]:
import umap.plot 
umap.plot.output_notebook()


p = umap.plot.interactive(fit, labels=df['code'],
                          hover_data=df[non_features],
                          point_size=4,
                          theme='fire',
                          background='black',
                          #color_key= ['FR', 'HAHF', 'HALF', 'LAHF', 'LALF', 'RP', 'background', 'null'],
                          interactive_text_search_columns=True)
#
umap.plot.show(p)

In [None]:
colormap

In [None]:
from bokeh.plotting import figure, output_file, show
from bokeh.models import ColumnDataSource, CustomJS
from bokeh.layouts import row
from bokeh.palettes import brewer
from bokeh.io import  output_notebook
output_notebook()

colors = brewer["Spectral"][len(df.code.unique())]
colormap = {i: colors[i] for i in df.code.unique()}
colors = [colormap[x] for x in df.code]
df['color'] = colors

tooltips = [
    ("day", "@day"),
    ("event", "@event_name"),
    ("subject", "@subject")]

s1 = ColumnDataSource(data=df[['x1','x2','day', 'event_name', 'subject', 'data', 'code', 'color']])
p1 = figure(width=400, height=400, tools='tap,hover,pan,wheel_zoom,box_zoom,reset', title="UMAP",
            tooltips=tooltips)
p1.scatter('x1', 'x2', source=s1, color='color')

df2 = pd.DataFrame()
df2['x'] =  np.arange(0,len(df['data'][0]))
df2['y'] = df['data'].values[0]
s2 = ColumnDataSource(data=df2)
p2 = figure(width=400, height=400, title="Data")
p2.line('x', 'y', source=s2)

s1.selected.js_on_change('indices', CustomJS(args=dict(s1=s1, s2=s2), code="""
        const inds = cb_obj.indices;
        console.log(inds[0]);
        const d2 = s2.data;
        console.log(s1.data.data);
        d2['x'] = []
        d2['y'] = []
        for (let i = 0; i < d2.index.length; i++) {
            d2['x'].push(i)
            d2['y'].push(s1.data.data[inds[0]][i])
        }
        s2.change.emit();
    """)
)


#layout = row(p1, p2)
show(row(p1, p2))
#show(p2)

In [None]:
from bokeh.palettes import brewer
from bokeh.io import  output_notebook
output_notebook()

colors = brewer["Spectral"][len(df.code.unique())]
colormap = {i: colors[i] for i in df.code.unique()}
colors = [colormap[x] for x in df.code]
df['color'] = colors

tooltips = [
    ("day", "$day"),
    ("event", "$event_name"),
    ("subject", "$subject")]

s1 = ColumnDataSource(data=df[['x1','x2','day', 'event_name', 'subject', 'data', 'code', 'color']])
p1 = figure(width=400, height=400, tools='tap,hover,pan,wheel_zoom,box_zoom,reset', title="UMAP",
            tooltips=tooltips)
p1.scatter('x1', 'x2', source=s1)
show(p1)

In [None]:
day = 23
subject = 11
start = 1656
raw = read_file(fr'V:\\switchdrive\\Brainhack\\KMR{subject}\\d{day}\\KMR{subject}_d{day}_Raw_DS.Avg_ref.sef')

for band in bands:
    raw_ = raw.copy().filter(band[0], band[1])
    raw_.plot(scalings='auto', start=start, decim=1)