In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import numpy.random as npr
import pandas as pd
import matplotlib.pyplot as plt
import ssm
from sklearn import preprocessing
from sklearn.model_selection import KFold
from scipy import stats

from notebooks.imports import *
from config import dir_config, main_config
from src.utils.glm_hmm_utils import *
import pickle
import copy

### Configuration

In [None]:
from config import dir_config, main_config

raw_dir = Path(dir_config.data.raw)
processed_dir = Path(dir_config.data.processed)

metadata = pd.read_csv(Path(processed_dir, "processed_metadata_accu_60.csv"))
data = pd.read_csv(Path(processed_dir, "processed_data_accu_60_all.csv"))

In [None]:
experiment_sites = ["Stanford"]

metadata = metadata[metadata['experiment_site'].isin(experiment_sites)].reset_index(drop=True)
data = data[data['subject_id'].isin(metadata['subject_id'])].reset_index(drop=True)

# add session_id to data with matching subject_id and medication
metadata['session_id'] = metadata[['subject_id', 'treatment']].apply(lambda x: '_'.join(x.astype(str).str.upper()), axis=1)
data['session_id'] = data[['subject_id', 'medication']].apply(lambda x: '_'.join(x.astype(str).str.upper()), axis=1)
data.choice = data.choice.fillna(-1).astype(int)
data.target = data.target.fillna(-1).astype(int)
data.outcome = data.outcome.fillna(-1).astype(int)

### Helper Functions

In [None]:
def extract_previous_trial_data(trial_data, invalid_idx):
    npr.seed(1)
    prev_choice = np.hstack([trial_data.choice[0], trial_data.choice[:-1]])  # 0:awayPrior, 1:toPrior of previous valid trial
    prev_target = np.hstack([trial_data.target[0], trial_data.target[:-1]]) * 2 - 1  # 0:awayPrior, 1:toPrior of previous valid trial
    prev_color = np.hstack([trial_data.color[0], trial_data.color[:-1]])  # 0:equalPrior, 1:UnequalPrior of previous valid trial

    # indices where the previous trial is invalid/valid
    prev_invalid_idx = np.array(invalid_idx) + 1
    if 0 in invalid_idx:
        prev_invalid_idx = np.append(0, prev_invalid_idx)
    prev_valid_idx = np.setdiff1d(np.arange(len(trial_data)), prev_invalid_idx)

    for i in prev_invalid_idx[prev_invalid_idx < len(trial_data)]:
        if i < prev_valid_idx[0]: #randomly sample if no previous valid trials
            prev_choice[i] = np.random.binomial(1,0.5)
            prev_target[i] = np.random.binomial(1,0.5) * 2 - 1
            prev_color[i]  = np.random.binomial(1,0.5)
        else:
            last_valid =  np.where(prev_valid_idx<i)[0][-1]
            prev_choice[i] = prev_choice[prev_valid_idx[last_valid]]
            prev_target[i] = prev_target[prev_valid_idx[last_valid]]
            prev_color[i] = prev_color[prev_valid_idx[last_valid]]

    prev_choice = (prev_choice * 2) - 1 # -1:awayPrior, 1:toPrior of previous valid trial
    return prev_choice.astype(int), prev_target.astype(int), prev_color.astype(int)

def prepare_input_data(data, input_dim, invalid_idx):
    X = np.ones((1, data.shape[0], input_dim))

    X[0,:,0] = data.signed_coherence / 100
    X[0,:,1] = data.color

    prev_choice, prev_target, prev_color = extract_previous_trial_data(data, invalid_idx)
    if input_dim == 4:
        X[0,:,3] = prev_choice
    elif input_dim == 5:
        X[0,:,3] = prev_choice
        X[0,:,4] = prev_target
    elif input_dim == 6:
        X[0,:,3] = prev_choice
        X[0,:,4] = prev_target
        X[0,:,5] = prev_color
    return list(X)

### Data processing

In [None]:
print("------------- info ----------------")
print(data.info())
print("------------- Head ----------------")
print(data.head())
print("\n\n------------- describe ----------------\n\n")
print(data.describe())
print("------------- nan counts ----------------")
print(data.isnull().sum())
print("\n\n------------- dtypes ----------------\n\n")
print(data.dtypes)
print("\n\n------------- shape ----------------\n\n")
print(data.shape)

#### Data preparation

In [None]:
off_med_sessions = metadata[metadata['treatment'] == 'OFF'].session_id.unique()
on_med_sessions = metadata[metadata['treatment'] == 'ON'].session_id.unique()
off_med_sessions, on_med_sessions

#### Create design matrix (input, output, mask)

In [None]:
n_states = 2       # number of discrete states
obs_dim = 1           # number of observed dimensions: choice(toPrior/awayPrior)
num_categories = 2    # number of categories for output
input_dim = 5        # input dimensions: current signed coherence, current stimulus color, 1(bias), previous choice(toPrior/awayPrior), previous target side(toPrior/awayPrior), previous color(toPrior/awayPrior)

### Off medication sessions

In [None]:
inputs_session_wise = []
choices_session_wise = []
invalid_idx_session_wise = []
masks_session_wise = []
reaction_time_session_wise = []

# off medication sessions
for session in off_med_sessions:
    session_data = data[data['session_id'] == session].reset_index(drop=True)


    invalid_idx = np.where(session_data.outcome < 0)[0]
    valid_idx = np.where(session_data.outcome >= 0)[0]

    inputs = prepare_input_data(session_data, input_dim, invalid_idx)
    choices = session_data.choice.values.reshape(-1,1).astype('int')

    # for training, replace -1 with random sample from 0,1
    choices[choices == -1] = npr.choice([0,1],invalid_idx.shape[0])
    mask = np.ones_like(choices, dtype=bool)
    mask[invalid_idx] = 0
    reaction_time = np.array(session_data.reaction_time)

    masks_session_wise.append(mask)
    inputs_session_wise += inputs
    choices_session_wise.append(choices)
    reaction_time_session_wise.append(reaction_time)

off_med_inputs_aggregate, off_med_choices_aggregate, off_med_masks_aggregate = [], [], []
off_med_inputs_aggregate.append(np.vstack(inputs_session_wise))
off_med_choices_aggregate.append(np.vstack(choices_session_wise))
off_med_masks_aggregate.append(np.vstack(masks_session_wise))

unnormalized_off_med_inputs = copy.deepcopy(inputs_session_wise)

# scaling signed coherence
off_med_inputs_aggregate[0][off_med_masks_aggregate[0][:,0],0] = preprocessing.scale(off_med_inputs_aggregate[0][off_med_masks_aggregate[0][:,0],0], axis=0)
for idx in range(len(off_med_sessions)):
    inputs_session_wise[idx][masks_session_wise[idx][:,0],0] = preprocessing.scale(inputs_session_wise[idx][masks_session_wise[idx][:,0],0], axis=0)


In [None]:
models_glm_hmm_off_med, fit_lls_glm_hmm_off_med = global_fit(off_med_choices_aggregate, off_med_inputs_aggregate, masks=off_med_masks_aggregate, n_iters= 1000, n_initializations=20)

In [None]:
# get best model of 20 initializations for each state
init_params = {
    'glm_weights': {},
    'transition_matrices': {}
}
for n_states in np.arange(2,6):
    best_idx = fit_lls_glm_hmm_off_med[n_states].index(max(fit_lls_glm_hmm_off_med[n_states]))
    init_params['glm_weights'][n_states] = models_glm_hmm_off_med[n_states][best_idx].observations.params
    init_params['transition_matrices'][n_states] = models_glm_hmm_off_med[n_states][best_idx].transitions.params

In [None]:
# session-wise fitting with 5 fold cross-validation
models_session_state_fold_off_med, train_ll_session_off_med, test_ll_session_off_med = session_wise_fit_cv(choices_session_wise, inputs_session_wise, masks=masks_session_wise,
                                                                                    n_sessions=len(off_med_sessions), init_params=init_params, n_iters= 1000)

In [None]:
off_medication_results = {
    'global':{
		'inputs': off_med_inputs_aggregate,
		'choices': off_med_choices_aggregate,
		'masks': off_med_masks_aggregate,
		'models': models_glm_hmm_off_med,
		'fit_lls': fit_lls_glm_hmm_off_med,
		'best_params': init_params
	},
	'session':{
		'session_ids': off_med_sessions,
        'unnormalized_inputs': unnormalized_off_med_inputs,
		'inputs': inputs_session_wise,
		'choices': choices_session_wise,
		'masks': masks_session_wise,
		'reaction_time': reaction_time_session_wise,
		'models': models_session_state_fold_off_med,
		'train_lls': train_ll_session_off_med,
		'test_lls': test_ll_session_off_med
	}
}


# with open(Path(processed_dir, f'glm_hmm_off_meds_result.pkl'), 'wb') as f:
#     pickle.dump(off_medication_results, f)

### On medication sessions

In [None]:
inputs_session_wise = []
choices_session_wise = []
invalid_idx_session_wise = []
masks_session_wise = []
reaction_time_session_wise = []

# on medication sessions
for session in on_med_sessions:
    session_data = data[data['session_id'] == session].reset_index(drop=True)


    invalid_idx = np.where(session_data.outcome < 0)[0]
    valid_idx = np.where(session_data.outcome >= 0)[0]

    inputs = prepare_input_data(session_data, input_dim, invalid_idx)
    choices = session_data.choice.values.reshape(-1,1).astype('int')

    # for training, replace -1 with random sample from 0,1
    choices[choices == -1] = npr.choice([0,1],invalid_idx.shape[0])
    mask = np.ones_like(choices, dtype=bool)
    mask[invalid_idx] = 0
    reaction_time = np.array(session_data.reaction_time)

    masks_session_wise.append(mask)
    inputs_session_wise += inputs
    choices_session_wise.append(choices)
    reaction_time_session_wise.append(reaction_time)

on_med_inputs_aggregate, on_med_choices_aggregate, on_med_masks_aggregate = [], [], []
on_med_inputs_aggregate.append(np.vstack(inputs_session_wise))
on_med_choices_aggregate.append(np.vstack(choices_session_wise))
on_med_masks_aggregate.append(np.vstack(masks_session_wise))

unnormalized_on_med_inputs = copy.deepcopy(inputs_session_wise)

# scaling signed coherence
on_med_inputs_aggregate[0][on_med_masks_aggregate[0][:,0],0] = preprocessing.scale(on_med_inputs_aggregate[0][on_med_masks_aggregate[0][:,0],0], axis=0)
for idx in range(len(on_med_sessions)):
    inputs_session_wise[idx][masks_session_wise[idx][:,0],0] = preprocessing.scale(inputs_session_wise[idx][masks_session_wise[idx][:,0],0], axis=0)


In [None]:
models_glm_hmm_on_med, fit_lls_glm_hmm_on_med = global_fit(on_med_choices_aggregate, on_med_inputs_aggregate, masks=on_med_masks_aggregate, n_iters= 1000, n_initializations=20)

In [None]:
# get best model of 20 initializations for each state
init_params = {
    'glm_weights': {},
    'transition_matrices': {}
}
for n_states in np.arange(2,6):
    best_idx = fit_lls_glm_hmm_on_med[n_states].index(max(fit_lls_glm_hmm_on_med[n_states]))
    init_params['glm_weights'][n_states] = models_glm_hmm_on_med[n_states][best_idx].observations.params
    init_params['transition_matrices'][n_states] = models_glm_hmm_on_med[n_states][best_idx].transitions.params

In [None]:
# session-wise fitting with 5 fold cross-validation
models_session_state_fold_on_med, train_ll_session_on_med, test_ll_session_on_med = session_wise_fit_cv(choices_session_wise, inputs_session_wise, masks=masks_session_wise,
                                                                                    n_sessions=len(on_med_sessions), init_params=init_params, n_iters= 1000)

In [None]:
on_medication_results = {
    'global':{
		'inputs': on_med_inputs_aggregate,
		'choices': on_med_choices_aggregate,
		'masks': on_med_masks_aggregate,
		'models': models_glm_hmm_on_med,
		'fit_lls': fit_lls_glm_hmm_on_med,
		'best_params': init_params
	},
	'session':{
		'session_ids': on_med_sessions,
        'unnormalized_inputs': unnormalized_on_med_inputs,
		'inputs': inputs_session_wise,
		'choices': choices_session_wise,
		'masks': masks_session_wise,
		'reaction_time': reaction_time_session_wise,
		'models': models_session_state_fold_on_med,
		'train_lls': train_ll_session_on_med,
		'test_lls': test_ll_session_on_med
	}
}

# with open(Path(processed_dir, f'glm_hmm_on_meds_result.pkl'), 'wb') as f:
#     pickle.dump(on_medication_results, f)

In [None]:
off_medication_results["session"].keys()


In [None]:
glm_hmm_results = {
    "off_med_global": off_medication_results['global'],
    "on_med_global": on_medication_results['global'],
    "session_wise": {
        "session_ids": off_medication_results['session']['session_ids'].tolist() + on_medication_results['session']['session_ids'].tolist(),
        "unnormalized_inputs": off_medication_results['session']['unnormalized_inputs'] + on_medication_results['session']['unnormalized_inputs'],
		"inputs": off_medication_results['session']['inputs'] + on_medication_results['session']['inputs'],
		"choices": off_medication_results['session']['choices'] + on_medication_results['session']['choices'],
		"masks": off_medication_results['session']['masks'] + on_medication_results['session']['masks'],
		"reaction_time": off_medication_results['session']['reaction_time'] + on_medication_results['session']['reaction_time'],
        "train_ll": np.concatenate([off_medication_results['session']['train_lls'], on_medication_results['session']['train_lls']], axis=0),
        "test_ll": np.concatenate([off_medication_results['session']['test_lls'], on_medication_results['session']['test_lls']], axis=0)
    }
}

glm_hmm_results['session_wise']['models'] = models_session_state_fold_off_med
# Shift the keys of on-medication models and merge
off_session_len = len(models_session_state_fold_off_med)  # Fix variable name
glm_hmm_results['session_wise']['models'].update({
    key + off_session_len: value for key, value in models_session_state_fold_on_med.items()
})


with open(Path(processed_dir, f'glm_hmm_separate_initialization_result.pkl'), 'wb') as f:
    pickle.dump(glm_hmm_results, f)
