# Replay in Aversive Environments - Sequenceness analysis

#### _This is a template that will be parameterised and run via [Papermill](http://papermill.readthedocs.io/) for each subject_

This notebook uses the classifer trained on the localiser data to detect spontaneous state reactivation during the planning and rest phases of the task.

Analysis steps:

1. Loading task data and classifier
2. Applying the classifer to the task data to generate time X state reactivation probabilities matrices
3. Running the GLM-based sequenceness estimation procedure using a sliding window approach

## Imports

In [None]:
import os
import sys
sys.path.insert(0, 'code')
import mne
import matplotlib.pyplot as plt
import numpy as np
from sklearn.externals import joblib
from state_prediction import *
from sequenceness import *
from utils import *
%matplotlib inline

np.random.seed(100)

## Parameters

In [None]:
# DEFAULT PARAMETERS - OVERRRIDEN BY PAPERMILL EXECUTION
session_id = '001'  # ID of the scanning session
output_dir = 'data/derivatives'  # Where the output data should go
window_width = 40  # Width of the sliding window used for sequenceness analysis
classifier_window = [5, 6] # Window used for classification
classifier_center_idx = 37  # The center index of the classification window, post stimulus onset
max_lag = 20  # Maximum time-lag to look at sequenceness for
correct_alpha = True  # Correct for alpha oscillations (only if using GLM)
glm_constant = False  # Use constant (only if using GLM)
method = 'cc'  # Method for assessing sequenceness, 'cc' for cross-correlation (e.g. Kurth-Nelson, Eldar), 'glm' for GLM (e.g. Liu)
scale_data = False  # Scale state reactivation probabilities prior to sequenceness analysis
n_stim = 14  # Number of stimuli

## State detection

### Load the classifier

First we load the classifier that we previously trained on the localiser data

In [None]:
clf = joblib.load(os.path.join(output_dir, 'classifier', 'classifier_idx_{0}'.format(classifier_center_idx + 50), 'sub-{0}_classifier_idx_{1}.pkl').format(session_id, classifier_center_idx + 50))

### Get the task data

We're interested in the planning and rest phases so we'll select these.

In [None]:
planning_epochs = mne.read_epochs(os.path.join(output_dir, 'preprocessing/task', 'sub-{0}_ses-01_task-AversiveLearningReplay_run-planning_proc_ICA-epo.fif.gz').format(session_id))
outcome_epochs = mne.read_epochs(os.path.join(output_dir, 'preprocessing/task', 'sub-{0}_ses-01_task-AversiveLearningReplay_run-task_outcome_proc_ICA-epo.fif.gz').format(session_id))

# Get the data as a numpy array, excluding non-MEG channels
picks_meg = mne.pick_types(planning_epochs.info, meg=True, ref_meg=False)
planning_X = planning_epochs.get_data()[:, picks_meg, :] # MEG signals: n_epochs, n_channels, n_times
outcome_X = outcome_epochs.get_data()[:, picks_meg, :]

In [None]:
assert np.isnan(planning_X).any() == False, "Nans present in planning data"
assert np.isnan(outcome_X).any() == False, "Nans present in outcome data"
assert np.isinf(planning_X).any() == False, "Infs present in planning data"
assert np.isinf(outcome_X).any() == False, "Infs present in outcome data"

### State detection

Here we iterate over trials, reshape the data for each trial into the format `[n_trials, n_sensors, n_timepoints]`, where the first dimension is 1 and the final dimension is the timepoint of interest plus additional adjacent timepoints used as extra features, and finally and use the `predict_proba` method of the fitted classifier to get predicted state reactivation probabilities for every timepoint within the trial.


This involves a lot of for loops and could probably be made far more efficient...

In [None]:
planning_state_reactivation = predict_states(planning_X, clf, shifts=classifier_window, n_stim=n_stim)
assert np.isnan(planning_state_reactivation).any() == False, "Nans present in planning state reactivation array"
assert np.isinf(planning_state_reactivation).any() == False, "Infs present in planning state reactivation array"

outcome_state_reactivation = predict_states(outcome_X, clf, shifts=classifier_window, n_stim=n_stim)
assert np.isnan(outcome_state_reactivation).any() == False, "Nans present in outcome state reactivation array"
assert np.isinf(outcome_state_reactivation).any() == False, "Infs present in outcome state reactivation array"

In [None]:
if not os.path.exists(os.path.join(output_dir, 'state_reactivation_arrays', 'planning', 'classifier_idx_{0}'.format(classifier_center_idx))):
    os.makedirs(os.path.join(output_dir, 'state_reactivation_arrays', 'planning', 'classifier_idx_{0}'.format(classifier_center_idx)))
np.save(os.path.join(output_dir, 'state_reactivation_arrays', 'planning', 'classifier_idx_{0}'.format(classifier_center_idx), 'sub-{0}_planning_state_reactivation_idx_{1}.pkl'.format(session_id, classifier_center_idx)), planning_state_reactivation)

if not os.path.exists(os.path.join(output_dir, 'state_reactivation_arrays', 'outcome', 'classifier_idx_{0}'.format(classifier_center_idx))):
    os.makedirs(os.path.join(output_dir, 'state_reactivation_arrays', 'outcome', 'classifier_idx_{0}'.format(classifier_center_idx)))
np.save(os.path.join(output_dir, 'state_reactivation_arrays', 'outcome', 'classifier_idx_{0}'.format(classifier_center_idx), 'sub-{0}_outcome_state_reactivation_idx_{1}.pkl'.format(session_id, classifier_center_idx)), outcome_state_reactivation)

In [None]:
# Convert to StateReactivation class
outcome_seq = StateReactivation(outcome_state_reactivation)
planning_seq = StateReactivation(planning_state_reactivation)

## Sequenceness analysis

After determining the state reactivation probabilities for each trial, we can submit this data to the sequenceness analysis. We use a GLM approach here.

### Load transition matrix

Here we load the transition matrix of the task, which is necessary for sequenceness analysis. We then subset this matrix to get the four arms of the task tree.

In [None]:
transition_matrix = np.loadtxt(r'task/Task_information/transition_matrix.txt')

matrices = []

# Select individual arms
for start in [0, 1, 2, 3]:
    if start in [0,1]:
        m = select_path(transition_matrix, start, 12)
    else:
        m = select_path(transition_matrix, start, 13)
    matrices.append(m)

### Calculate sequenceness

In [None]:
outcome_windowed_sequenceness = outcome_seq.get_windowed_sequenceness(max_lag, matrices, alpha=correct_alpha, 
                                                                      width=window_width, remove_first=False, constant=glm_constant, set_zero=False, scale=scale_data, method=method)
planning_windowed_sequenceness = planning_seq.get_windowed_sequenceness(max_lag, matrices, alpha=correct_alpha, 
                                                                        width=window_width, remove_first=False, constant=glm_constant, set_zero=False, scale=scale_data, method=method)

## Save the sequenceness data

In [None]:
if not os.path.exists(os.path.join(output_dir, 'sw_sequenceness', 'planning', 'classifier_idx_{0}'.format(classifier_center_idx))):
    os.makedirs(os.path.join(output_dir, 'sw_sequenceness', 'planning', 'classifier_idx_{0}'.format(classifier_center_idx)))
joblib.dump(planning_windowed_sequenceness, os.path.join(output_dir, 'sw_sequenceness', 'planning', 'classifier_idx_{0}'.format(classifier_center_idx), 'sub-{0}_planning_sequenceness_idx_{1}__{2}.pkl'.format(session_id, classifier_center_idx, method)))

if not os.path.exists(os.path.join(output_dir, 'sw_sequenceness', 'outcome', 'classifier_idx_{0}'.format(classifier_center_idx))):
    os.makedirs(os.path.join(output_dir, 'sw_sequenceness', 'outcome', 'classifier_idx_{0}'.format(classifier_center_idx)))
joblib.dump(outcome_windowed_sequenceness, os.path.join(output_dir, 'sw_sequenceness', 'outcome', 'classifier_idx_{0}'.format(classifier_center_idx), 'sub-{0}_outcome_sequenceness_idx_{1}__{2}.pkl'.format(session_id, classifier_center_idx, method)))