# **Fit GLM-HMM to all DMDM data with L2 regularization**
---
We next fit GLM-HMM to the all the animals in the dataset. This step is important as it searchs thorough the parameter space and find possible parameters needed to fit GLM-HMM on individual animal later. We only use outcomes `y` as a dependent variable to simplify the model here.

## **HPC setting**
Ashwood's original script is written in python scirpts. Here, we rewrite it in Jupyter to make it more user-friendly to run on HPC with `dask`. [This](https://github.com/pierreglaser/hpc-tutorial/tree/main) is very useful resource to get familiar with `dask`.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# allocate the computing resources
from dask_jobqueue import SLURMCluster
from distributed import Client
from joblib import Memory, Parallel, delayed, parallel_backend
from threadpoolctl import threadpool_limits
from tqdm import tqdm

cluster = SLURMCluster(
    workers=0,      # create the workers "lazily" (upon cluster.scal)
    memory='64g',   # amount of RAM per worker
    processes=1,    # number of execution units per worker (threads and processes)
    cores=1,        # among those execution units, number of processes
    walltime="48:00:00",
    worker_extra_args=["--resources GPU=1"], # the only way to add GPUs
    local_directory='/nfs/nhome/live/skuroda/jobs', # set your path to save log
    log_directory='/nfs/nhome/live/skuroda/jobs' # set your path to save log
)   

memory = Memory('/nfs/nhome/live/skuroda/joblib-cache') # set your path

n = 10
cluster.scale(n)
client = Client(cluster)
print(client.dashboard_link)

  from distributed.utils import format_bytes, parse_bytes, tmpfile, get_ip_interface
  from distributed.utils import format_bytes, parse_bytes, tmpfile, get_ip_interface
  from distributed.utils import format_bytes, parse_bytes, tmpfile, get_ip_interface
  from distributed.utils import parse_bytes


http://192.168.234.51:8787/status


## **Fit GLM-HMM to all animals**
---

In [3]:
# ------- load modules -------
import autograd.numpy as np
import numpy as onp
import autograd.numpy.random as npr
import sys

from glm_hmm_utils import fit_glm_hmm
sys.path.append('../') # a lazy trick to search parent dir
# https://stackoverflow.com/questions/34478398/import-local-function-from-a-module-housed-in-another-directory-with-relative-im
from data_io import get_file_dir, load_session_fold_lookup, load_data, load_glm_vectors
from data_labels import create_abort_mask, partition_data_by_session, partition_data_by_session_L2

from functools import partial
from collections import OrderedDict

In [4]:
# ------- setup variables -------
dname = 'dataAllHumans' # 'dataAllHumans' 'dataAllMiceTraining'
C = 3  # number of output types/categories
D = 1  # data (observations) dimension
K_vals = [1, 2, 3, 4]
ridge_lambda = [0, 1, 100] # for ridge regression
N_initializations = 10 #20
num_folds_training = num_folds_tuning = 5

nested_outcome = OrderedDict() # define nested structure for behavioral outcomes
nested_outcome["Baseline"] = [2]
nested_outcome["Change"] = [0, 1]

cluster_job_arr = []
for l in ridge_lambda:
    for K in K_vals:
        for i in range(num_folds_training):
             for ii in range(num_folds_tuning):
                for j in range(N_initializations):
                    cluster_job_arr.append([l, K, i, ii, j])

N_em_iters = 600  # number of EM iterations
global_fit = True
reguralization = True
paramter_tuning = True
transition_alpha = 1 # perform mle => set transition_alpha to 1
prior_sigma = 100

print('Total number of jobs: {}'.format(len(cluster_job_arr)))

Total number of jobs: 3000


In [5]:
# ------- setup path and load data -------
data_dir =  get_file_dir().parents[1] / "data" / "dmdm" / dname / 'data_for_cluster'
# Create directory for results:
try: 
    results_dir = get_file_dir().parents[1] / "results" / "dmdm_global_fit" / dname
except:
    raise FileNotFoundError('Run GLM first to initialize parameters')


#  read in data and train/test split
animal_file = data_dir / 'all_animals_concat.npz'
session_fold_lookup_table = load_session_fold_lookup(
    data_dir / 'all_animals_concat_session_fold_lookup.npz')

inpt_y, inpt_rt, y, session, rt, stim_onset = load_data(animal_file)

In [6]:
def fit_GLMHMM_y(inpt_y, y, session, 
                 session_fold_lookup_table, 
                 global_fit,
                 paramter_tuning,
                 reguralization,
                 transition_alpha,
                 prior_sigma,
                 results_dir,
                 params):
    
    [l, K, fold_training, fold_tuning, iter] = params

    # Append a column of ones to inpt to represent the bias covariate:
    inpt_y = np.hstack((inpt_y, np.ones((len(inpt_y),1))))
    y = y.astype('int')
    # Identify violations for exclusion:
    abort_idx = np.where(y == 3)[0]
    nonviolation_idx, mask = create_abort_mask(abort_idx, inpt_y.shape[0])

    #  GLM weights to use to initialize GLM-HMM
    if global_fit == True:
        init_param_file = results_dir / 'GLM' / ('fold_' + str(fold_training)) / 'GLM_y_variables_of_interest_iter_0.npz'
    else:
        raise NotImplementedError('This notebook only runs global fitting')

    # Create save directory for this initialization/fold combination:
    if paramter_tuning:
        saving_directory = results_dir / ("GLM_HMM_y_K_" + str(K)) / ("fold_" + str(fold_training)) / ("tuningfold_" + str(fold_tuning)) / ('iter_' + str(iter))
    else:
        saving_directory = results_dir / ("GLM_HMM_y_K_" + str(K)) / ("fold_" + str(fold_training)) / ('iter_' + str(iter))
        saving_directory.mkdir(parents=True, exist_ok=True)

    launch_glm_hmm_job(inpt_y,
                       y,
                       session,
                       mask,
                       session_fold_lookup_table,
                       K,
                       D,
                       C,
                       N_em_iters,
                       transition_alpha,
                       prior_sigma,
                       l,
                       fold_training, 
                       fold_tuning,
                       paramter_tuning,
                       iter,
                       global_fit,
                       reguralization,
                       init_param_file,
                       saving_directory)


def launch_glm_hmm_job(inpt, y, session, mask, session_fold_lookup_table, K, D,
                       C, N_em_iters, transition_alpha, prior_sigma, l,
                       fold_training, fold_tuning, paramter_tuning, iter, global_fit, reguralization, 
                       init_param_file, save_directory):
    sys.stdout.flush()
    if paramter_tuning:
        split_loc = np.logical_and(session_fold_lookup_table[:, 1] != fold_training, session_fold_lookup_table[:, 2] != fold_tuning)
        sessions_to_keep = session_fold_lookup_table[split_loc, 0]
    else:
        sessions_to_keep = session_fold_lookup_table[np.where(
            session_fold_lookup_table[:, 1] != fold_training), 0]
        
    idx_this_fold = [str(sess) in sessions_to_keep for sess in session]
    this_inpt, this_y, this_session, this_mask = inpt[idx_this_fold, :], \
                                                 y[idx_this_fold, :], \
                                                 session[idx_this_fold], \
                                                 mask[idx_this_fold, :]
    
    # Only do this so that errors are avoided - these y values will not
    # actually be used for anything (due to violation mask)
    this_y[np.where(this_y == 3), :] = 2
    
    inputs, datas, masks = partition_data_by_session_L2(
        this_inpt, this_y, this_mask, this_session, l)
    # Read in GLM fit if global_fit = True:
    if global_fit == True:
        _, params_for_initialization = load_glm_vectors(init_param_file)
    else:
        raise NotImplementedError('This notebook only runs global fitting')
    M = this_inpt.shape[1]

    npr.seed(iter)
    fit_glm_hmm(datas,
                inputs,
                masks,
                K,
                D,
                M,
                C,
                N_em_iters,
                transition_alpha,
                prior_sigma,
                global_fit,
                reguralization,
                l,
                params_for_initialization,
                save_title=save_directory / ('GLM_HMM_y_raw_parameters_itr_' + str(iter))
                )
    
fit_GLMHMM_eachparam = partial(fit_GLMHMM_y, inpt_y, y, session, session_fold_lookup_table, 
                               global_fit, paramter_tuning, reguralization, transition_alpha, prior_sigma, results_dir)        
fit_GLMHMM_eachparam_cached = memory.cache(fit_GLMHMM_eachparam)

In [7]:
with Client(cluster) as client: # upload local functions to each worker. They cannot read them with sys.append or sys.insert.
    client.wait_for_workers(n)
    client.upload_file(str(get_file_dir() / 'data_io.py'))
    client.upload_file(str(get_file_dir() / 'data_labels.py'))

In [8]:
%%time

with threadpool_limits(limits=1, user_api='blas'):
    with parallel_backend('dask', wait_for_workers_timeout=120):
        Parallel(verbose=100)(
            delayed(fit_GLMHMM_eachparam)(params) for params in cluster_job_arr
            )

[Parallel(n_jobs=-1)]: Using backend DaskDistributedBackend with 10 concurrent workers.
[Parallel(n_jobs=-1)]: Done   1 tasks      | elapsed:    5.2s
[Parallel(n_jobs=-1)]: Done   2 tasks      | elapsed:    5.9s
[Parallel(n_jobs=-1)]: Done   3 tasks      | elapsed:    5.9s
[Parallel(n_jobs=-1)]: Done   4 tasks      | elapsed:    6.1s
[Parallel(n_jobs=-1)]: Done   5 tasks      | elapsed:    6.2s
[Parallel(n_jobs=-1)]: Done   6 tasks      | elapsed:    8.5s
[Parallel(n_jobs=-1)]: Done   7 tasks      | elapsed:    8.6s
[Parallel(n_jobs=-1)]: Done   8 tasks      | elapsed:    9.0s
[Parallel(n_jobs=-1)]: Done   9 tasks      | elapsed:    9.4s
[Parallel(n_jobs=-1)]: Done  10 tasks      | elapsed:    9.5s
[Parallel(n_jobs=-1)]: Done  11 tasks      | elapsed:    9.8s
[Parallel(n_jobs=-1)]: Done  12 tasks      | elapsed:    9.8s
[Parallel(n_jobs=-1)]: Done  13 tasks      | elapsed:   11.1s
[Parallel(n_jobs=-1)]: Done  14 tasks      | elapsed:   11.4s
[Parallel(n_jobs=-1)]: Done  15 tasks      |

In [17]:
# Once finished, shut down the cluster and the client.
memory.clear(warn=False)
cluster.close()
client.close()

  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # 

## **L2 hyperparameter tuning**
---

We refit GLM-HMM again but using the best L2 hyperparam we just found.

In [8]:
from kfold_cv import KFoldCV
from data_io import load_data, load_session_fold_lookup, load_glmhmm_data, load_best_map_params, \
                     load_cv_arr, get_file_name_for_best_glmhmm_fold
from data_postprocessing_utils import calculate_state_permutation
import json

model = 'GLM_HMM_y'

In /nfs/nhome/live/skuroda/.conda/envs/glmhmm/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: 
The text.latex.preview rcparam was deprecated in Matplotlib 3.3 and will be removed two minor releases later.
In /nfs/nhome/live/skuroda/.conda/envs/glmhmm/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: 
The mathtext.fallback_to_cm rcparam was deprecated in Matplotlib 3.3 and will be removed two minor releases later.
In /nfs/nhome/live/skuroda/.conda/envs/glmhmm/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: Support for setting the 'mathtext.fallback_to_cm' rcParam is deprecated since 3.3 and will be removed two minor releases later; use 'mathtext.fallback : 'cm' instead.
In /nfs/nhome/live/skuroda/.conda/envs/glmhmm/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: 
The validate_bool_maybe_none function was deprecated in Matplotlib 3.3 and will be removed two mi

In [39]:
out_params = np.zeros((len(K_vals)*num_folds_training, 3), dtype='<U32') # lambda, K
idx_count = 0

for fold_training in range(num_folds_training):
    # get best lambda given fold for training

    KFCV = KFoldCV(model,num_folds_tuning, K_vals, global_fit, 
                   [transition_alpha],[prior_sigma],ridge_lambda,
                   results_dir)
    out = KFCV.save_best_L2params(inpt_y, inpt_rt, y, session, rt, stim_onset,
                                    session_fold_lookup_table, fold_training, C, 
                                    outcome_dict=nested_outcome, save_output=True)

    out_arr = np.array(out)[:,[1,0]] # sort

    out_arr_fold_animal = np.hstack([out_arr, 
                                      np.repeat(fold_training, len(K_vals)).T[:, None]])
                                    
    out_params[idx_count:idx_count+len(K_vals),:] = out_arr_fold_animal
    idx_count += len(K_vals)

np.savez(data_dir / "best_l2_params_model_{}.npz".format(model), out_params)
print("Saved all the L2 params!")
print(out_params)

Retrieving best iter results for model = GLM_HMM_y; fold = 0; num_folds = 5


100%|██████████| 5/5 [00:29<00:00,  6.00s/it]


Calculating best L2 parameters...
[[ 0.14472502  0.13244128  0.15704311  0.15795999  0.15608732]
 [ 0.02798024  0.00978113  0.03673892  0.03818171  0.03482281]
 [-1.18267452 -1.26480882 -1.21006586 -1.19928649 -1.22028852]]
[ 0.15582677  0.03553162 -1.21089071]
(0,)
[[0.13443693 0.12210732 0.13533878 0.14884254 0.1469561 ]
 [0.16609275 0.15499094 0.16360684 0.17175139 0.1739927 ]
 [0.18920514 0.18271039 0.19415814 0.19309855 0.1917307 ]]
[0.15270758 0.18571332 0.20978145]
(2,)
[[0.13903872 0.12033157 0.14359055 0.14647097 0.15997489]
 [0.17377683 0.17277785 0.18040581 0.18988776 0.21171005]
 [0.19775163 0.18574108 0.2053285  0.20992038 0.22729538]]
[0.17121438 0.21533216 0.23487905]
(2,)
[[0.09705548 0.1159574  0.10991336 0.15069447 0.1425253 ]
 [0.1442807  0.16614194 0.15108815 0.18273298 0.18854657]
 [0.1584251  0.18089205 0.17416452 0.19141323 0.1885608 ]]
[0.18143287 0.2310775  0.24316914]
(2,)
Retrieving best iter results for model = GLM_HMM_y; fold = 1; num_folds = 5


100%|██████████| 5/5 [00:26<00:00,  5.36s/it]


Calculating best L2 parameters...
[[ 0.13996656  0.14099673  0.1425175   0.17536849  0.18795254]
 [ 0.09576451  0.10247285  0.09585612  0.13100838  0.14780701]
 [-0.46425504 -0.40826411 -0.51669459 -0.43005396 -0.38938468]]
[ 0.16257166  0.11978131 -0.43613608]
(0,)
[[0.12728986 0.12562639 0.12747354 0.15151613 0.1748036 ]
 [0.17016692 0.16073631 0.16787767 0.1601877  0.19642719]
 [0.17025718 0.18714418 0.19879625 0.19695257 0.22870916]]
[0.16203763 0.19587987 0.22072294]
(2,)
[[0.11254012 0.12275275 0.11896441 0.16458506 0.19695318]
 [0.15763116 0.13602472 0.12669488 0.21146738 0.23083607]
 [0.17255139 0.18289513 0.18503531 0.22788078 0.26470737]]
[0.17958871 0.2096752  0.24261355]
(2,)
[[0.11072063 0.11106552 0.10882694 0.15811463 0.1783986 ]
 [0.12717607 0.15530648 0.15414965 0.17619355 0.20260198]
 [0.15694573 0.17376843 0.18022251 0.1964022  0.23113597]]
[0.19042012 0.22529197 0.24993738]
(2,)
Retrieving best iter results for model = GLM_HMM_y; fold = 2; num_folds = 5


100%|██████████| 5/5 [00:25<00:00,  5.13s/it]


Calculating best L2 parameters...
[[ 0.12334864  0.14519631  0.08431404  0.16116953  0.15158359]
 [ 0.07794528  0.10626226  0.03569439  0.1184042   0.11105168]
 [-0.49655303 -0.41836685 -0.61030483 -0.43218372 -0.43208134]]
[ 0.1433436   0.10045982 -0.46165745]
(0,)
[[0.11061194 0.13375999 0.06398906 0.13796106 0.14212094]
 [0.15231265 0.14923845 0.10734811 0.15889898 0.20206311]
 [0.15232796 0.17695194 0.13723179 0.18384447 0.20219726]]
[0.13979407 0.18200771 0.19790665]
(2,)
[[0.07576808 0.12241588 0.05891756 0.14933893 0.14865568]
 [0.11438103 0.1359023  0.10180049 0.18897234 0.18778207]
 [0.13758527 0.18291408 0.13403715 0.21530028 0.21251568]]
[0.162441   0.19678384 0.22708229]
(2,)
[[0.08854572 0.11739252 0.06341146 0.13942678 0.13029025]
 [0.12006053 0.14785179 0.10747532 0.17467737 0.14455866]
 [0.14515869 0.17749065 0.13946709 0.19579858 0.17282874]]
[0.1754626  0.21022317 0.23721089]
(2,)
Retrieving best iter results for model = GLM_HMM_y; fold = 3; num_folds = 5


100%|██████████| 5/5 [00:27<00:00,  5.45s/it]


Calculating best L2 parameters...
[[ 0.11273522  0.13141968  0.12437471  0.16495903  0.16966463]
 [ 0.06876445  0.09089279  0.07891222  0.12136065  0.13375457]
 [-0.49088348 -0.45784331 -0.53042162 -0.43907493 -0.34805639]]
[ 0.14607035  0.10404458 -0.44890042]
(0,)
[[0.1143993  0.06033946 0.11553721 0.15348658 0.1570248 ]
 [0.14037332 0.09557395 0.15820373 0.19200783 0.18079914]
 [0.16873878 0.12396758 0.18721501 0.19202284 0.20972991]]
[0.14505426 0.18181264 0.20455104]
(2,)
[[0.09679572 0.11380159 0.09703218 0.15623969 0.16759793]
 [0.15291844 0.15832696 0.14518616 0.20410716 0.20780651]
 [0.16460286 0.17536893 0.17010773 0.21695858 0.23049152]]
[0.16191548 0.21123655 0.2289004 ]
(2,)
[[ 0.09299444  0.09932548 -0.02775682  0.1522519   0.1512351 ]
 [ 0.10758332  0.14329085  0.04538316  0.19571292  0.18645725]
 [ 0.14314785  0.16279256  0.04888911  0.19634567  0.20393189]]
[0.16955548 0.21518497 0.23116469]
(2,)
Retrieving best iter results for model = GLM_HMM_y; fold = 4; num_folds =

100%|██████████| 5/5 [00:26<00:00,  5.27s/it]

Calculating best L2 parameters...
[[ 0.14232234  0.14369854  0.15788456  0.16777857  0.17841943]
 [ 0.09600324  0.1010355   0.10926475  0.12189856  0.13671052]
 [-0.48943854 -0.46914952 -0.5273048  -0.45550547 -0.41395422]]
[ 0.16484523  0.11980784 -0.46346317]
(0,)
[[0.13013174 0.13253607 0.15395864 0.14402975 0.17654703]
 [0.16469522 0.17213542 0.19736865 0.16283233 0.21075092]
 [0.18784128 0.19556271 0.22419321 0.18847736 0.2369177 ]]
[0.16475998 0.2006866  0.22568881]
(2,)
[[0.11475202 0.11708422 0.12113346 0.16886509 0.176797  ]
 [0.14804638 0.15486151 0.16856722 0.21907892 0.22765727]
 [0.17449818 0.17809943 0.19323656 0.23359669 0.24416193]]
[0.18166647 0.22672951 0.24771146]
(2,)
[[0.12368911 0.11385109 0.14681786 0.15949131 0.16280666]
 [0.14745824 0.14189196 0.17689528 0.18101746 0.17990057]
 [0.17843745 0.17208708 0.21471464 0.20894779 0.21239375]]
[0.18827718 0.2152455  0.24692028]
(2,)
Saved all the L2 params!





In [9]:
paramter_tuning = False
fold_dummy = 100

raw_params = load_best_map_params(data_dir / "best_l2_params_model_{}.npz".format(model))
# it is okay to use load_best_map_params here as there is no difference
params = raw_params.astype('int32')

In [10]:
cluster_job_arr_2 = []
for i in range(len(params)):
        for j in range(N_initializations):
                        param_i = [params[i,0], 
                                   params[i,1],
                                   params[i,2],
                                   fold_dummy, j]
                        cluster_job_arr_2.append(param_i)

print('Total number of jobs: {}'.format(len(cluster_job_arr_2)))

Total number of jobs: 200


In [11]:
fit_GLMHMM_eachparam_2 = partial(fit_GLMHMM_y, inpt_y, y, session, session_fold_lookup_table, 
                               global_fit, paramter_tuning, reguralization, transition_alpha, prior_sigma, results_dir)        
fit_GLMHMM_eachparam_2_cached = memory.cache(fit_GLMHMM_eachparam_2)

In [12]:
with Client(cluster) as client: # upload local functions to each worker. They cannot read them with sys.append or sys.insert.
    client.wait_for_workers(n)
    client.upload_file(str(get_file_dir() / 'data_io.py'))
    client.upload_file(str(get_file_dir() / 'data_labels.py'))

In [13]:
%%time

with threadpool_limits(limits=1, user_api='blas'):
    with parallel_backend('dask', wait_for_workers_timeout=120):
        Parallel(verbose=100)(
            delayed(fit_GLMHMM_eachparam_2)(allparams) for allparams in cluster_job_arr_2
            )

[Parallel(n_jobs=-1)]: Using backend DaskDistributedBackend with 10 concurrent workers.
[Parallel(n_jobs=-1)]: Done   1 tasks      | elapsed:   33.3s
[Parallel(n_jobs=-1)]: Done   2 tasks      | elapsed:  1.5min
[Parallel(n_jobs=-1)]: Done   3 tasks      | elapsed:  1.6min
[Parallel(n_jobs=-1)]: Done   4 tasks      | elapsed:  1.6min
[Parallel(n_jobs=-1)]: Done   5 tasks      | elapsed:  1.6min
[Parallel(n_jobs=-1)]: Done   6 tasks      | elapsed:  1.7min
[Parallel(n_jobs=-1)]: Done   7 tasks      | elapsed:  1.7min
[Parallel(n_jobs=-1)]: Done   8 tasks      | elapsed:  1.8min
[Parallel(n_jobs=-1)]: Done   9 tasks      | elapsed:  1.9min
[Parallel(n_jobs=-1)]: Done  10 tasks      | elapsed:  1.9min
[Parallel(n_jobs=-1)]: Done  11 tasks      | elapsed:  2.0min
[Parallel(n_jobs=-1)]: Done  12 tasks      | elapsed:  2.0min
[Parallel(n_jobs=-1)]: Done  13 tasks      | elapsed:  2.6min
[Parallel(n_jobs=-1)]: Done  14 tasks      | elapsed:  2.7min
[Parallel(n_jobs=-1)]: Done  15 tasks      |

In [15]:
# Once finished, shut down the cluster and the client.
memory.clear(warn=False)
cluster.close()
client.close()

## **Posthoc data processing**
---

We then check the output files to find an initialization with the best log likelihood for each `K` (because EM algorithm does not gurantee that we will find the global maxima). At the same time, we look at different fold from k-fold cross-validation and calculate the normalized log likelihood. The parameters are saved for the next individual GLM-HMM fitting, and we also plot them for visualization purposes. Below is the workfrom from Ashwood et al., 2022, Nat. Neurosci..

![image.png](attachment:image.png)

In [6]:
from kfold_cv import KFoldCV
from data_io import load_data, load_session_fold_lookup, load_glmhmm_data, load_cv_arr, get_file_name_for_best_glmhmm_fold, get_file_name_for_best_glmhmm_fold_l2, load_best_map_params
from data_postprocessing_utils import calculate_state_permutation
sys.path.append('../../../3_make_figures/dmdm/')
from plot_model_perform import plot_states, plot_model_comparison_l2
import json

model = 'GLM_HMM_y'
labels_for_plot_y = ['CSize', 'COnset', 'Outcome +1', 'Outcome +2', 'Outcome +3', 'Outcome +4', 'Outcome +5', 'bias']
saving_directory = data_dir / "best_global_params_L2"
saving_directory.mkdir(parents=True, exist_ok=True)

In /nfs/nhome/live/skuroda/.conda/envs/glmhmm/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: 
The text.latex.preview rcparam was deprecated in Matplotlib 3.3 and will be removed two minor releases later.
In /nfs/nhome/live/skuroda/.conda/envs/glmhmm/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: 
The mathtext.fallback_to_cm rcparam was deprecated in Matplotlib 3.3 and will be removed two minor releases later.
In /nfs/nhome/live/skuroda/.conda/envs/glmhmm/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: Support for setting the 'mathtext.fallback_to_cm' rcParam is deprecated since 3.3 and will be removed two minor releases later; use 'mathtext.fallback : 'cm' instead.
In /nfs/nhome/live/skuroda/.conda/envs/glmhmm/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: 
The validate_bool_maybe_none function was deprecated in Matplotlib 3.3 and will be removed two mi

In [23]:
raw_params = load_best_map_params(data_dir / "best_l2_params_model_{}.npz".format(model))

inpt_y, inpt_rt, y, session, rt, stim_onset = load_data(data_dir / 'all_animals_concat.npz')
session_fold_lookup_table = load_session_fold_lookup(
    data_dir / 'all_animals_concat_session_fold_lookup.npz')

KFCV = KFoldCV(model,num_folds_training, K_vals, global_fit, 
               [transition_alpha], [prior_sigma], ridge_lambda,
               results_dir)

KFCV.save_best_iter(inpt_y, inpt_rt, y, session, rt, stim_onset,
                     session_fold_lookup_table, C, outcome_dict=nested_outcome,
                     map_params=raw_params)

Retrieving best iter results for model = GLM_HMM_y; num_folds = 5


100%|██████████| 5/5 [01:22<00:00, 16.49s/it]

Best init saved!





In [7]:
raw_params = load_best_map_params(data_dir / "best_l2_params_model_{}.npz".format(model))

In [8]:
cvbt_folds_model = load_cv_arr(results_dir / "cvbt_folds_model_{}.npz".format(model))
cvbt_train_folds_model = load_cv_arr(results_dir / "cvbt_train_folds_model_{}.npz".format(model))

for model_idx, K in enumerate(K_vals):
    print("K = " + str(K))
    with open(results_dir / "best_init_cvbt_dict_{}.json".format(model), 'r') as f:
        best_init_cvbt_dict = json.load(f)

    # Get the file name corresponding to the best initialization for
    # given K value
    raw_file, _ = get_file_name_for_best_glmhmm_fold_l2(
        cvbt_folds_model, model_idx, K, 0, transition_alpha, 0, prior_sigma, 
        results_dir, best_init_cvbt_dict, model,
        'GLM_HMM_y_raw_parameters_itr_', map_params=raw_params)
    hmm_params, lls, _, _, _= load_glmhmm_data(raw_file)

    # Calculate permutation
    init_state_dist, log_transition_matrix, weight_vectors, permutation = \
        calculate_state_permutation(hmm_params, K)

    if K == 1:
        best_params = weight_vectors
    elif K > 1:
        best_params = [[init_state_dist], [log_transition_matrix], weight_vectors]

    plot_states(weight_vectors,
                log_transition_matrix,
                saving_directory,
                K,
                save_title='best_params_' + model + '_K_',
                labels_for_plot=labels_for_plot_y)

    np.savez(saving_directory / ('best_params_' + model + '_K_' + str(K) + '.npz'),
             best_params)
    
plot_model_comparison_l2(cvbt_folds_model,
                      cvbt_train_folds_model,
                      global_fit,
                      K_vals,
                      saving_directory)

K = 1
K = 2


  val = np.asanyarray(val)


K = 3


  val = np.asanyarray(val)


K = 4


  val = np.asanyarray(val)

The `ci` parameter is deprecated. Use `errorbar=('ci', 68)` for the same effect.

  lw=4)

The `ci` parameter is deprecated. Use `errorbar=('ci', 68)` for the same effect.

  lw=4)
