In [None]:
import numpy as np
import numpy.random as npr
import pandas as pd
import matplotlib.pyplot as plt
import ssm
from sklearn import preprocessing
from sklearn.model_selection import KFold
from scipy import stats

from notebooks.imports import *
from config import dir_config, main_config
from src.utils.glm_hmm_utils import *
import pickle
import copy


In [None]:
compiled_dir = Path(dir_config.data.compiled)
processed_dir = Path(dir_config.data.processed)

In [None]:
_TRIALS = 'valid_only'

#### utils

In [None]:

def extract_previous_data(trial_data, invalid_idx):
    npr.seed(1)
    prev_choice = np.hstack([trial_data.choice[0], trial_data.choice[:-1]])  # 0:awayRF, 1:toRF of previous valid trial
    prev_target = np.hstack([trial_data.target[0], trial_data.target[:-1]]) * 2 - 1  # -1:awayRF, 1:toRF of previous valid trial

    # indices where the previous trial is invalid/valid
    prev_invalid_idx = np.array(invalid_idx) + 1
    if 0 in invalid_idx:
        prev_invalid_idx = np.append(0, prev_invalid_idx)
    prev_valid_idx = np.setdiff1d(np.arange(len(trial_data)), prev_invalid_idx)

    for i in prev_invalid_idx[prev_invalid_idx < len(trial_data)]:
        if i < prev_valid_idx[0]: #randomly sample if no previous valid trials
            prev_choice[i] = np.random.binomial(1,0.5)
            prev_target[i] = np.random.binomial(1,0.5) * 2 - 1
        else:
            last_valid =  np.where(prev_valid_idx<i)[0][-1]
            prev_choice[i] = prev_choice[prev_valid_idx[last_valid]]
            prev_target[i] = prev_target[prev_valid_idx[last_valid]]

    prev_choice = (prev_choice * 2) - 1 # -1:awayRF, 1:toRF of previous valid trial
    return prev_choice.astype(int), prev_target.astype(int)



def prepare_input_data(data, input_dim, invalid_idx):
    X = np.ones((1, data.shape[0], input_dim))
    
    current_stimulus = data.coherence * (2*data.target-1)
    current_stimulus = current_stimulus / 100
    
    X[0,:,0] = current_stimulus
    X[0,:,2], X[0,:,3] = extract_previous_data(data, invalid_idx)
    return list(X)

### create design matrix (input, output, mask)

In [None]:
n_states = 2       # number of discrete states
obs_dim = 1           # number of observed dimensions: choice(toRF/awayRF)
num_categories = 2    # number of categories for output
input_dim = 4        # input dimensions: current signed coherence, 1(bias), previous choice(toRF/awayRF), previous target side(toRF/awayRF)


#### data preparation

In [None]:
session_metadata = pd.read_csv(Path(compiled_dir, "sessions_metadata.csv"), index_col=None)
inputs_session_wise = []
choices_session_wise = []
invalid_idx_session_wise = []
masks_session_wise = []
GP_trial_num_session_wise = []
prob_toRF_session_wise = []

for session_id in (session_metadata['session_id']):

    trial_data = pd.read_csv(Path(compiled_dir, session_id,f"{session_id}_trial.csv"), index_col=None)
    GP_trial_data = trial_data[trial_data.task_type == 1].reset_index()
    GP_trial_data.choice = GP_trial_data.choice.fillna(-1)
    GP_trial_data.target = GP_trial_data.target.fillna(-1)
    GP_trial_data.outcome = GP_trial_data.outcome.fillna(-1)

    invalid_idx = np.where(GP_trial_data.outcome < 0)[0]
    valid_idx = np.where(GP_trial_data.outcome >= 0)[0]

    inputs = prepare_input_data(GP_trial_data, input_dim, invalid_idx)
    choices = GP_trial_data.choice.values.reshape(-1,1).astype('int')
    
    if _TRIALS == 'all_trials':
        # for training, replace -1 with random sample from 0,1
        choices[choices == -1] = npr.choice(1,invalid_idx.shape[0])
        mask = np.ones_like(choices, dtype=bool)
        mask[invalid_idx] = 0
        GP_trial_num = np.array(GP_trial_data.trial_number)
        prob_toRF = np.array(GP_trial_data.prob_toRF)
        
    elif _TRIALS == 'valid_only':
        choices = choices[valid_idx,:]
        inputs[0] = inputs[0][valid_idx,:]
        mask = np.ones_like(choices, dtype=bool)
        GP_trial_num = np.array(GP_trial_data.trial_number)[valid_idx]
        prob_toRF = np.array(GP_trial_data.prob_toRF)[valid_idx]
        
    masks_session_wise.append(mask)
    inputs_session_wise += inputs
    choices_session_wise.append(choices)
    GP_trial_num_session_wise.append(GP_trial_num)
    prob_toRF_session_wise.append(prob_toRF)


inputs_aggregated, choices_aggregated, masks_aggregated = [], [], []
inputs_aggregated.append(np.vstack(inputs_session_wise))
choices_aggregated.append(np.vstack(choices_session_wise))
masks_aggregated.append(np.vstack(masks_session_wise))


In [None]:
unnormalized_inputs_aggregated = copy.deepcopy(inputs_aggregated)
unnormalized_inputs_session_wise = copy.deepcopy(inputs_session_wise)
# scaling signed coherence 
inputs_aggregated[0][masks_aggregated[0][:,0],0] = preprocessing.scale(inputs_aggregated[0][masks_aggregated[0][:,0], 0], axis=0)
for idx_session in range(len(session_metadata)):
    inputs_session_wise[idx_session][masks_session_wise[idx_session][:,0],0] = preprocessing.scale(
        inputs_session_wise[idx_session][masks_session_wise[idx_session][:,0], 0], axis=0) # normalize signed coherence

In [None]:
models_glm_hmm, fit_lls_glm_hmm = global_fit(choices_aggregated, inputs_aggregated, masks=masks_aggregated, n_iters= 1000, n_initializations=20)

In [None]:
# get best model of 20 initializations for each state
init_params = {
    'glm_weights': {},
    'transition_matrices': {}
}
for n_states in np.arange(2,6):
    best_idx = fit_lls_glm_hmm[n_states].index(max(fit_lls_glm_hmm[n_states]))
    init_params['glm_weights'][n_states] = models_glm_hmm[n_states][best_idx].observations.params
    init_params['transition_matrices'][n_states] = models_glm_hmm[n_states][best_idx].transitions.params



In [None]:
# session-wise fitting with 5 fold cross-validation
models_session_state_fold, train_ll_session, test_ll_session = session_wise_fit_cv(choices_session_wise, inputs_session_wise, masks=masks_session_wise,
                                                                                    n_sessions=len((session_metadata['session_id'])), init_params=init_params, n_iters= 1000)

In [None]:
# store data and models for aggregated
agg_data = pd.DataFrame({
    "choices": choices_aggregated[0].reshape(-1),
    "stimulus": unnormalized_inputs_aggregated[0][:,0],
    "normalized_stimulus": inputs_aggregated[0][:,0],
    "bias": inputs_aggregated[0][:,1],
    "previous_choice": inputs_aggregated[0][:,2],
    "previous_target": inputs_aggregated[0][:,3],
    'mask': masks_aggregated[0].reshape(-1),
})
global_fits = {
    'models': models_glm_hmm,
    'fits_lls_glm_hmm': fit_lls_glm_hmm,
    "data": agg_data
}


# store data and models for session-wise
session_data = {}
for idx_session, session_id in enumerate(session_metadata['session_id']):
    session_data[session_id] = pd.DataFrame({
        "choices": choices_session_wise[idx_session].reshape(-1),
        "stimulus": unnormalized_inputs_session_wise[idx_session][:,0],
        "normalized_stimulus": inputs_session_wise[idx_session][:,0],
        "bias": inputs_session_wise[idx_session][:,1],
        "previous_choice": inputs_session_wise[idx_session][:,2],
        "previous_target": inputs_session_wise[idx_session][:,3],
        "mask": masks_session_wise[idx_session].reshape(-1),
        "trial_num":GP_trial_num_session_wise[idx_session].reshape(-1),
        "prob_toRF": prob_toRF_session_wise[idx_session].reshape(-1),
    })

session_wise_fits = {
    'models': models_session_state_fold,
    'train_ll': train_ll_session,
    'test_ll': test_ll_session,
    'data': session_data,
}


models_and_data = {
    'global': global_fits,
    'session_wise': session_wise_fits,
}

with open(Path(processed_dir, f'glm_hmm_{_TRIALS}.pkl'), 'wb') as f:
    pickle.dump(models_and_data, f)