# Replay in Aversive Environments - MEG Preprocessing

#### _This is a template that will be parameterised and run via [Papermill](http://papermill.readthedocs.io/) for each subject_

This notebook performs preprocessing of localiser and task data. 

Preprocessing steps:

1. Identification and loading of raw data
2. Maxwell filtering
3. Filtering
4. Epoching
5. Downsampling
6. ICA


## Imports

In [1]:
from mne.io import read_raw_ctf
import mne
import matplotlib.pyplot as plt
from mne.preprocessing import ICA, create_eog_epochs, create_ecg_epochs
import numpy as np
import plotly
import plotly.graph_objs as go
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
import os
import time
import datetime
import yaml
import papermill as pm
import pandas as pd
np.random.seed(100)
%matplotlib inline



## Parameters

In [None]:
data_dir = 'data/'  # Directory containing data
session_id = 'MG05513'  # ID of the scanning session
n_runs = 9  # Number of runs
output_dir = 'data/derivatives'  # Where the output data should go
eye_tracking = True  # If True, eye-tracking measures will be used for exclusion of blink-related ICA components
n_stim = 8  # Number of stimuli
cores = 1  # Number of cores to use for parallel processing
blink_components = None

os.environ['OMP_NUM_THREADS'] = str(cores)

## Get data

Data is stored in [BIDS format](https://www.nature.com/articles/sdata2018110) - currently MNE doesn't directly read from BIDS however.

In [None]:
# Set the data directory for this subject
data_dir = os.path.join(data_dir, 'sub-{0}'.format(session_id), 'ses-01', 'meg')

# Find all files in the directory and make sure they're in the right order (i.e. ascending)
data = os.listdir(data_dir)
data = sorted([i for i in data if '.ds' in i and str(session_id) in i])

# Check we have the right number of runs
assert len(data) == n_runs, "Wrong number of data files, found {0}".format(len(data))

# See what has been found
print(data)

# Get all the data and read it in
raws = []
run_idx = range(0, n_runs)

# Read in each data set
for i in run_idx:
    start_time = time.time()
    raws.append(read_raw_ctf(os.path.join(data_dir, data[i]), preload=True))
    time_taken = time.time() - start_time
    print("Time taken = {0}".format(str(datetime.timedelta(seconds=time_taken))))

# Concatenate the runs
raw = mne.concatenate_raws(raws)
del raws  # delete the list of raw data to conserve memory

# Label eye-tracking channels as EOG
raw.set_channel_types({'UADC001-2910': 'eog', 'UADC002-2901': 'eog', 'UADC003-2901': 'eog'})

## Preprocessing

### Apply Maxwell filter
This is partly because some MNE functions don't seem to work properly if CTF compensation is turned on, however Maxwell filtering also appears to do a better job than CTF compensation at removing noise from movement etc (https://martinos.org/mne/stable/auto_tutorials/plot_brainstorm_phantom_ctf.html)

In [None]:
raw.apply_gradient_compensation(0)  # Remove CTF compensation
mf_kwargs = dict(origin=(0., 0., 0.), st_duration=10.)
raw = mne.preprocessing.maxwell_filter(raw, **mf_kwargs)

### Plot the raw data

In [None]:
raw.plot()

### Filter

Bandpass filter between 0.5hz and 45hz, using windowed FIR filter with MNE default settings.

In [None]:
print("FILTERING")
raw.filter(0.5, 50, method='fir', fir_design='firwin')

## ICA

ICA is performed on the raw data and is set to find the number of components that explains 95% of the variance. 

We don't do much in terms of selecting noise-related components here - the data is generally pretty clean and doesn't seem to benefit much from extra denoising, so we simply automatically detect blink-related components based on eye-tracking channels.

For some subjects eye tracking was poor due to equipment problems. If we're not able to identify a blink-related component automatically that may be due to an absent eye tracking channel, or just due to noisy eye tracking data - in this case the notebook raises an error and we can rerun with manually selected components to remove.

In [None]:
# Run ICA
picks_meg = mne.pick_types(raw.info, meg=True, ref_meg=False)
reject = dict(mag=5e-12, grad=4000e-13)
ica = ICA(n_components=0.95, method='fastica',
          random_state=100, max_iter=100).fit(raw, decim=10, picks=picks_meg, reject=reject)

# Plot components
ica.plot_components()

# Save decomposition
ica.save(os.path.join(output_dir, 'ICA', 'sub-{0}_ses-01_task-AversiveLearningReplay_proc-ICA.fif.gz').format(session_id))

# Find blink-related components
if blink_components is None or blink_components == 'None':
    blink_components, scores = ica.find_bads_eog(raw, threshold=1.5)
    ica.plot_scores(scores, exclude=blink_components, labels='blink')
    show_picks = np.abs(scores).argsort()[::-1][:5]
    ica.plot_components(blink_components, colorbar=True)
    
# Find ECG componenhts
ecg_epochs = create_ecg_epochs(raw, tmin=-.5, tmax=.5, picks=picks_meg)

ecg_components, scores = ica.find_bads_ecg(ecg_epochs, method='ctps')
ica.plot_scores(scores, exclude=ecg_components, labels='ecg')
ecg_components = ecg_components[:3]

if isinstance(blink_components, list):
    
    print("APPLYING ICA")

    # Only select a maximum of 2 components
    blink_components = blink_components[:2]

    ica.exclude = blink_components + ecg_components
    ica.apply(raw)

else:
    raise ValueError("No blink related components identified automatically - do this manually")

### Find events

This identifies triggers.

In [None]:
print("FINDING EVENTS")
events = mne.find_events(raw, stim_channel='UPPT001', shortest_event=1)

### Create epochs

Here we split the continuous data into epochs. Stimulus triggers are from 2 to the number of stimuli * 2 with a step of 2 (this is because sending odd numbers also triggers shocks in the actual task). Code 99 is used for null trials (only used in the localiser). We don't reject any trials because we want to decide on rejections later.

We're only selecting the planning and rest periods from the task here as these are the periods we'll be looking for replay during.

In [None]:
print("EPOCHING")

# Get event names
# Event numbers for the localiser are even numbers from 2 up to 2 * the number of stimuli
localiser_event_names = dict([('stimulus_{0}'.format(i), i) for i in list(range(2, n_stim * 2, 2))] + [('null', 99)])

# We get task event numbers from the task config file
with open('task/replay_task_settings.yaml', 'rb') as f:
    task_config = yaml.load(f)
task_event_names = task_config['triggers']
task_event_names['final_state'] = 26

# Sometimes we record a rest event before trials begin for some reason - if this happens we need to remove this.
if np.where(events[:, 2] == task_event_names['rest'])[0][0] < np.where(np.isin(events[:, 2], [task_event_names['planning'], task_event_names['outcome_only_warning']]))[0][0]:
    events = np.delete(events, np.where(events[:, 2] == task_event_names['rest'])[0][0], axis=0)
    
# Occasionally planning triggers get coded as 62 instead of 60, so change any 62s to 60s. This might be just in cases with the incorrect trigger timing (M200-203)
events[:, 2][events[:, 2] == 62] = 60
events[:, 2][events[:, 2] == 98] = 99 # 98 for null events for one subject these got recorded as 98 rather than 99 for no apparent reason

# Create the epoch objects
localiser_epochs = mne.Epochs(raw, events[np.isin(events[:, 2], list(localiser_event_names.values()))], tmin=-0.1, tmax=0.8, preload=True, event_id=localiser_event_names,
                              reject=None)
task_epochs = mne.Epochs(raw, events[np.isin(events[:, 2], list(task_event_names.values()))], 
                         tmin=0, tmax=np.max([task_config['MEG_durations']['start_duration'], task_config['MEG_durations']['rest_duration']]), 
                         preload=True, event_id={k: task_event_names[k] for k in ('planning', 'rest', 'outcome_only_warning')}, reject=None)

# Outcome + rest epochs
outcome_rest_epochs = mne.Epochs(raw, events[np.isin(events[:, 2], list(task_event_names.values()))], 
                         tmin=-0.1, tmax=8, 
                         preload=True, event_id={k: task_event_names[k] for k in ('shock_outcome', 'no_shock_outcome')}, reject=None)

# Epochs for other task events
final_state_epochs = mne.Epochs(raw, events[np.isin(events[:, 2], list(task_event_names.values()))], 
                         tmin=0, tmax=task_config['MEG_durations']['shock_symbol_delay'], 
                         preload=True, event_id={k: task_event_names[k] for k in ('final_state', 'outcome_only_outcome')}, reject=None)
outcome_epochs = mne.Epochs(raw, events[np.isin(events[:, 2], list(task_event_names.values()))], 
                         tmin=-0.1, tmax=task_config['MEG_durations']['shock_delay'], 
                         preload=True, event_id={k: task_event_names[k] for k in ('shock_outcome', 'no_shock_outcome')}, reject=None)

In [None]:
print("LOCALISER EVENTS")

for k, v in localiser_event_names.items():
    print("Number of {0} events = {1}".format(k, np.sum(events[:, 2] == v)))
    
print("TASK EVENTS")

for k, v in task_event_names.items():
    print("Number of {0} events = {1}".format(k, np.sum(events[:, 2] == v)))

### Plot the events

In [None]:
data = []

for n, i in enumerate(np.unique(events[:, 2])):
    data.append(go.Scatter(x=events[events[:, 2] == i][:, 0], y=[n] * len(events[events[:, 2] == i][:, 0]), mode = 'markers'))
    

layout = dict(title='Events', showlegend=False,
              xaxis=dict(title='Samples'), yaxis=dict(title='Event ID', tickvals=np.arange(len(np.unique(events[:, 2]))), ticktext=[str(i) for i in np.unique(events[:, 2])]), 
              width=1500, height=600)

fig = dict(data=data, layout=layout)
iplot(fig, filename='events')


### Check the events look right

First, we should have 600 localiser events (or thereabouts - a couple of runs got cut short by a trial or two)

In [None]:
assert len(localiser_epochs) > 590, 'Unexpected number of localiser trials, found {0}, expected 600'.format(len(localiser_epochs))

We can load in the behavioural data to compare the number of recorded trials there to what we find in the MEG triggers.

In [None]:
subject_ids = pd.read_csv('subject_ids.csv', sep='\t')
behavioural_id = subject_ids.loc[subject_ids['meg'] == session_id]['behavioural'].values[0]
behavioural_data = pd.read_csv(os.path.join('task/Data/behavioural', [i for i in os.listdir('task/Data/behavioural') if behavioural_id in i][0]))
behavioural_data.head()

Here we iterate over trials in the behavioural data and check whether the trial number is `NaN`. If the experiment is paused for any reason (e.g. electrodes coming loose), we record a `NaN` trial and the experiment subsequently restarts from this trial when unpaused - this allows us to remove any triggers from a trial that was restarted by identifying these trials in the behavioural data and removing associated planning/rest epochs from the MEG data.

In [None]:
for trial in range(len(behavioural_data)):
    if np.isnan(behavioural_data['trial_number'][trial]):  # check whether this is nan (trial where the task was paused)
        print("Dropping trial {0}".format(trial))
        if task_epochs[trial * 2 + 1]._name == 'rest':
            task_epochs = task_epochs.drop(trial * 2 + 1)
        task_epochs = task_epochs.drop(trial * 2)

The number of planning and outcome only trials should now add up to 100.

In [None]:
n_planning_oo_trials = len(task_epochs['planning']) + len(task_epochs['outcome_only_warning'])
assert n_planning_oo_trials == 100, 'Number of planning + outcome only trials is {0}, expected 100'.format(n_planning_oo_trials)

### Downsample

Resampling to 100hz.

In [None]:
del raw  # Delete the raw data variable to save memory

print("DOWNSAMPLING")
print('Original sampling rate: {0} Hz'.format(localiser_epochs.info['sfreq']))
localiser_epochs = localiser_epochs.copy().resample(100, npad='auto')
task_epochs = task_epochs.copy().resample(100, npad='auto')
final_state_epochs = final_state_epochs.copy().resample(100, npad='auto')
outcome_epochs = outcome_epochs.copy().resample(100, npad='auto')
outcome_rest_epochs = outcome_rest_epochs.copy().resample(100, npad='auto')
print('Downsampled sampling rate: {0} Hz'.format(localiser_epochs.info['sfreq']))

### Save the epoched data

In [None]:
localiser_epochs.save(os.path.join(output_dir, 'preprocessing/localiser', 'sub-{0}_ses-01_task-AversiveLearningReplay_run-localiser_proc_ICA-epo.fif.gz').format(session_id))
task_epochs.save(os.path.join(output_dir, 'preprocessing/task', 'sub-{0}_ses-01_task-AversiveLearningReplay_run-task_proc_ICA-epo.fif.gz').format(session_id))
final_state_epochs.save(os.path.join(output_dir, 'preprocessing/task', 'sub-{0}_ses-01_task-AversiveLearningReplay_run-task_final_state_proc_ICA-epo.fif.gz').format(session_id))
outcome_epochs.save(os.path.join(output_dir, 'preprocessing/task', 'sub-{0}_ses-01_task-AversiveLearningReplay_run-task_outcome_proc_ICA-epo.fif.gz').format(session_id))
outcome_rest_epochs.save(os.path.join(output_dir, 'preprocessing/task', 'sub-{0}_ses-01_task-AversiveLearningReplay_run-task_outcome_rest_proc_ICA-epo.fif.gz').format(session_id))