In [None]:
%matplotlib inline

In [None]:
import functools
import heapq
import math
import os
import pathlib

import matplotlib.pyplot as plt
import mne
import numpy as np
import sklearn

## Make sure to install autoreject from master as there is a bug in the latest release: http://autoreject.github.io/
from autoreject import AutoReject, get_rejection_threshold, set_matplotlib_defaults
from IPython.display import Markdown, display

## Setup Project and Unicorn specific settings

In [None]:
project_root = pathlib.Path().absolute().parent

sample_data_folder = project_root / "visual_stress_data"
sample_data_raw_files = list(sample_data_folder.glob("**/*.fif"))
print("Found {} *_raw.fif file(s)".format(len(sample_data_raw_files)))

petri_sample_data_raw_files = list(sample_data_folder.glob("**/petri*.fif"))
print("Found {} petri*_raw.fif file(s)".format(len(petri_sample_data_raw_files)))

gel_sample_data_raw_files = list(sample_data_folder.glob("**/gel*.fif"))
print("Found {} gel*_raw.fif file(s)".format(len(gel_sample_data_raw_files)))

channel_name_mapping = {
    "EEG0": "Fz",
    "EEG1": "C3",
    "EEG2": "Cz",
    "EEG3": "C4",
    "EEG4": "Pz",
    "EEG5": "PO7",
    "EEG6": "Oz",
    "EEG7": "PO8",
}

event_dict = {"focused": 1, "blurred": 2, "end": 3}

### Load montage file and display montage
#### API Docs
- [read_custom_montage](https://mne.tools/stable/generated/mne.channels.read_custom_montage.html#mne-channels-read-custom-montage)
- [plot](https://mne.tools/stable/generated/mne.channels.DigMontage.html#mne.channels.DigMontage.plot)

In [None]:
channel_locs_file = (
    project_root / "Resources" / "locs_electrode_placement_gtec_unicorn_standard.locs"
)

montage = mne.channels.read_custom_montage(channel_locs_file)
_ = montage.plot()

In [None]:
def get_raw_from_fif(file):

    raw = mne.io.read_raw_fif(file, preload=True)
    
    # fix event object format (flip and get increment event ids by one as mne doesn't like an event id of 0)
    raw.info["events"] = np.array(
        [[e["list"][2], e["list"][1], e["list"][0] + 1] for e in raw.info["events"]]
    )

    # rename channels and set channel positions
    mne.rename_channels(raw.info, channel_name_mapping)
    raw.set_montage(montage)
    
    # Pick the EEG channels and exclude the Fz channel
    raw = raw.pick("eeg", exclude=['Fz'])

    # Bandpass filter. Later when splitting Epochs make sure to take into account the Nyquist frequency
    raw.filter(2, 40, n_jobs=-1)

    # Throw away everything until the first event
    raw.crop(tmin=raw.info["events"][0][0] / raw.info["sfreq"])

    # Set average reference
    raw.set_eeg_reference("average", projection=False)
    
    # Scale values to microvolts
    raw = raw.apply_function(lambda a: a * 1e-7)

    return raw

## Preparation
Load all data

In [None]:
old_log_level = mne.set_log_level("WARNING")

raws = [get_raw_from_fif(f) for f in gel_sample_data_raw_files]

mne.set_log_level(old_log_level)

Some debugging outputs:

In [None]:
# Note: concatenate_raws will modify raws[0]!
raw, events = mne.concatenate_raws(raws, events_list=[r.info["events"] for r in raws])

# Update events object
raw.info["events"] = events


display(Markdown("**n_times of modified raw:** {}".format(raw.n_times)))

raw.info

### Augument data
To make it simpler to split the data into short epochs later, insert artifical events after every blur and focused event.

In [None]:
# create the virtual events for every second.
augumented_events_lists = [
    [[e_1[0] + (250*3) * i, 0, e_1[2]] for i in range(math.floor((e_2[0] - e_1[0]) / 250/3))]
    for e_1, e_2 in zip(raw.info["events"][:-1], raw.info["events"][1:])
]

# insert the events into the raw object
raw.info["events"] = np.array(list(heapq.merge(*augumented_events_lists)))

# Drop the temporary events with id 3
raw.info["events"] = np.array(list(filter(lambda e: e[2] != 3, raw.info["events"])))

### Raw Plot
**Let's have a look at our data.** Set duration to some high number (>10000) to see all data.

In [None]:
_ = raw.plot(
    events=raw.info["events"],
    event_color={1: "green", 2: "blue", 3: "black"},
    duration=160,
    scalings="2e-6",
    clipping="clamp",  # or "transparent" "clamp"
    show_scalebars=False,
    show_scrollbars=False,
)

## Preprocessing

### Generate Epochs
Each event is three second long

In [None]:
epochs = mne.Epochs(
    raw,
    raw.info["events"],
    event_id={"focused": 1, "blurred": 2},
    tmin=0,
    tmax=3,
    baseline=(None, None),
    preload=True,
    reject_by_annotation=True,
    reject=None,
    flat=None
)

### Drop bad epochs

In [None]:
reject_criteria = get_rejection_threshold(epochs)
print('Reject criteria: {}'.format((reject_criteria)))

display(Markdown("**Epochs dropped due to bad channels:**"))

old_log_level = mne.set_log_level("WARNING") # lower log level as drop_bad generates a lot of output if many epochs are dropped
epochs.drop_bad(reject=reject_criteria)
mne.set_log_level(old_log_level)

_ = epochs.plot_drop_log()

### Equalize Events 

In [None]:
conds_we_care_about = ["focused", "blurred"]

display(Markdown("**Epochs dropped to prevent bias:**"))
epochs.equalize_event_counts(conds_we_care_about)  # this operates in-place

focused_epochs = epochs["focused"]
blurred_epochs = epochs["blurred"]

In [None]:
print('We are left with {} epochs for training.'.format((len(epochs))))

## Exploratory Analysis

In [None]:
_ = epochs.plot(
    events=events,
    scalings="auto",
    show_scrollbars=False,
    n_epochs=40,
    event_colors={1: "red", 2: "blue"},
    epoch_colors=[["red" if e[2] == 1 else "blue"] * len(raw.get_data()) for e in epochs.events],
)

_ = focused_epochs.plot(
    events=events,
    scalings="auto",
    show_scrollbars=False,
    epoch_colors=[["red" if e[2] == 1 else "blue"] * len(raw.get_data()) for e in focused_epochs.events],
)
_ = blurred_epochs.plot(
    events=events,
    scalings="auto",
    show_scrollbars=False,
    epoch_colors=[["red" if e[2] == 1 else "blue"] * len(raw.get_data()) for e in blurred_epochs.events],
)

In [None]:
_ = focused_epochs.plot_image()
_ = blurred_epochs.plot_image()

Can we make out any trends in power spectrum?

In [None]:
_ = focused_epochs.plot_psd(
    average=False
)
_ = blurred_epochs.plot_psd(
    average=False
)

In [None]:
_ = focused_epochs.plot_psd_topomap(ch_type="eeg", normalize=True)
_ = blurred_epochs.plot_psd_topomap(ch_type="eeg", normalize=True)

### Time Frequency

In [None]:
frequencies = np.arange(7, 30, 3)
power = mne.time_frequency.tfr_morlet(
    focused_epochs, n_cycles=2, return_itc=False, freqs=frequencies, decim=2
)
_ = power.plot()

power = mne.time_frequency.tfr_morlet(
    blurred_epochs, n_cycles=2, return_itc=False, freqs=frequencies, decim=2
)
_ = power.plot()

## Machine learning

 ### Model Processing and Machine Learning Algorithms

#### Support Vector Machines

In [None]:
# Changing data to PSD
freq_data = mne.decoding.PSDEstimator()

In [None]:
numpy_data = epochs.get_data()
numpy_data = numpy_data.reshape(numpy_data.shape[0], -1)
print("Shape of original data is {}".format(numpy_data.shape))

labels = epochs.events[:,2]-1

from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split

# Splitting data
X_train, X_test, y_train, y_test = train_test_split(numpy_data, labels, test_size=0.33, random_state=42)
print("Shape of X_train is {}".format(X_train.shape))
print("Shape of labels is {}".format(y_train.shape))

# Applying SVC to the data for a baseline machine learning model
clf = make_pipeline(freq_data, StandardScaler(), SVC(gamma='auto'))
clf.fit(X_train, y_train)

In [None]:
predictions = clf.predict(X_test)

In [None]:
from sklearn.metrics import f1_score, accuracy_score, plot_confusion_matrix, plot_roc_curve
from sklearn.metrics import hamming_loss, roc_curve, classification_report

print("F1 score is {}".format(f1_score(y_test, predictions)))
print("Accuracy score is {}".format(accuracy_score(y_test, predictions)))
print("Hamming Loss of the classifier is {}\n".format(hamming_loss(y_test, predictions)))

print("\n CLASSIFICATION REPORT \n")
print(classification_report(y_test, predictions))

# plot_confusion_matrix(clf, X_test, y_test)
plot_roc_curve(clf, X_test, y_test)

In [None]:
import pickle 

with open('model_svm_gel_3sec_fixed.pkl', 'wb') as myfile:
    pickle.dump(clf, myfile)