In [125]:
import numpy as np
import numpy.random as npr
import pandas as pd
import matplotlib.pyplot as plt
import ssm
from sklearn import preprocessing
from sklearn.model_selection import KFold
from scipy import stats

from notebooks.imports import *
from config import dir_config, main_config
from src.utils.glm_hmm_utils import *



In [127]:
compiled_dir = Path(dir_config.data.compiled)
processed_dir = Path(dir_config.data.processed)

#### utils

In [128]:
def extract_previous_data(trial_data):
    # npr.seed()
    prev_choice = np.hstack([trial_data.choice[0] , trial_data.choice[:-1]])  # 0:awayRF, 1:toRF of previous valid trial
    prev_target = np.hstack([trial_data.target[0] , trial_data.target[:-1]]) * 2 - 1 # -1:awayRF, 1:toRF of previous valid trial 
    prev_outcome = np.hstack([trial_data.outcome[0] , trial_data.outcome[:-1]])
    prev_valid_idx = np.where(prev_outcome >= 0)[0]
    prev_invalid_idx = np.where(prev_outcome == -1)[0]

    for i in prev_invalid_idx:
        if i < prev_valid_idx[0]: #randomly sample if no previous valid trials
            prev_choice[i] = np.random.binomial(1,0.5)
            prev_target[i] = np.random.binomial(1,0.5) * 2 - 1
        else:
            last_valid =  np.where(prev_valid_idx<i)[0][-1]
            prev_choice[i] = prev_choice[prev_valid_idx[last_valid]]
            prev_target[i] = prev_target[prev_valid_idx[last_valid]]
    return prev_choice,prev_target




#### create design matrix (input, output, mask)

In [129]:
n_states = 2       # number of discrete states
obs_dim = 1           # number of observed dimensions: choice(toRF/awayRF)
num_categories = 2    # number of categories for output
input_dim = 4        # input dimensions: current signed coherence, 1(bias), previous choice(toRF/awayRF), previous target side(toRF/awayRF)


In [130]:
session_metadata = pd.read_csv(Path(compiled_dir, "sessions_metadata.csv"), index_col=None)
inputs_session_wise = []
choices_session_wise = []
invalid_idx_session_wise = []
masks_session_wise = []

for session_id in (session_metadata['session_id']):

    trial_data = pd.read_csv(Path(compiled_dir, session_id,f"{session_id}_trial.csv"), index_col=None).fillna(-1)
    GP_trial_data = trial_data[trial_data.task_type == 1] #  all GP trials
    # GP_trial_data = trial_data[(trial_data.task_type == 1) & (trial_data.outcome>=0)] #  valid GP trial
    GP_trial_data = GP_trial_data.reset_index()
    num_trials_per_sess = GP_trial_data.shape[0] # number of trials in a session
    inpts = np.ones((1, num_trials_per_sess, input_dim)) # initialize inpts array

    current_stimulus = GP_trial_data.coherence * (2*GP_trial_data.target-1)
    inpts[0,:,0] = current_stimulus / 100
    inpts[0,:,0] = preprocessing.scale(inpts[0,:,0]) # normalize stim values

    prev_choice,prev_target = extract_previous_data(GP_trial_data)
    inpts[0,:,2] = prev_choice * 2- 1 # -1:awayRF, 1:toRF of previous valid trial
    inpts[0,:,3] = prev_target # -1:awayRF, 1:toRF of previous valid trial

    # inpts[0,:,2] = np.hstack([GP_trial_data.choice.iloc[0] , GP_trial_data.choice.iloc[:-1]]) *2-1 # -1:awayRF, 1:toRF choice of previous valid trial
    # inpts[0,:,3] = np.hstack([GP_trial_data.target.iloc[0] , GP_trial_data.target.iloc[:-1]]) *2-1 # -1:awayRF, 1:toRF stim values of previous valid trial

    inpts = list(inpts) #convert inpts to correct format
    inputs_session_wise = inputs_session_wise + inpts

    choices = GP_trial_data.choice.values
    choices = choices.reshape(-1,1).astype('int')

    masks = np.array(choices >= 0)
    masks_session_wise.append(masks)
    invalid_idx = np.where(choices == -1)[0].reshape(-1,1)
    invalid_idx_session_wise.append(invalid_idx)

    # for training, replace -1 with random sample from 0,1
    choices[choices == -1] = npr.choice(1,invalid_idx.shape[0])
    choices_session_wise.append(choices)

inputs_aggregated = []
inputs_aggregated.append(np.vstack(inputs_session_wise))
choices_aggregated = []
choices_aggregated.append(np.vstack(choices_session_wise))
masks_aggregated = []
masks_aggregated.append(np.vstack(masks_session_wise))


In [131]:
import pickle
with open(Path(processed_dir, "inputs_aggregated.pickle"), 'wb') as handle:
    pickle.dump(inputs_aggregated, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [132]:
# # read pickle to list
# import pickle
# with open(Path(processed_dir, "inputs_aggregated.pickle"), 'rb') as handle:
#     inpts_aggregated = pickle.load(handle)

In [None]:
# fit global data with 20 initialization
models_glm_hmm, fit_lls_glm_hmm = global_fit(choices_aggregated, inputs_aggregated, masks_aggregated, n_iters= 1000, n_initializations=20)

# get best model of 20 initializations for each state
init_params = {
    'glm_weights': {},
    'transition_matrix': {}
}
for n_states in np.arange(2,6):
    best_idx = np.argmax(fit_lls_glm_hmm[n_states])
    init_params['glm_weights'][n_states] = models_glm_hmm[n_states][best_idx].observations.params
    init_params['tranition_matrices'][n_states] = models_glm_hmm[n_states][best_idx].transitions.params

# session-wise fitting with 5 fold cross-validation
models_session_state_fold, train_ll_session, test_ll_session = session_wise_fit_cv(choices_session_wise, inputs_session_wise, masks_session_wise,
                                                                                    len((session_metadata['session_id'])), init_params, n_iters= 1000)

fitting GLM globally.....


  0%|          | 0/1000 [00:00<?, ?it/s]

fitting GLM-HMM globally.....
fitting 2 states.....
Initialization   1


  0%|          | 0/1000 [00:00<?, ?it/s]