# Load and concatenate

In this script we load the data from the original BrainVision format into MNE-Python, and then save it as a `.fif` file.

For a given subject, load all available datafiles (corresponding to tasks),
and do the following:
- set EEG reference to None (=keep original reference)
- set EEG, ECG, EOG channel types
- add a 1020 standard montage to have coordinates for plotting
- extract events and scale the event markers for each task (**A**ctive or **Y**oked, **F**ixed or **V**ariable, or description)
    - AF          += 100
    - AV          += 200
    - YF          += 300
    - YV          += 400
    - description += 500
- load bad segments and channels and add them to the data
- concatenate the raw data across trials
- save the concatenated raw data and corresponding events
  (including their updated event markers)

In [None]:
%matplotlib qt

In [1]:
import itertools
import multiprocessing
import os
import os.path as op

import mne

from utils import (
    BIDS_ROOT,
    TASK_BASED_INCREMENTS,
    provide_trigger_dict,
    task_not_present_for_subject,
)

In [None]:
# IO: Where to find the data
eeg_path_template = op.join(
    BIDS_ROOT, "sub-{0:02}", "eeg", "sub-{0:02}_task-{1}_eeg.vhdr"
)

# Where to find the annotations and bad channels
fname_annots_template = op.join(
    BIDS_ROOT, "derivatives", "sub-{0:02}", "sub-{0:02}_task-{1}_annotations.txt"
)
fname_channels_template = op.join(
    BIDS_ROOT, "derivatives", "sub-{0:02}", "sub-{0:02}_task-{1}_badchannels.txt"
)

# Where to save concatenated raw data and events
fname_rawconcat_template = op.join(
    BIDS_ROOT, "derivatives", "sub-{0:02}", "sub-{0:02}_concat_eeg-raw.fif.gz"
)
fname_events_template = op.join(
    BIDS_ROOT, "derivatives", "sub-{0:02}", "sub-{0:02}_concat_eeg-eve.fif.gz"
)

# Pack all names in a dict
name_templates = dict()
name_templates["eeg"] = eeg_path_template
name_templates["annots"] = fname_annots_template
name_templates["channels"] = fname_channels_template
name_templates["rawconcat"] = fname_rawconcat_template
name_templates["events"] = fname_events_template

In [None]:
subjects = range(1, 41)
tasks = ["ActiveFixed", "ActiveVariable", "YokedFixed", "YokedVariable", "description"]

# How many subjects to run over in parallel
NJOBS = max(2, multiprocessing.cpu_count() - 6)

# Whether or not to overwrite existing files
overwrite = True

In [None]:
def load_and_concatenate(subj, tasks, name_templates, overwrite):
    """Load BrainVision raw data and save as fif.
    
    For a given subject, load all available datafiles (corresponding to tasks),
    and do the following:
    - set EEG reference to None (=keep original reference)
    - set EEG, ECG, EOG channel types
    - add a 1020 standard montage to have coordinates for plotting
    - extract events and scale the event markers for each task (**A**ctive or
      **Y**oked, **F**ixed or **V**ariable, or description)
        - AF          += 100
        - AV          += 200
        - YF          += 300
        - YV          += 400
        - description += 500
    - load bad segments and channels and add them to the data
    - concatenate the raw data across trials
    - save the concatenated raw data and corresponding events
      (including their updated event markers)

    Parameters
    ----------
    subj : int
        The subject identifier in the range(1, 41).
    tasks : list of str
        The task names.
    name_templates : dict
        A dictionary of string templates. Needs the following keys:
        "eeg", "annots", "channels", "rawconcat", "events"
    overwrite : bool
        Whether to overwrite existing files.

    """
    # Unpack places to load and save data from
    eeg_path_template = name_templates["eeg"]
    fname_annots_template = name_templates["annots"]
    fname_channels_template = name_templates["channels"]
    fname_rawconcat_template = name_templates["rawconcat"]
    fname_events_template = name_templates["events"]

    # Handle existing files
    skip_events = False
    if op.exists(fname_events_template.format(subj)):
        if overwrite:
            os.remove(fname_events_template.format(subj))
        else:
            skip_events = True

    skip_raw = False
    if op.exists(fname_rawconcat_template.format(subj)):
        if overwrite:
            os.remove(fname_rawconcat_template.format(subj))
        else:
            skip_raw = True

    # If we have all needed files already, return early
    if skip_events and skip_raw:
        return

    # Else, start preparing the files ...
    # Get the markers that wered used in the experiment
    marker_ids = [ord(i) for i in provide_trigger_dict().values()]

    all_events = list()
    files = list()
    bads = list()
    for task in tasks:
        if task_not_present_for_subject(subj, task):
            continue

        # Get the data
        raw = mne.io.read_raw_brainvision(
            eeg_path_template.format(subj, task), preload=True
        )

        # Suppress an automatic "average reference"
        raw.set_eeg_reference(ref_channels=[])

        # Set the EOG and ECG channels to their type
        raw.set_channel_types({"ECG": "ecg", "HEOG": "eog", "VEOG": "eog"})

        # Set a standard montage for plotting later
        montage = mne.channels.make_standard_montage("standard_1020")
        raw.set_montage(montage)

        # Extract events, incrementing event markers according to a
        # task dependent mapping
        event_id = dict()
        for marker in marker_ids:
            key = "Stimulus/S{: >3}".format(marker)
            val = marker + TASK_BASED_INCREMENTS[task]
            event_id[key] = val

        events, event_id = mne.events_from_annotations(raw, event_id)

        # Add bad segments and channels to the data
        fname_annots = fname_annots_template.format(subj, task)
        fname_channels = fname_channels_template.format(subj, task)

        # annotations have orig_date "anonymized", so it will be set to
        # None upon reading.
        # For concatenating our annotations with raw.annotations however,
        # they need to have the same orig_time property.
        # Taking the raw.annotations.orig_time results in the expected
        # alignment of raw and annotations
        annots = mne.read_annotations(fname_annots)
        annots = mne.Annotations(annots.onset, annots.duration,
                                 annots.description, raw.annotations.orig_time)

        raw.set_annotations(raw.annotations + annots)

        # Collect bads across all tasks
        # Later set concatenated raw.info["bads"]
        # See also: raw.load_bad_channels
        with open(fname_channels, "r") as fin:
            bads += [line.strip() for line in fin.readlines()]

        # append task based datafiles to list for concatenation
        all_events.append(events)
        files.append(raw)

    # Set the union of bad channels across tasks to each task
    for raw in files:
        raw.info["bads"] = list(set(bads))

    # Concatenate datafiles and events
    raw, events = mne.concatenate_raws(files, events_list=all_events)

    # Save
    if not skip_events:
        mne.write_events(fname_events_template.format(subj), events)

    if not skip_raw:
        raw.save(fname_rawconcat_template.format(subj))

In [None]:
# Run the pipeline in parallel over subjects
pool_inputs = itertools.product(subjects, [tasks], [name_templates], [overwrite])

with multiprocessing.Pool(NJOBS) as pool:
    pool.starmap(load_and_concatenate, pool_inputs)