# Replay in Aversive Environments - Sequenceness analysis

#### _This is a template that will be parameterised and run via [Papermill](http://papermill.readthedocs.io/) for each subject_

This notebook uses the classifer trained on the localiser data to detect spontaneous state reactivation during the planning and rest phases of the task.

Analysis steps:

1. Loading task data and classifier
2. Applying the classifer to the task data to generate time X state reactivation probabilities matrices
3. Running the GLM-based sequenceness estimation procedure

## Imports

In [None]:
# import os
# os.chdir('..')
# %load_ext autoreload
# %autoreload 2

import sys
sys.path.insert(0, 'code')
import mne
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import re
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import KFold
from sklearn.decomposition import PCA
from sklearn.metrics import confusion_matrix, accuracy_score, recall_score
from sklearn.preprocessing import FunctionTransformer
from sklearn.model_selection import RandomizedSearchCV, cross_val_predict
from sklearn.externals import joblib
import os
import papermill as pm
from state_prediction import *
from sequenceness import *
from utils import *
%matplotlib inline

np.random.seed(100)

## Parameters

In [None]:
session_id = 'MG05513'  # ID of the scanning session
output_dir = 'data/derivatives'  # Where the output data should go
eye_tracking = True  # If True, eye-tracking measures will be used for exclusion of blink-related ICA components
n_stim = 8  # Number of stimuli, including null
shifts = [-5, 6]  # Additional timepoints to use as features
max_lag = 40  # Maximum lag to use in sequencess analysis

In [None]:
subject_ids = pd.read_csv('subject_ids.csv', sep='\t')
behavioural_id = subject_ids.loc[subject_ids['meg'] == session_id]['behavioural'].values[0]
localiser_stimuli_file = os.path.join('localiser/Data', [i for i in os.listdir('localiser/Data') if behavioural_id in i and 'stimuli' in i][0])
task_stimuli_file = os.path.join('task/Data/behavioural/logs', [i for i in os.listdir('task/Data/behavioural/logs') if behavioural_id in i and 'stimuli' in i][0])

In [None]:
def get_stimuli(log_file):
    with open(log_file, 'r') as f:
        stimuli = f.read().split(',')
    stimuli = [re.search('[0-9]{2}', i).group() for i in stimuli]
    return stimuli

def match_stimuli(localiser, task):
    localiser_stimuli = get_stimuli(localiser)
    task_stimuli = get_stimuli(task)
    print(localiser_stimuli, task_stimuli)
    new_idx = [localiser_stimuli.index(i) for i in task_stimuli]
    return new_idx

In [None]:
correct_idx = match_stimuli(localiser_stimuli_file, task_stimuli_file) + [7]

## State detection

### Load the classifier

First we load the classifier that we previously trained on the localiser data

In [None]:
clf = joblib.load(os.path.join(output_dir, 'classifier', 'sub-{0}_classifier.pkl').format(session_id)) 

### Get the task data

We're interested in the planning and rest phases so we'll select these.

In [None]:
task_epochs = mne.read_epochs(os.path.join(output_dir, 'preprocessing/task', 'sub-{0}_ses-01_task-AversiveLearningReplay_run-task_proc_ICA-epo.fif.gz').format(session_id))
planning_epochs = task_epochs['planning']
rest_epochs = task_epochs['rest']

# Get the data as a numpy array, excluding non-MEG channels
picks_meg = mne.pick_types(task_epochs.info, meg=True, ref_meg=False)
planning_X = planning_epochs.get_data()[:, picks_meg, :] # MEG signals: n_epochs, n_channels, n_times
rest_X = rest_epochs.get_data()[:, picks_meg, :]

In [None]:
assert np.isnan(planning_X).any() == False, "Nans present in planning data"
assert np.isnan(rest_X).any() == False, "Nans present in rest data"
assert np.isinf(planning_X).any() == False, "Infs present in planning data"
assert np.isinf(rest_X).any() == False, "Infs present in rest data"

### State detection

Here we iterate over trials, reshape the data for each trial into the format `[n_trials, n_sensors, n_timepoints]`, where the first dimension is 1 and the final dimension is the timepoint of interest plus additional adjacent timepoints used as extra features, and finally and use the `predict_proba` method of the fitted classifier to get predicted state reactivation probabilities for every timepoint within the trial.


This involves a lot of for loops and could probably be made far more efficient...

In [None]:
# pca = joblib.load(os.path.join(output_dir, 'classifier', 'sub-{0}_pca.pkl').format(session_id)) 

In [None]:
# planning_X = pca.transform(planning_X)
# rest_X = pca.transform(rest_X)

In [None]:
planning_state_reactivation = predict_states(planning_X, clf, shifts=shifts)
assert np.isnan(planning_state_reactivation).any() == False, "Nans present in planning state reactivation array"
assert np.isinf(planning_state_reactivation).any() == False, "Infs present in planning state reactivation array"

rest_state_reactivation = predict_states(rest_X, clf, shifts=shifts)
assert np.isnan(rest_state_reactivation).any() == False, "Nans present in rest state reactivation array"
assert np.isinf(rest_state_reactivation).any() == False, "Infs present in rest state reactivation array"

# # Save state probabilities
np.save(os.path.join(output_dir, 'state_reactivation_arrays', 'planning', 'sub-{0}_planning_state_reactivation'.format(session_id)), planning_state_reactivation)
np.save(os.path.join(output_dir, 'state_reactivation_arrays', 'rest', 'sub-{0}_rest_state_reactivation'.format(session_id)), rest_state_reactivation)

In [None]:
planning_state_reactivation = planning_state_reactivation[..., correct_idx]
rest_state_reactivation = rest_state_reactivation[..., correct_idx]

In [None]:
# Convert to StateReactivation class
rest_state_reactivation = StateReactivation(rest_state_reactivation)
planning_state_reactivation = StateReactivation(planning_state_reactivation)


In [None]:
# # Convert to StateReactivation class
# planning_state_reactivation = StateReactivation(planning_X, clf)
# rest_state_reactivation = StateReactivation(rest_X, clf)

# planning_state_reactivation.predict_states(shifts=shifts)
# assert np.isnan(planning_state_reactivation.reactivation_array).any() == False, "Nans present in planning state reactivation array"
# assert np.isinf(planning_state_reactivation.reactivation_array).any() == False, "Infs present in planning state reactivation array"

# rest_state_reactivation.predict_states(shifts=shifts)
# assert np.isnan(rest_state_reactivation.reactivation_array).any() == False, "Nans present in rest state reactivation array"
# assert np.isinf(rest_state_reactivation.reactivation_array).any() == False, "Infs present in rest state reactivation array"

# # Save state probabilities
# # np.save(os.path.join(output_dir, 'state_reactivation_arrays', 'planning', 'sub-{0}_planning_state_reactivation'.format(session_id)), planning_state_reactivation)
# # np.save(os.path.join(output_dir, 'state_reactivation_arrays', 'rest', 'sub-{0}_rest_state_reactivation'.format(session_id)), rest_state_reactivation)

### Plot state detection probabilities

We can plot these state X time arrays to view how states are reactivated over time on each trial.

_This is commented out here as it's an interactive plot which takes a while to run/load and increases the size of the notebook_

In [None]:
# # REST PHASE
# plot_state_prob(rest_state_probabilities, 'Rest phase')

# # PLANNING PHASE
# plot_state_prob(planning_state_probabilities, 'Planning phase')

## Sequenceness analysis

After determining the state reactivation probabilities for each trial, we can submit this data to the sequenceness analysis. We use a GLM approach here.

### Load transition matrix

Here we load the transition matrix of the task, which is necessary for sequenceness analysis. We then subset this matrix to get the two branches of the task tree.

In [None]:
transition_matrix = np.loadtxt(r'task/Task_information/transition_matrix.txt')

matrices = [transition_matrix]

for n, i in enumerate([5, 6]):
    matrices.append(select_path(transition_matrix, i))

# Cross-state transitions for each stage
v = np.zeros((7, 7))
v[1, 2] = 1
v2 = np.zeros((7, 7))
v2[3, 4] = 1
v3 = np.zeros((7, 7))
v3[5, 6] = 1

# Off-matrix transitions
v4 = 1 - transition_matrix
v5 = 1 - transition_matrix.T
for i in range(7):
    v4[i, :i+1] = 0
    v5[:i+1, i] = 0

# Cross-path transitions
v6 = v + v2 + v3

## two step
v7 = np.zeros([7, 7])
v7[1, 5] = 1
v7[2, 6] = 1
v8 = v7.T

    
matrices = matrices + [v, v2, v3, v4, v5, v6, v7, v8]

### Run the analysis for each trial

This returns a dictionary with entries representing forwards and reverse sequenceness, along with the difference between the two.

This is repeated for 1000 permuted transition matrices that share no transitions with the true matrix or its inverse.

In [None]:
# permuted_matrices = generate_permuted_matrices(transition_matrix, n_permutations=20, n_transitions=3)

In [None]:
# plot_matrices(pepermuted_matrices = generate_permuted_matrices(transition_matrix, n_permutations=20, n_transitions=3)rmuted_matrices)

In [None]:
# rest_sequenceness, null_rest_sequenceness = rest_state_reactivation.get_sequenceness(max_lag, matrices, alpha=True, remove_first=True, permuted_matrices=permuted_matrices, constant=False)
# planning_sequenceness, null_planning_sequenceness = planning_state_reactivation.get_sequenceness(max_lag, matrices, alpha=True, remove_first=True, permuted_matrices=permuted_matrices, constant=False)

In [None]:
# rest_sequenceness_segments, _ = rest_state_reactivation.get_sequenceness_segments(max_lag, matrices, alpha=False, remove_first=False, permuted_matrices=[], constant=True, n_segments=5)
# planning_sequenceness_segments, _ = planning_state_reactivation.get_sequenceness_segments(max_lag, matrices, alpha=False, remove_first=False, permuted_matrices=[], constant=True, n_segments=5)

In [None]:
# f, ax = plt.subplots(2, figsize=(6, 10), facecolor='white')

# labels = ['Whole matrix', 'Arm 1', 'Arm 2']

# for i in range(3):
#     ax[0].plot(rest_sequenceness['forwards'][..., 9].mean(axis=0), label=labels[i])
# ax[0].set_xlabel('Lag')
# ax[0].set_title('Rest sequenceness')
# ax[0].set_ylabel(r'Backward $\leftarrow$ Sequenceness $\rightarrow$ Forward')
# ax[0].legend()

# for i in range(3):
#     ax[1].plot(planning_sequenceness['backwards'][..., 9].mean(axis=0), label=labels[i])
# ax[1].set_xlabel('Lag')
# ax[1].set_title('Planning sequenceness')
# ax[1].set_ylabel(r'Backward $\leftarrow$ Sequenceness $\rightarrow$ Forward')
# ax[1].legend();



## Save the sequenceness data

In [None]:
joblib.dump(planning_state_reactivation, os.path.join(output_dir, 'sequenceness', 'planning', 'sub-{0}_planning_sequenceness.pkl'.format(session_id)))
joblib.dump(rest_state_reactivation, os.path.join(output_dir, 'sequenceness', 'rest', 'sub-{0}_rest_sequenceness.pkl'.format(session_id)))
# joblib.dump(planning_sequenceness, os.path.join(output_dir, 'sequenceness', 'planning', 'sub-{0}_planning_sequenceness.pkl'.format(session_id)))
# joblib.dump(rest_sequenceness, os.path.join(output_dir, 'sequenceness', 'rest', 'sub-{0}_rest_sequenceness.pkl'.format(session_id)))
# joblib.dump(null_planning_sequenceness, os.path.join(output_dir, 'sequenceness', 'planning', 'sub-{0}_null_planning_sequenceness.pkl'.format(session_id)))
# joblib.dump(null_rest_sequenceness, os.path.join(output_dir, 'sequenceness', 'rest', 'sub-{0}_null_rest_sequenceness.pkl'.format(session_id)))
# joblib.dump(planning_sequenceness_segments, os.path.join(output_dir, 'sequenceness', 'planning', 'sub-{0}_planning_sequenceness_segments.pkl'.format(session_id)))
# joblib.dump(rest_sequenceness_segments, os.path.join(output_dir, 'sequenceness', 'rest', 'sub-{0}_rest_sequenceness_segments.pkl'.format(session_id)))

## Outcome reactivation

Load the classifier

In [None]:
outcome_clf = joblib.load(os.path.join(output_dir, 'classifier', 'sub-{0}_outcome_classifier.pkl').format(session_id)) 

Predict outcome reactivations for planning, rest, and final state

In [None]:
final_state_epochs = mne.read_epochs(os.path.join(output_dir, 'preprocessing/task', 'sub-{0}_ses-01_task-AversiveLearningReplay_run-task_final_state_proc_ICA-epo.fif.gz').format(session_id))
final_state_X = final_state_epochs.get_data()[:, picks_meg, :]

In [None]:
rest_outcome_reactivation = predict_states(rest_X, outcome_clf, n_stim=3)
planning_outcome_reactivation = predict_states(planning_X, outcome_clf, n_stim=3)
final_state_outcome_reactivation = predict_states(final_state_X, outcome_clf, n_stim=3)

In [None]:
# Save state probabilities
np.save(os.path.join(output_dir, 'outcome_reactivation_arrays', 'planning', 'sub-{0}_planning_outcome_reactivation'.format(session_id)), planning_outcome_reactivation)
np.save(os.path.join(output_dir, 'outcome_reactivation_arrays', 'rest', 'sub-{0}_rest_outcome_reactivation'.format(session_id)), rest_outcome_reactivation)
np.save(os.path.join(output_dir, 'outcome_reactivation_arrays', 'final_state', 'sub-{0}_final_state_outcome_reactivation'.format(session_id)), final_state_outcome_reactivation)