In [13]:
import numpy as np
import numpy.random as npr
import pandas as pd
import matplotlib.pyplot as plt
import ssm
from sklearn import preprocessing
from sklearn.model_selection import KFold
from scipy import stats

from notebooks.imports import *
from config import dir_config, main_config
from src.utils.glm_hmm_utils import *
import pickle


In [2]:
compiled_dir = Path(dir_config.data.compiled)
processed_dir = Path(dir_config.data.processed)

#### utils

In [3]:
def extract_previous_data(trial_data):
    # npr.seed()
    prev_choice = np.hstack([trial_data.choice[0] , trial_data.choice[:-1]])  # 0:awayRF, 1:toRF of previous valid trial
    prev_target = np.hstack([trial_data.target[0] , trial_data.target[:-1]]) * 2 - 1 # -1:awayRF, 1:toRF of previous valid trial 
    prev_outcome = np.hstack([trial_data.outcome[0] , trial_data.outcome[:-1]])
    prev_valid_idx = np.where(prev_outcome >= 0)[0]
    prev_invalid_idx = np.where(prev_outcome == -1)[0]

    for i in prev_invalid_idx:
        if i < prev_valid_idx[0]: #randomly sample if no previous valid trials
            prev_choice[i] = np.random.binomial(1,0.5)
            prev_target[i] = np.random.binomial(1,0.5) * 2 - 1
        else:
            last_valid =  np.where(prev_valid_idx<i)[0][-1]
            prev_choice[i] = prev_choice[prev_valid_idx[last_valid]]
            prev_target[i] = prev_target[prev_valid_idx[last_valid]]

    prev_choice = (prev_choice * 2) - 1 # -1:awayRF, 1:toRF of previous valid trial
    return prev_choice.astype(int), prev_target.astype(int)


def prepare_input_data(data, input_dim):
    X = np.ones((1, data.shape[0], input_dim))
    
    current_stimulus = data.coherence * (2*data.target-1)
    current_stimulus = current_stimulus / 100
    
    X[0,:,0] = current_stimulus
    X[0,:,2], X[0,:,3] = extract_previous_data(data)
    return list(X)

### create design matrix (input, output, mask)

In [4]:
n_states = 2       # number of discrete states
obs_dim = 1           # number of observed dimensions: choice(toRF/awayRF)
num_categories = 2    # number of categories for output
input_dim = 4        # input dimensions: current signed coherence, 1(bias), previous choice(toRF/awayRF), previous target side(toRF/awayRF)


#### data preparation

In [5]:
session_metadata = pd.read_csv(Path(compiled_dir, "sessions_metadata.csv"), index_col=None)
inputs_session_wise = []
choices_session_wise = []
invalid_idx_session_wise = []
masks_session_wise = []

for session_id in (session_metadata['session_id']):

    trial_data = pd.read_csv(Path(compiled_dir, session_id,f"{session_id}_trial.csv"), index_col=None).fillna(-1)
    GP_trial_data = trial_data[trial_data.task_type == 1].reset_index()
    inputs_session_wise += prepare_input_data(GP_trial_data, input_dim)
    choices = GP_trial_data.choice.values.reshape(-1,1).astype('int')

    masks_session_wise.append(np.array(choices >= 0))
    invalid_idx = np.where(choices == -1)[0].reshape(-1,1)
    invalid_idx_session_wise.append(invalid_idx)

    # for training, replace -1 with random sample from 0,1
    choices[choices == -1] = npr.choice(1,invalid_idx.shape[0])
    choices_session_wise.append(choices)

inputs_aggregated, choices_aggregated, masks_aggregated = [], [], []
inputs_aggregated.append(np.vstack(inputs_session_wise))
choices_aggregated.append(np.vstack(choices_session_wise))
masks_aggregated.append(np.vstack(masks_session_wise))

# scaling signed coherence
inputs_aggregated[0][:,0] = preprocessing.scale(inputs_aggregated[0][:, 0], axis=0)
for idx_session in range(len(session_metadata)):
    inputs_session_wise[idx_session][:,0] = preprocessing.scale(inputs_session_wise[idx_session][:, 0], axis=0) # normalize signed coherence



In [14]:
data_aggregated = {
    'inputs':inputs_aggregated,
    'choices': choices_aggregated,
    'masks': masks_aggregated,
    'invalid_idx': invalid_idx_session_wise
}

data_session_wise = {
    'inputs':inputs_session_wise,
    'choices': choices_session_wise,
    'masks': masks_session_wise
}

with open(Path(processed_dir, "data_aggregated.pkl"), 'wb') as handle:
    pickle.dump(data_aggregated, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open(Path(processed_dir, "data_session_wise.pkl"), 'wb') as handle:
    pickle.dump(data_session_wise, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [11]:
models_glm_hmm, fit_lls_glm_hmm = global_fit(choices_aggregated, inputs_aggregated, masks_aggregated, n_iters= 1000, n_initializations=20)

Fitting GLM globally...


  0%|          | 0/1000 [00:00<?, ?it/s]

Fitting 2 states...
Fitting 3 states...
Fitting 4 states...
Fitting 5 states...


In [29]:
global_fits = {
    'models': models_glm_hmm,
    'fits_lls_glm_hmm': fit_lls_glm_hmm,
}

with open(Path(processed_dir, 'models_glm_hmm_global_YH.pkl'), 'wb') as f:
    pickle.dump(global_fits, f)

In [30]:
# get best model of 20 initializations for each state
init_params = {
    'glm_weights': {},
    'transition_matrices': {}
}
for n_states in np.arange(2,6):
    best_idx = fit_lls_glm_hmm[n_states].index(max(fit_lls_glm_hmm[n_states]))
    init_params['glm_weights'][n_states] = models_glm_hmm[n_states][best_idx].observations.params
    init_params['transition_matrices'][n_states] = models_glm_hmm[n_states][best_idx].transitions.params



In [None]:
# session-wise fitting with 5 fold cross-validation
models_session_state_fold, train_ll_session, test_ll_session = session_wise_fit_cv(choices_session_wise, inputs_session_wise, masks_session_wise,
                                                                                    len((session_metadata['session_id'])), init_params, n_iters= 1000)

In [33]:
session_wise_fits = {
    'models': models_session_state_fold,
    'train_ll': train_ll_session,
    'test_ll': test_ll_session
}

with open(Path(processed_dir, 'models_glm_hmm_session_wise_YH.pkl'), 'wb') as f:
    pickle.dump(session_wise_fits, f)