### Imports

In [None]:
from os import listdir
from os.path import exists
from importlib import reload
import numpy as np
import pandas as pd
import pyxdf
import mne
from utils import *
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import train_test_split, cross_val_score
import time
import datetime
from datetime import datetime, timezone

print('Imports done...')

### Functions

In [None]:
# Helper functions:
def extract_eeg(stream, kick_last_ch=True):
    """
    Extracts the EEG data and the EEG timestamp data from the stream and stores it into two lists.
    :param stream: XDF stream containing the EEG data.
    :param kick_last_ch: Boolean to kick out the brainproducts marker channel
    :return: eeg: list containing the eeg data
             eeg_ts: list containing the eeg timestamps.cd
    """
    extr_eeg = stream['time_series'].T
    extr_eeg *= 1e-6 # Convert to volts.
    assert extr_eeg.shape[0] == 65
    extr_eeg_ts = eeg_stream['time_stamps']

    if kick_last_ch:
        # Kick the last row (unused Brainproduct markers):
        extr_eeg = extr_eeg[:64,:]

    return extr_eeg, extr_eeg_ts


def extract_eeg_infos(stream):
    """
    Takes eeg stream and extracts the sampling rate, channel names, channel labels and the effective sample rate from the xdf info.
    :param stream: EEG xdf stream
    :return: sampling_rate: Configured sampling rate
    :return: names: channel names
    :return: labels: channel labels (eeg or eog)
    :return: effective_sample_frequency: Actual sampling frequency based on timestamps.
    """
    # Extract all infos from the EEG stream:
    recording_device = stream['info']['name'][0]
    sampling_rate = float(stream['info']['nominal_srate'][0])
    effective_sample_frequency = float(stream['info']['effective_srate'])

    # Extract channel names:
    chn_names = [stream['info']['desc'][0]['channels'][0]['channel'][i]['label'][0] for i in range(64)]
    # chn_names.append('Markers')
    labels = ['eeg' for i in range(64)]
    labels[16] = 'eog'
    labels[21] = 'eog'
    labels[40] = 'eog'
    # chn_labels.append('misc')

    return sampling_rate, chn_names, labels, effective_sample_frequency


def extract_annotations(mark_stream, first_samp):
    """
    Function to extract the triggers of the marker stream in order to prepare for the annotations.
    :param mark_stream: xdf stream containing the markers and time_stamps
    :param first_samp: First EEG sample, serves for aligning the markers
    :return: triggs: Dict containing the extracted triggers.
    """
    triggs = {'onsets': [], 'duration': [], 'description': []}

    # Extract the markers:
    marks = mark_stream['time_series']

    # Fix markers due to bug in paradigm:
    corrected_markers = fix_markers(marks)

    # Extract the timestamp of the markers and correct them to zero
    marks_ts = mark_stream['time_stamps'] - first_samp

    # Read every trigger in the stream
    for index, marker_data in enumerate(corrected_markers):
        # extract triggers information
        triggs['onsets'].append(marks_ts[index])
        triggs['duration'].append(int(0))
        # print(marker_data[0])
        triggs['description'].append(marker_data[0])

    return triggs

# Fix markers:
def fix_markers(orig_markers):
    """
    Given a list of markers, this function processes the markers and modifies the trial type markers if necessary.

    :param orig_markers: A list of markers. Each marker is a tuple containing the marker string and a float value representing the time at which the marker occurred.
    :type orig_markers: list
    :return: The modified list of markers.
    :rtype: list
    """

    trial_type_markers = ['LTR-s', 'LTR-l','RTL-s', 'RTL-l', 'TTB-s', 'TTB-l', 'BTT-s', 'BTT-l']
    counter_letter = {'l': 'R', 'r': 'L', 'b': 'T', 't': 'B'}

    # Parse through markers
    for i in range(len(orig_markers)-3):
        marker = orig_markers[i][0]
        if marker in trial_type_markers:
            following_markers = []
            # Find the next 4 occurances that start with 'c':
            # and store them in a list:
            if (i+9) < len(orig_markers):
                for ii in range(i+1, i+9):
                    next_mark = orig_markers[ii][0]
                    if next_mark[0] == 'c':
                        following_markers.append(next_mark[2])
            else:
                for ii in range(i+1, len(orig_markers)):
                    next_mark = orig_markers[ii][0]
                    if next_mark[0] == 'c':
                        following_markers.append(next_mark[2])

            # Exit loop if less than 4 following markers were found:
            if len(following_markers) < 4:
                continue

            if following_markers[0] == 'c' or following_markers[1] == 'c':
                continue

            # Extract first letter of the trial type marker:
            first_letter = marker[0].lower()
            last_letter = marker[-1].lower()

            # Check if the first two letters in following markers are the same, if not, change type:
            if (following_markers[0] != first_letter) and (following_markers[1] != first_letter):
                # Trial type changes:
                new_type = following_markers[0].upper() + 'T' + counter_letter[following_markers[0]] + '-'

                if (following_markers[2] == 'c') and (following_markers[3] == 'c'):
                    new_type = new_type + 's'
                else:
                    new_type = new_type + 'l'

                orig_markers[i][0] = new_type

            # Otherwise check if the second two markers are short or long and change accordingly:
            else:
                if (last_letter == 's') and (following_markers[2] != 'c') and (following_markers[3] != 'c'):
                    new_type = marker[:-1]
                    new_type += 'l'
                    orig_markers[i][0] = new_type

                elif (last_letter == 'l') and (following_markers[2] == 'c') and (following_markers[3] == 'c'):
                    new_type = marker[:-1]
                    new_type += 's'
                    orig_markers[i][0] = new_type


    return orig_markers

def add_bad_channel_to_df(bad_chn_row, ch_names, csv_name='bad_channels.csv'):
    """
    Add a row to a CSV file containing information about bad channels in some data.

    :param bad_chn_row : list
        A list containing the information to be added to the CSV file. The order of the elements should
        match the order of the columns in the CSV file.
    :param csv_name : str, optional
        The name of the CSV file. The default is 'bad_channels.csv'.
    :return: df_bads : pandas.DataFrame
        A dataframe containing the information from the CSV file, with the new row added.
    """
    # Check if df_bads.csv already exists:
    if not exists(csv_name):
        # Create dataframe with bad channels:
        df_bads = pd.DataFrame(columns=['Subject', 'Run', 'Paradigm', 'Bad_channel'])
        df_bads.to_csv(csv_name)
    else:
        # Load dataframe
        df_bads = pd.read_csv(csv_name, index_col=0)

    # Check if the channel name exists:
    if bad_chn_row[-1] not in ch_names:
        raise NameError('Channel name not found')

    # Add row to the dataframe:
    df_bads.loc[len(df_bads.index)] = bad_chn_row

    print(f'Added {bad_chn_row} to the dataframe...')

    # Drop duplicates:
    df_bads.drop_duplicates(inplace=True)

    # Save df:
    df_bads.to_csv(csv_name)

    return df_bads

def get_bads_for_subject(subject, csv_file='bad_channels.csv'):
    """
    Get a list of bad channels that appear more than once for a given subject from a CSV file.

    :param subject: Subject name.
    :type subject: str
    :param csv_file: CSV file containing bad channel information. Default is 'bad_channels.csv'.
    :type: csv_file: str

    :returns: list: List of bad channels that appear more than once.

    :raises: FileExistsError: If the CSV file does not exist.
    """
    # Check if df_bads.csv already exists:
    if not exists(csv_file):
        raise FileExistsError('File does not exist, please use the add_bad_channel_df() function.')
    else:
        # Load dataframe
        df = pd.read_csv(csv_file, index_col=0)

    # Filter for subject and check if channel has more then 1 appearances:
    subject_df = df[df['Subject'] == subject]

    # Get the counts of all the unique values in the 'column_name' column
    channel_counts = subject_df['Bad_channel'].value_counts()

    # Select the rows that have a count greater than 1
    duplicate_bads = list(channel_counts[channel_counts>1].index)

    return duplicate_bads

def get_all_additional_information(subject, csv_file='participant_info.csv'):
    """Returns a tuple of additional information for the given subject.

    :param subject: The name of the subject.
    :type subject: str
    :param csv_file: The file path to the participant info CSV file.
    :type csv_file: str
    :return: A tuple containing the following information:
        - meas_date (datetime): The measurement date.
        - experimenter (str): The name of the experimenter.
        - proj_name (str): The name of the project.
        - subject_info (str): The name of the subject.
        - line_freq (float): The line frequency.
        - gender (str): The gender of the subject.
        - dob (str): The date of birth of the subject.
        - age_at_meas (float): The age of the subject at the time of measurement.
    :rtype: tuple
    """
    if not isinstance(subject, str):
        raise TypeError('Subject must be a string.')
    if not isinstance(csv_file, str):
        raise TypeError('CSV file must be a string.')
    if not exists(csv_file):
        raise FileNotFoundError('File does not exist. Check if the path is correct.')

    df = pd.read_csv(csv_file, index_col=False)
    subject_info = df[df['Participant'] == subject]

    if subject_info.empty:
        raise ValueError('Subject not found in CSV file.')

    meas_date_str = subject_info['Measurement_Date'].values[0]
    meas_date = datetime.strptime(meas_date_str, '%d.%m.%Y')
    meas_date = meas_date.replace(tzinfo=timezone.utc)
    experimenter = 'Peter T.'
    proj_name = 'Decoding of range during goal-directed movement'
    line_freq = 50.0
    gender = subject_info['Gender'].values[0]
    dob = subject_info['Date_Of_Birth'].values[0]
    age_at_meas = subject_info['Age_At_Measurement'].values[0]

    return meas_date, experimenter, proj_name, subject_info, line_freq, gender, dob, age_at_meas

def get_subset_of_dict(full_dict, keys_of_interest):
    return dict((k, full_dict[k]) for k in keys_of_interest if k in full_dict)


def create_sliced_trial_list(event_dict, events_from_annot):
    # Slice into list of list from trial_type_marker to trial_type_marker
    trial_type_markers = ['LTR-s', 'LTR-l','RTL-s', 'RTL-l', 'TTB-s', 'TTB-l', 'BTT-s', 'BTT-l']
    event_dict_trial_type = get_subset_of_dict(event_dict, trial_type_markers)
    event_sequence = events_from_annot[:,-1]

    trial_list = []
    first_samps = []
    first_time = True
    for i, entry in enumerate(event_sequence):
        if entry in event_dict_trial_type.values():
            if first_time:
                temp_list = [entry]
                first_samps.append(events_from_annot[i,0])
                first_time = False
            else:
                temp_list.append(entry)
                trial_list.append(temp_list)
                temp_list = [entry]
                first_samps.append(events_from_annot[i,0])
        else:
            if not first_time:
                temp_list.append(entry)

    trial_list.append(temp_list)

    return trial_list, first_samps


def get_bad_epochs(event_dict, trial_list):
    """
    Given an event dictionary, find the indices of the epochs (sub-lists) in the trial list that are invalid.
    An epoch is invalid if it does not satisfy the following conditions:
        1. If it is not the last epoch, its length must be 9.
        2. If it is the last epoch, its length must be 8.
        3. The first entry must be a trial_type marker.
        4. The second entry must be the 'Start' marker.
        5. The fourth entry must be the 'Cue' marker.
        6. The seventh entry must be the 'Break' marker.
        7. The first two LDR readings must be coherent with the trial type.
        8. The second two LDR readings must be coherent with the trial type.

    :param event_dict: A dictionary where keys are event names and values are corresponding event markers.
    :type event_dict: dict
    :return: A list of indices corresponding to the invalid epochs.
    :rtype: list
    """

    # Check if the order is correct:
    bad_idcs = []
    trial_type_markers = ['LTR-s', 'LTR-l','RTL-s', 'RTL-l', 'TTB-s', 'TTB-l', 'BTT-s', 'BTT-l']
    trial_vals = [event_dict[key] for key in trial_type_markers]
    n_epochs = len(trial_list)

    for idx, sub_list in enumerate(trial_list):
        # Add bad epoch if the length is not 9 (except for the last epoch):
        if len(sub_list) != 9 and idx != n_epochs-1:
            bad_idcs.append(idx)
            continue

        # Add bad epoch if the length is not 8 for the last epoch:
        elif len(sub_list) != 8 and idx == n_epochs-1:
            bad_idcs.append(idx)
            continue

        # Add bad epoch if the first entry is not a trial_type_marker:
        if sub_list[0] not in trial_vals:
            bad_idcs.append(idx)
            continue

        # Add bad epoch if the second entry is not a Start marker:
        if sub_list[1] != event_dict['Start']:
            bad_idcs.append(idx)
            continue

        # Add bad epoch if the fourth entry is not a Cue marker:
        if sub_list[3] != event_dict['Cue']:
            bad_idcs.append(idx)
            continue

        # Add bad epoch if the seventh entry is not a Break marker:
        if sub_list[6] != event_dict['Break']:
            bad_idcs.append(idx)
            continue

        # Get the keys for entries 3,5,6 and 8:
        start_touch = list(event_dict.keys())[list(event_dict.values()).index(sub_list[2])]
        start_release = list(event_dict.keys())[list(event_dict.values()).index(sub_list[4])]
        target_touch = list(event_dict.keys())[list(event_dict.values()).index(sub_list[5])]
        target_release = list(event_dict.keys())[list(event_dict.values()).index(sub_list[7])]

        # Get key for the trial_type marker:
        trial_type = list(event_dict.keys())[list(event_dict.values()).index(sub_list[0])]

        # Add bad epoch if first two ldr readings are not coherent with the trial type:
        if (trial_type[0].lower() != start_touch[2]) or (trial_type[0].lower() != start_release[2]):
            bad_idcs.append(idx)
            continue

        # Add bad epoch if the second two ldr readings are not coherent with the second part of the trial type:
        if (trial_type[4] == 'l'):
            if (trial_type[2].lower() != target_touch[2]) or (trial_type[2].lower() != target_release[2]):
                bad_idcs.append(idx)
                continue

        if (trial_type[4] == 's'):
            if (target_touch[2] != 'c') or (target_release[2] != 'c'):
                bad_idcs.append(idx)
                continue

    return bad_idcs

def convert_samps_to_time(first_time, first_samp, samp_list):
    """Convert sample numbers to time values.
    :param first_time: float time value of the first sample
    :param first_samp: int sample number of the first sample
    :param samp_list: list of int sample numbers to be converted
    :return: numpy ndarray of time values for the input sample numbers
    """
    return np.array(samp_list) * first_time / first_samp

def create_bad_annotations(starting_times, bad_events, duration, orig_time):
    """Create annotations for bad events in EEG data.

    :param starting_times: 1D array of starting times for all events in EEG data
    :type starting_times: numpy.ndarray
    :param bad_events: Indices of bad events in the starting_times array
    :type bad_events: numpy.ndarray or list
    :param duration: Duration of the bad events
    :type duration: float
    :param orig_time: The time at which the first sample in data was recorded
    :type orig_time: float
    :return: mne.Annotations object containing onsets, durations, and descriptions for bad events
    :rtype: mne.Annotations
    """

    bad_times = starting_times[bad_events]
    onsets = bad_times
    durations = [duration] * len(bad_times)
    descriptions = ['bad epoch'] * len(bad_times)
    return mne.Annotations(onsets, durations, descriptions, orig_time=orig_time)


### Constants

In [None]:
# data_path = 'C:/Users/tumfart/Code/github/master-thesis/data/'
data_path = 'C:/Users/peter/Google Drive/measurements/eeg/'
subjects = ['A01', 'A02', 'A03', 'A04', 'A05', 'A06', 'A07' , 'A08', 'A09', 'A10']
# = 'A03'
paradigm = 'paradigm' # 'eye', 'paradigm'
plot = False
mne.set_log_level('WARNING')

trial_type_markers = ['LTR-s', 'LTR-l','RTL-s', 'RTL-l', 'TTB-s', 'TTB-l', 'BTT-s', 'BTT-l']

# Create path list for each subject:
paths = [str(data_path + subject + '/' + paradigm) for subject in subjects]

### Read xdf-files for specified subject

In [None]:
# Iterate over each subject and extract the streams
start = time.time()
for subject, path in zip(subjects, paths):
    print(f'Extracting subject {subject}', end=' ')
    file_names = [f for f in listdir(path) if '.xdf' in f]

    for i, file_name in enumerate(file_names):
        print(f'#', end=' ')
        file = path + '/' + file_name

        # Read the raw stream:
        streams, header = pyxdf.load_xdf(file)

        # Split the streams:
        eeg_stream, marker_stream = split_streams(streams)

        # Get the eeg data:
        eeg, eeg_ts = extract_eeg(eeg_stream, kick_last_ch=True)
        #max_eeg_ts.append(eeg_ts.max())

        # Extract all infos from the EEG stream:
        fs, ch_names, ch_labels, eff_fs = extract_eeg_infos(eeg_stream)

        # Extract the triggers from the marker stream:
        triggers = extract_annotations(marker_stream, first_samp=eeg_ts[0])

        # Define MNE annotations
        annotations = mne.Annotations(triggers['onsets'], triggers['duration'], triggers['description'], orig_time=None)

        # Create mne info:
        # TODO: Check what info can be added to the stream:
        info = mne.create_info(ch_names, fs, ch_labels)

        # Create the raw array and add info, montage and annotations:
        raw = mne.io.RawArray(eeg, info, first_samp=eeg_ts[0])
        raw.set_montage('standard_1005')
        raw.set_annotations(annotations)

        # Store the raw file:
        store_name = path + '/' + subject + '_run_' + str(i + 1) + '_unprocessed_raw.fif'
        raw.save(store_name, overwrite=True)

        if plot:
            raw.plot(duration=60, proj=False, n_channels=len(raw.ch_names),
                     remove_dc=False, title='Raw')

    print()

print(f'Finished reading, took me {round(time.time()-start)} seconds...')


### Concatenate all raws for a subject:

In [None]:
# Iterate over each subject and load the raw files:
start = time.time()
for subject, path in zip(subjects, paths):
    print(f'Reading raw files for subject {subject}', end=' ')
    file_names = [f for f in listdir(path) if '_unprocessed_raw.fif' in f]

    for i, file_name in enumerate(file_names):
        print(f'#', end=' ')

        file = path + '/' + file_name
        raw = mne.io.read_raw(file, preload=True)
        if plot:
            raw.plot(duration=60, proj=False, n_channels=len(raw.ch_names),
                     remove_dc=False, title='Highpass filtered')
            plot_spectrum(raw)


        # Highpass filter:
        raw_highpass = raw.copy().filter(l_freq=0.4, h_freq=None, picks=['eeg'], method='iir')
        if plot:
            raw_highpass.plot(duration=60, proj=False, n_channels=len(raw.ch_names),
                              remove_dc=False, title='Highpass filtered')
            plot_spectrum(raw_highpass)

        # Notch filter:
        raw_notch = raw_highpass.copy().notch_filter(freqs=[50], picks=['eeg'])
        if plot:
            raw_notch.plot(duration=60, proj=False, n_channels=len(raw.ch_names), remove_dc=False, title='Notch filtered')
            plot_spectrum(raw_notch)

        # Store the raw file:
        store_name = path + '/' + subject + '_run_' + str(i + 1) + '_highpass_notch_filtered_raw.fif'
        raw_notch.save(store_name, overwrite=True)

    print()

print(f'Finished highpass and notch filtering, took me {round(time.time() - start)} seconds...')

### Filter the signals

In [None]:
# Iterate over each subject and load the raw files:
start = time.time()
for subject, path in zip(subjects, paths):
    print(f'Reading raw files for subject {subject}', end=' ')
    file_names = [f for f in listdir(path) if '_unprocessed_raw.fif' in f]

    for i, file_name in enumerate(file_names):
        print(f'#', end=' ')

        file = path + '/' + file_name
        raw = mne.io.read_raw(file, preload=True)
        if plot:
            raw.plot(duration=60, proj=False, n_channels=len(raw.ch_names),
                              remove_dc=False, title='Highpass filtered')
            plot_spectrum(raw)


        # Highpass filter:
        raw_highpass = raw.copy().filter(l_freq=0.4, h_freq=None, picks=['eeg'], method='iir')
        if plot:
            raw_highpass.plot(duration=60, proj=False, n_channels=len(raw.ch_names),
                              remove_dc=False, title='Highpass filtered')
            plot_spectrum(raw_highpass)

        # Notch filter:
        raw_notch = raw_highpass.copy().notch_filter(freqs=[50], picks=['eeg'])
        if plot:
            raw_notch.plot(duration=60, proj=False, n_channels=len(raw.ch_names), remove_dc=False, title='Notch filtered')
            plot_spectrum(raw_notch)

        # Store the raw file:
        store_name = path + '/' + subject + '_run_' + str(i + 1) + '_highpass_notch_filtered_raw.fif'
        raw_notch.save(store_name, overwrite=True)

    print()

print(f'Finished highpass and notch filtering, took me {round(time.time() - start)} seconds...')

### Visualize signals for bad channel identification

In [None]:
# Specifiy subject:
subject = 'A01'
paradigm = 'paradigm'
if paradigm == 'paradigm':
    runs = 9
else:
    runs = 2
names = [subject + '_run_' + str(i + 1) + '_highpass_notch_filtered_raw.fif' for i in range(runs)]

for i, name in enumerate(names):
    file = data_path + subject + '/' + paradigm + '/' + name
    raw = mne.io.read_raw(file, preload=True)

    raw.plot(duration=60, proj=False, n_channels=len(raw.ch_names), remove_dc=False, title=f'Notch & HP filtered. Run: {i+1}')

In [None]:
# Specifiy subject:
subject = 'A10'
paradigm = 'eye'
run = 2
name = subject + '_run_' + str(run) + '_highpass_notch_filtered_raw.fif'
file = data_path + subject + '/' + paradigm + '/' + name

raw = mne.io.read_raw(file, preload=True)

raw.plot(duration=60, proj=False, n_channels=len(raw.ch_names), remove_dc=False, title=f'Notch & HP filtered. Run: {run}')

In [None]:
# Add bad channel to bad channel.csv:
bad_df = add_bad_channel_to_df([subject, run, paradigm, 'T8'], ch_names=raw.ch_names, csv_name='bad_channels.csv')




### Add bad channels to all raw infos:

In [None]:
# Iterate over each subject and extract the streams
start = time.time()
for subject, path in zip(subjects, paths):
    print(f'Reading all fif files for subject {subject}', end=' ')
    file_names = [f for f in listdir(path) if 'raw.fif' in f]
    # Add bad channels:
    bads = get_bads_for_subject(subject, csv_file='bad_channels.csv')

    for i, file_name in enumerate(file_names):
        print(f'#', end=' ')

        file = path + '/' + file_name
        raw = mne.io.read_raw(file, preload=True)
        raw.info['bads'] = bads

        # Overwrite the raw file with the added info:
        store_name = path + '/' + file_name
        raw.save(store_name, overwrite=True)

    print()

print(f'Finished bad channel adding, took me {round(time.time() - start)} seconds...')

### Perform interpolation of bad channels:

In [None]:
# Iterate over each subject and extract the streams
start = time.time()
for subject, path in zip(subjects, paths):
    print(f'Reading all fif files for subject {subject}', end=' ')
    file_names = [f for f in listdir(path) if '_highpass_notch_filtered_raw.fif' in f]

    for i, file_name in enumerate(file_names):
        print(f'#', end=' ')

        file = path + '/' + file_name
        raw = mne.io.read_raw(file, preload=True)

        # Interpolate bad channels:
        raw_interp = raw.copy().interpolate_bads(reset_bads=False)

        # Overwrite the raw file with the added info:
        store_name = path + '/' + subject + '_run_' + str(i + 1) + '_bad_channels_interpolated_raw.fif'
        raw_interp.save(store_name, overwrite=True)

    print()

print(f'Finished interpolating bad channels, took me {round(time.time() - start)} seconds...')

### CAR re-referencing

In [None]:
# Iterate over each subject and extract the streams
start = time.time()
for subject, path in zip(subjects, paths):
    print(f'Reading all fif files for subject {subject}', end=' ')

    #TODO: Loaded fif file changes in final pipeline (because eye artifact correction was not yet implemented).
    file_names = [f for f in listdir(path) if '_bad_channels_interpolated_raw.fif' in f]

    for i, file_name in enumerate(file_names):
        print(f'#', end=' ')

        file = path + '/' + file_name
        raw_interp = mne.io.read_raw(file, preload=True)

        # Interpolate bad channels:
        raw_avg_ref = raw_interp.copy().set_eeg_reference(ref_channels='average')

        # Overwrite the raw file with the added info:
        store_name = path + '/' + subject + '_run_' + str(i + 1) + '_car_referenced_raw.fif'
        raw_avg_ref.save(store_name, overwrite=True)

    print()

print(f'Finished rereferencing eeg, took me {round(time.time() - start)} seconds...')

In [None]:
# Load uninterpolated raw and interpolated raw:
raw_avg_ref = mne.io.read_raw('C:/Users/peter/Google Drive/measurements/eeg/A01/paradigm/A01_run_1_car_referenced_raw.fif')
raw_interp = mne.io.read_raw('C:/Users/peter/Google Drive/measurements/eeg/A01/paradigm/A01_run_1_bad_channels_interpolated_raw.fif')

In [None]:
raw_avg_ref.plot(duration=60, proj=False, n_channels=len(raw_avg_ref.ch_names), remove_dc=False, title=f'CAR referenced.')
raw_interp.plot(duration=60, proj=False, n_channels=len(raw_interp.ch_names), remove_dc=False, title='Interpolated')

### Helper cell to add info:

In [None]:
# Iterate over each subject and extract the streams
start = time.time()
for subject, path in zip(subjects, paths):
    print(f'Reading all fif files for subject {subject}', end=' ')
    file_names = [f for f in listdir(path) if 'raw.fif' in f]

    # Get correct info:
    meas_date, experimenter, proj_name, subject_info, line_freq, gender, dob, age_at_meas = get_all_additional_information(subject, csv_file='participant_info.csv')

    big_subject_info = {'Subject ID': subject,
                        'Gender': gender,
                        'Age at measurement': age_at_meas}

    for i, file_name in enumerate(file_names):
        print(f'#', end=' ')

        file = path + '/' + file_name
        raw = mne.io.read_raw(file, preload=True)

        # Add infos:
        raw.info['subject_info'] = big_subject_info
        raw.info['experimenter'] = experimenter
        #raw.info['proj_name'] = proj_name
        raw.set_meas_date(meas_date)
        raw.info['line_freq'] = line_freq

        # Overwrite the raw file with the added info:
        store_name = path + '/' + file_name
        raw.save(store_name, overwrite=True)

    print()

print(f'Finished adding info, took me {round(time.time() - start)} seconds...')

### HEAR - High-variance electrode artifact removal algorithm

In [None]:
# TODO: Implement HEAR
# Get resting data:

# Check resting trials and exclude bad ones:

# Calculate variance µ^2_s





### Lowpass filter at 3 Hz

In [None]:
# Iterate over each subject and load the raw files:
start = time.time()
for subject, path in zip(subjects, paths):
    print(f'Reading raw files for subject {subject}', end=' ')
    file_names = [f for f in listdir(path) if '_car_referenced_raw.fif' in f]

    for i, file_name in enumerate(file_names):
        print(f'#', end=' ')

        file = path + '/' + file_name
        raw = mne.io.read_raw(file, preload=True)

        # Lowpass filter:
        raw_lowpass = raw.copy().filter(l_freq=None, h_freq=3.0, picks=['eeg'], method='iir')
        if plot:
            raw_lowpass.plot(duration=60, proj=False, n_channels=len(raw.ch_names),
                              remove_dc=False, title='Highpass filtered')
            plot_spectrum(raw_lowpass)

        # Store the raw file:
        store_name = path + '/' + subject + '_run_' + str(i + 1) + '_lowpass_filtered_raw.fif'
        raw_lowpass.save(store_name, overwrite=True)

    print()

print(f'Finished lowpass filtering, took me {round(time.time() - start)} seconds...')

### Combine the datasets into one dataset

In [None]:
# Iterate over each subject and extract the streams
start = time.time()
for subject, path in zip(subjects, paths):
    print(f'Reading all fif files for subject {subject}', end=' ')
    file_names = [f for f in listdir(path) if '_lowpass_filtered_raw.fif' in f]

    raws = []
    for i, file_name in enumerate(file_names):
        print(f'#', end=' ')

        file = path + '/' + file_name
        raw = mne.io.read_raw(file, preload=True)
        raws.append(raw)

    concat_raw = mne.concatenate_raws(raws)

    # Store the concatenated raw file:
    store_name = path + '/' + subject + '_' + paradigm + '_concatenated_raw.fif'
    concat_raw.save(store_name, overwrite=True)
    print()

print(f'Finished concatenating, took me {round(time.time() - start)} seconds...')


In [None]:
concat_raw.plot(duration=60, proj=False, n_channels=len(raw.ch_names), remove_dc=False, title=f'Concatenated raw.')

# Mark bad dataspans

In [None]:
# Iterate over each subject and extract the streams
start = time.time()
for subject, path in zip(subjects, paths):
    print(f'Reading last fif file for subject {subject}', end=' ')
    file_names = [f for f in listdir(path) if 'concatenated_raw.fif' in f]

    # Load file
    file_name = file_names[0]
    file = path + '/' + file_name
    raw = mne.io.read_raw(file, preload=True)

    events_from_annot, event_dict = mne.events_from_annotations(raw)


    # Select subset of event_dict with following markers:
    markers_of_interest = ['LTR-s', 'LTR-l','RTL-s', 'RTL-l', 'TTB-s', 'TTB-l', 'BTT-s', 'BTT-l']
    event_dict_of_interest = get_subset_of_dict(event_dict, markers_of_interest)

    # Check if the order of annotations is correct:
    # Therefore first create a marker list of each trial:
    trial_list, starting_samples = create_sliced_trial_list(event_dict, events_from_annot)
    starting_times = convert_samps_to_time(raw.first_time, raw.first_samp, starting_samples)
    bad_events = get_bad_epochs(event_dict, trial_list)
    print(len(bad_events))

    # add annotation for bad channels and select reject_by_annotation when generating the epochs:
    bad_annots = create_bad_annotations(starting_times, bad_events, duration=7, orig_time=raw.info['meas_date'])
    raw.set_annotations(raw.annotations + bad_annots)

    # Save epochs:
    store_name = path + '/' + subject + '_' + paradigm + '_bad_annotations_raw.fif'
    raw.save(store_name, overwrite=True)

    print()

print(f'Finished adding bad annotations, took me {round(time.time() - start)} seconds...')

# Epoching

In [None]:
# Iterate over each subject and extract the streams
start = time.time()
for subject, path in zip(subjects, paths):
    print(f'Reading last fif file for subject {subject}', end=' ')
    file_names = [f for f in listdir(path) if '_bad_annotations_raw.fif' in f]

    # Load file
    file_name = file_names[0]
    file = path + '/' + file_name
    raw = mne.io.read_raw(file, preload=True)

    events_from_annot, event_dict = mne.events_from_annotations(raw)


    # Select subset of event_dict with following markers:
    markers_of_interest = ['LTR-s', 'LTR-l','RTL-s', 'RTL-l', 'TTB-s', 'TTB-l', 'BTT-s', 'BTT-l']
    event_dict_of_interest = get_subset_of_dict(event_dict, markers_of_interest)

    # TODO select event ID's of interest, hand over dict for event_id to make it easier to extract them:
    epochs = mne.Epochs(raw, events_from_annot, event_id=event_dict_of_interest, tmin=2.0, tmax=7.0, baseline=None, reject_by_annotation=True, preload=True)

    # Save epochs:
    store_name = path + '/' + subject + '_' + paradigm + '_epo.fif'
    epochs.save(store_name, overwrite=True)

    print()

print(f'Finished epoching, took me {round(time.time() - start)} seconds...')

In [None]:
help(create_bad_annotations)
epochs = epochs.copy().resample(10)
epochs.plot()

In [None]:
evokeds_list = [epochs['LTR-l', 'RTL-l', 'TTB-l', 'BTT-l'].average(), epochs['LTR-s', 'RTL-s', 'TTB-s', 'BTT-s'].average()]

In [None]:
conds = ('long', 'short')
evks = dict(zip(conds, evokeds_list))

In [None]:
evokeds2 = dict(short=list(epochs['LTR-l', 'RTL-l', 'TTB-l', 'BTT-l'].iter_evoked()),
                long=list(epochs['LTR-s', 'RTL-s', 'TTB-s', 'BTT-s'].iter_evoked()))
mne.viz.plot_compare_evokeds(evokeds2, combine='mean', picks=['Cz', 'C1', 'C2', 'FCz', 'CPz'], show_sensors='upper right')
#picks=['Cz', 'C1', 'C2', 'FCz', 'CPz']
# plt.savefig('distance_grand_averages.pdf')

In [None]:
subjects = ['A05']

In [None]:
mne.viz.plot_compare_evokeds(evks, picks='Fcz')

In [None]:
temp = epochs['LTR-l', 'RTL-l']

In [None]:
def custom_func(x):
    return x.max(axis=1)


for combine in ('mean', 'median', 'gfp', custom_func):
    mne.viz.plot_compare_evokeds(evks, picks='eeg', combine=combine)

In [None]:
event_dict['BTT-s']

In [None]:
epochs['BTT-l'].plot()

In [None]:
# Read xdf:
# Read the raw stream:
streams, header = pyxdf.load_xdf('C:/Users/peter/Google Drive/measurements/eeg/A02/paradigm/sub-A02_ses-S001_task-Paradigm[_acq-]_run-002_eeg.xdf')

### Helper cell to add bad epochs to a dataframe

In [None]:
# TODO: cell to view the epochs for a specific subject and marker:
marker_of_interest = 'LTR-s' # ['LTR-s', 'LTR-l','RTL-s', 'RTL-l', 'TTB-s', 'TTB-l', 'BTT-s', 'BTT-l']
subject = 'A03'
file = data_path + subject + '/paradigm/' + subject + '_paradigm_epo.fif'

# Load epochs:
epochs = mne.read_epochs(file, preload=True)

epochs[marker_of_interest].plot()


In [None]:
temp = epochs[marker_of_interest][0]

In [None]:
replace_list = [list(event_dict.keys())[list(event_dict.values()).index(events_from_annot[i,2])] for i in range(len(events_from_annot))]

events_from_annot[:,2] = replace_list

In [None]:
for i in range(len(events_from_annot)):
    events_from_annot[i,2] = list(event_dict.keys())[list(event_dict.values()).index()]

### Get metrics of rejected channels per subject

In [None]:
# Iterate over each subject and extract the streams
start = time.time()

num_bads = []
for subject, path in zip(subjects, paths):
    print(f'Reading all fif files for subject {subject}', end=' ')
    file_names = [f for f in listdir(path) if '.fif' in f]

    # Load one .fif file:
    file_name = file_names[0]
    file = path + '/' + file_name
    raw = mne.io.read_raw(file, preload=True)

    bads = raw.info['bads']
    num_bads.append(len(bads))
    print()

num_bads = np.asarray(num_bads)
print(f'Rejceted on average {num_bads.mean()} +/- {round(num_bads.std(),2)}')

print(f'Finished calculating rejected channel metrics, took me {round(time.time() - start)} seconds...')

In [None]:
%reset

In [None]:
# List files in folder:
files = [f for f in listdir(path)]

eeg_streams = []
marker_streams = []
# Load all recorded EEG files for one subjectc
files = [files[0]]
for file in files:
    file_name = path + '/' + file
    print(f'####', end='#')

    # Read streams
    streams, header = pyxdf.load_xdf(file_name)

    # Split the streams:
    eeg_stream, marker_stream = split_streams(streams)

    eeg_streams.append(eeg_stream)
    marker_streams.append(marker_stream)


print()
print(f'Finished reading, found {len(eeg_streams)} EEG streams and {len(marker_streams)} marker streams...')

In [None]:
differences = [0]
max_eeg_ts = []
for i, (eeg_stream, m_stream) in enumerate(zip(eeg_streams, marker_streams)):
    # Get the eeg data:
    eeg, eeg_ts = extract_eeg(eeg_stream)
    max_eeg_ts.append(eeg_ts.max())

    # Kick the last row (unused Brainproduct markers):
    eeg = eeg[:64,:]

    # Extract all infos from the EEG stream:
    fs, ch_names, ch_labels, eff_fs = extract_eeg_infos(eeg_stream)

    # Extract the markers and timestamps:
    # markers = m_stream['time_series']
    # markers_ts = m_stream['time_stamps']
    #
    # # Convert list of list of strings to list of strings:
    # markers = [''.join(element) for element in markers]

    # # Make Nan array with len(eeg)
    # aligned_markers = np.empty(eeg_ts.shape, dtype='<U5')
    #
    # # Place markers string at the align array where first time markers_ts <= eeg_ts:
    # for k, marker in enumerate(markers):
    #     ts = markers_ts[k]
    #     idx = np.where(ts <= eeg_ts)[0][0]
    #     aligned_markers[idx] = marker

    if i == 0:
        global_eeg = eeg
        first_ts = eeg_ts[0]
        # global_markers = aligned_markers
    else:
        global_eeg = np.concatenate((global_eeg, eeg), axis=1)
        # global_markers = np.concatenate((global_markers, aligned_markers))
        differences.append(eeg_ts[0]-last_ts)

    last_ts = eeg_ts[-1]
    print(f'####', end='#')

cum_diff = np.cumsum(differences)
eeg = global_eeg
# markers = global_markers
print()
print('Extracted EEG data, EEG infos...')

In [None]:
# annotation generation from:
# https://github.com/WriessneggerLab/EEG-preprocessing/blob/eeg/src/EEGAnalysis.py
# generation of the events according to the definition
triggers = {'onsets': [], 'duration': [], 'description': []}
global_markers_ts = []
for i, m_stream in enumerate(marker_streams):
    # Extract the markers and timestamps:
    markers = m_stream['time_series']
    markers_ts = m_stream['time_stamps'] - float(m_stream['info']['created_at'][0])# - cum_diff[i]


    global_markers_ts += list(markers_ts)
    # read every trigger in the stream
    for idx, marker_data in enumerate(markers):
        # extract triggers information
        triggers['onsets'].append(markers_ts[idx])
        triggers['duration'].append(int(0))
        # print(marker_data[0])
        triggers['description'].append(marker_data[0])

# define MNE annotations
annotations = mne.Annotations(triggers['onsets'], triggers['duration'], triggers['description'], orig_time=None) #, orig_time=np.array(global_markers_ts))

In [None]:
mrks_list = list(markers_ts)
a = []
a += mrks_list

### Put extracted data into mne structure

In [None]:
# TODO: align annotations

info = mne.create_info(ch_names, fs, ch_labels)

raw = mne.io.RawArray(eeg, info, first_samp=first_ts)
raw.set_montage('standard_1005')
raw.set_annotations(annotations)

if plot:
    raw.plot(duration=60, proj=False, n_channels=len(raw.ch_names),
             remove_dc=False, title='Raw')

### Filter with HP at 0.4Hz and BS at 50 Hz

In [None]:
raw_highpass = raw.copy().filter(l_freq=0.4, h_freq=None, picks=['eeg'], method='iir')
if plot:
    raw_highpass.plot(duration=60, proj=False, n_channels=len(raw.ch_names),
                      remove_dc=False, title='Highpass filtered')
    plot_spectrum(raw_highpass)

raw_notch = raw_highpass.copy().notch_filter(freqs=[50], picks=['eeg'])
if plot:
    raw_notch.plot(duration=60, proj=False, n_channels=len(raw.ch_names), remove_dc=False, title='Notch filtered')
    plot_spectrum(raw_notch)

### Interpolate bad channels:

In [None]:
# TODO: check function --> need to mark them first
raw_interp = raw_notch.copy().interpolate_bads(reset_bads=False)

### Correct eye artifacts:

In [None]:
# TODO

### CAR:

In [None]:
raw_avg_ref = raw_interp.copy().set_eeg_reference(ref_channels='average')
if plot:
    raw_avg_ref.plot(duration=60, proj=False, n_channels=len(raw.ch_names), remove_dc=False, title='CAR Referenced')

### HEAR model:

In [None]:
# TODO?

### LP at 3.0Hz

In [None]:
raw_lp = raw_avg_ref.copy().filter(l_freq=None, h_freq=3.0, picks=['eeg'], method='iir')

### Extract epochs before resampling (otherwise markers may get lost) and reject bad trials:

In [None]:
events = mne.find_events(raw_lp, stim_channel='Markers')

epochs = mne.Epochs(raw_lp, events, event_id=classes_map, tmin=1, tmax=6, preload=True, baseline=None, reject=dict(eeg=100e-6)) #, baseline=(1,2))

print(epochs)

if plot:
    epochs.plot(n_epochs=2)

### Resample to 10 Hz:

In [None]:
epochs_resampled = epochs.copy().resample(10)
print('Preprocessing finished.')

### Implementing cue-aligned (better according to Reinmar paper)

### Distance decoding:

In [None]:
events = mne.find_events(raw_lp, stim_channel='Markers')
event_dict = {'short': 1, 'long': 2, 'short': 1, 'long': 2, 'short':1, 'long':2, 'short':1, 'long':2}

epochs_long_short = mne.Epochs(raw_lp, events, event_id=event_dict, tmin=1, tmax=6, preload=True, baseline=None, reject=dict(eeg=100e-6))



short = epochs_long_short['short'].average()

long = epochs_long_short['long'].average()

#evokeds = dict(short=short, long=long)
#mne.viz.plot_compare_evokeds(evokeds, picks='POz')

evokeds2 = dict(short=list(epochs_long_short['short'].iter_evoked()),
                long=list(epochs_long_short['long'].iter_evoked()))
mne.viz.plot_compare_evokeds(evokeds2, combine='mean', picks=['Cz', 'C1', 'C2', 'FCz', 'CPz'], show_sensors='upper right')
plt.savefig('distance_grand_averages.pdf')

#['Pz', 'POz', 'PO3', 'PO4', 'P2', 'P1', 'P2', 'Oz', 'O1', 'O2']

epochs_long_short = epochs_long_short.copy().resample(10)

In [None]:
X = []
y = []
for i,epoch in enumerate(epochs_long_short):
    #print(epoch.shape)
    # Deleting EOG channels:
    epoch = np.delete(epoch, 40, 0)
    epoch = np.delete(epoch, 21, 0)
    epoch = np.delete(epoch, 16, 0)
    X.append(epoch[:61,:])
    y.append(list(epochs_long_short[i].event_id.values())[0])

X = np.array(X)
y = np.array(y)

print(y)

for i,label in enumerate(y):
    if label % 2 == 0:
        y[i] = 0
    else:
        y[i] = 1

print(y)


# Split training and test set:

clf = LinearDiscriminantAnalysis(solver='lsqr', shrinkage='auto')
acc = []
cv_scores = []
for idx in range(len(X[0,0])):
    x = X[:,:,idx]
    # Reshape X to 2d array:
    #nsamples, nx, ny = x.shape
    #x = x.reshape((nsamples,nx*ny))
    X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)
    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)
    acc.append(clf.score(X_test, y_test))

    scores = cross_val_score(clf, x, y, cv=100)
    cv_scores.append(scores.mean())

    if idx % 10 == 0:
        print(idx)

print('Done')

t = np.arange(len(acc))
t = t/10
#plt.plot(t, acc)

plt.plot(t, cv_scores)

window = 7

ma = np.convolve(cv_scores, np.ones(window), 'valid') / window

plt.plot(t[:-window+1], ma)
plt.plot([2,2], [min(cv_scores), max(cv_scores)])
plt.title('Single sample approach, 180-fold CV')
plt.savefig('distance_acc_single.pdf')

In [None]:
# 5 point LDA
X = []
y = []
for i,epoch in enumerate(epochs_long_short):
    #print(epoch.shape)
    # Deleting Marker channel:
    # Deleting EOG channels:
    epoch = np.delete(epoch, 40, 0)
    epoch = np.delete(epoch, 21, 0)
    epoch = np.delete(epoch, 16, 0)
    X.append(epoch[:61,:])
    y.append(list(epochs_long_short[i].event_id.values())[0])

for i,label in enumerate(y):
    if label % 2 == 0:
        y[i] = 0
    else:
        y[i] = 1


X = np.array(X)
y = np.array(y)


# Split training and test set:

clf = LinearDiscriminantAnalysis(solver='lsqr', shrinkage='auto')
acc = []
cv_scores = []
for idx in range(len(X[0,0])-5):
    x = X[:,:,idx:idx+5]
    if idx % 10 == 0:
        print(idx)
        print(x.shape)
    # Reshape X to 2d array:
    nsamples, nx, ny = x.shape
    x = x.reshape((nsamples,nx*ny))
    X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)
    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)
    acc.append(clf.score(X_test, y_test))

    scores = cross_val_score(clf, x, y, cv=100)
    cv_scores.append(scores.mean())



print('Done')
#print(acc)

t = np.arange(len(acc))
t = t/10 + 5/10
#plt.plot(t, acc)

plt.plot(t, cv_scores)

window = 7

ma = np.convolve(cv_scores, np.ones(window), 'valid') / window

plt.plot(t[window-1:], ma)
plt.plot([2,2], [min(cv_scores), max(cv_scores)])
plt.xlabel('Time (s)')
plt.ylabel('Accuracy (a.u.)')
plt.title('Windowed approach accuracies, distance 180-fold CV')
plt.savefig('distance_acc_5point.pdf')

In [None]:
%reset