# **Fit GLM-HMM to individual DMDM data**
---
After fittinng GLM to all animals, we can finally fit GLM-HMM to individual animals in the dataset. Just like global GLM-HMM, we only use outcomes `y` as a dependent variable:.

## **HPC setting**
Ashwood's original script is written in python scirpts. Here, we rewrite it in Jupyter to make it more user-friendly to run on HPC with `dask`. [This](https://github.com/pierreglaser/hpc-tutorial/tree/main) is very useful resource to get familiar with `dask`.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# allocate the computing resources
from dask_jobqueue import SLURMCluster
from distributed import Client
from joblib import Memory, Parallel, delayed, parallel_backend
from threadpoolctl import threadpool_limits
from tqdm import tqdm

cluster = SLURMCluster(
    workers=0,      # create the workers "lazily" (upon cluster.scal)
    memory='64g',   # amount of RAM per worker
    processes=1,    # number of execution units per worker (threads and processes)
    cores=1,        # among those execution units, number of processes
    walltime="48:00:00",
    worker_extra_args=["--resources GPU=1"], # the only way to add GPUs
    local_directory='/nfs/nhome/live/skuroda/jobs', # set your path to save log
    log_directory='/nfs/nhome/live/skuroda/jobs' # set your path to save log
)   

memory = Memory('/nfs/nhome/live/skuroda/joblib-cache') # set your path

n = 5
cluster.scale(n)
client = Client(cluster)
print(client.dashboard_link)

  from distributed.utils import format_bytes, parse_bytes, tmpfile, get_ip_interface
  from distributed.utils import format_bytes, parse_bytes, tmpfile, get_ip_interface
  from distributed.utils import format_bytes, parse_bytes, tmpfile, get_ip_interface
  from distributed.utils import parse_bytes


http://192.168.234.51:8787/status


## **Fit GLM-HMM to individual animals**
---

In [1]:
# ------- load modules -------
import autograd.numpy as np
import numpy as onp
import autograd.numpy.random as npr
import sys
import itertools

from glm_hmm_utils import fit_glm_hmm
sys.path.append('../') # a lazy trick to search parent dir
# https://stackoverflow.com/questions/34478398/import-local-function-from-a-module-housed-in-another-directory-with-relative-im
from data_io import get_file_dir, load_session_fold_lookup, load_data, load_glm_vectors, load_animal_list, load_global_params
from data_labels import create_abort_mask, partition_data_by_session

from functools import partial
from collections import OrderedDict

In [2]:
# ------- setup variables -------
dname = 'dataAllMice'
C = 3  # number of output types/categories
D = 1  # data (observations) dimension
prior_sigma = [1, 2]
transition_alpha = [1, 2]
K_vals = [2, 3, 4, 5]
N_initializations = 2 #20
num_folds = 5

nested_outcome = OrderedDict() # define nested structure for behavioral outcomes
nested_outcome["Baseline"] = [2]
nested_outcome["Change"] = [0, 1]

cluster_job_arr = []
for K in K_vals:
    for i in range(num_folds):
        for j in range(N_initializations):
                for sigma in prior_sigma:
                    for alpha in transition_alpha:
                        cluster_job_arr.append([sigma, alpha, K, i, j])

N_em_iters = 600  # number of EM iterations
global_fit = False

To select the hyperparameters σ and α governing the prior, we performed a grid search for σ ∈ {1, 2} and α ∈ {1, 2} and selected the hyperparameters that resulted in the best performance on a held-out validation set.

In [3]:
# ------- setup path and load data -------
data_2_dir =  get_file_dir().parents[1] / "data" / "dmdm" / dname / 'data_for_cluster' / "data_by_animal"
# Create directory for results:
try: 
    results_2_dir = get_file_dir().parents[1] / "results" / "dmdm_individual_fit" / dname
except:
    raise FileNotFoundError('Run GLM first to initialize parameters')

#  read in data and train/test split
animal_list = load_animal_list(data_2_dir / 'animal_list.npz')
cluster_job_arr_with_animal = list(itertools.product(cluster_job_arr, animal_list))

print('Total number of jobs: {}'.format(len(cluster_job_arr_with_animal)))
print('Animals for individual fitting: {}'.format(animal_list))

Total number of jobs: 960
Animals for individual fitting: ['x1108393' 'x1116760' 'x1116765' 'x1117910' 'x1119408' 'x1119541']


In [6]:
def fit_GLMHMM_y_separetely(data_2_dir, results_2_dir, cluster_job_arr_with_animal):
    animal = cluster_job_arr_with_animal[1]
    print(animal)

    allparams = cluster_job_arr_with_animal[0]
    prior_sigma = allparams[0]
    transition_alpha = allparams[1]
    HMM_params = allparams[2:]

    animal_file = data_2_dir / (animal + '_processed.npz')
    session_fold_lookup_table = load_session_fold_lookup(
        data_2_dir / (animal + '_session_fold_lookup.npz'))
    this_results_dir = results_2_dir / animal

    inpt_y, inpt_rt, y, session, rt, stim_onset = load_data(animal_file)

    fit_GLMHMM_y(inpt_y, y, session, 
                 session_fold_lookup_table, 
                 global_fit,
                 transition_alpha,
                 prior_sigma,
                 this_results_dir,
                 HMM_params)

def fit_GLMHMM_y(inpt_y, y, session, 
                 session_fold_lookup_table, 
                 global_fit,
                 transition_alpha,
                 prior_sigma,
                 results_dir,
                 params):
    
    [K, fold, iter] = params

    # Append a column of ones to inpt to represent the bias covariate:
    inpt_y = np.hstack((inpt_y, np.ones((len(inpt_y),1))))
    y = y.astype('int')
    # Identify violations for exclusion:
    abort_idx = np.where(y == 3)[0]
    nonviolation_idx, mask = create_abort_mask(abort_idx, inpt_y.shape[0])

    #  GLM weights to use to initialize GLM-HMM
    if global_fit == True:
        raise NotImplementedError('This notebook only runs individual fitting')
    else:
        init_param_file = data_2_dir.parents[0] / 'best_global_params' / ('best_params_GLM_HMM_y_K_' + str(K) + '.npz')

    # Create save directory for this initialization/fold combination:
    saving_directory = results_dir / ("GLM_HMM_y_K_" + str(K)) / ("fold_" + str(fold)) / ('iter_' + str(iter))
    saving_directory.mkdir(parents=True, exist_ok=True)

    launch_glm_hmm_job(inpt_y,
                       y,
                       session,
                       mask,
                       session_fold_lookup_table,
                       K,
                       D,
                       C,
                       N_em_iters,
                       transition_alpha,
                       prior_sigma,
                       fold,
                       iter,
                       global_fit,
                       init_param_file,
                       saving_directory)


def launch_glm_hmm_job(inpt, y, session, mask, session_fold_lookup_table, K, D,
                       C, N_em_iters, transition_alpha, prior_sigma, fold,
                       iter, global_fit, init_param_file, save_directory):
    sys.stdout.flush()
    sessions_to_keep = session_fold_lookup_table[np.where(
        session_fold_lookup_table[:, 1] != fold), 0]
    idx_this_fold = [str(sess) in sessions_to_keep for sess in session]
    this_inpt, this_y, this_session, this_mask = inpt[idx_this_fold, :], \
                                                 y[idx_this_fold, :], \
                                                 session[idx_this_fold], \
                                                 mask[idx_this_fold, :]
    
    # Only do this so that errors are avoided - these y values will not
    # actually be used for anything (due to violation mask)
    this_y[np.where(this_y == 3), :] = 2
    
    inputs, datas, masks = partition_data_by_session(
        this_inpt, this_y, this_mask, this_session)
    # Read in GLM fit if global_fit = True:
    if global_fit == True:
         raise NotImplementedError('This notebook only runs individual fitting')
    else:
        params_for_initialization = load_global_params(init_param_file)
    M = this_inpt.shape[1]

    npr.seed(iter)
    fit_glm_hmm(datas,
                inputs,
                masks,
                K,
                D,
                M,
                C,
                N_em_iters,
                transition_alpha,
                prior_sigma,
                global_fit,
                params_for_initialization,
                save_title=save_directory / ('glm_hmm_y_raw_parameters_itr_' + str(iter) + '.npz')
                )

fit_GLMHMM_separetely_eachparam = partial(fit_GLMHMM_y_separetely, data_2_dir, results_2_dir)        
fit_GLMHMM_separetely_eachparam_cached = memory.cache(fit_GLMHMM_separetely_eachparam)

In [7]:
with Client(cluster) as client: # upload local functions to each worker. They cannot read them with sys.append or sys.insert.
    client.wait_for_workers(n)
    client.upload_file(str(get_file_dir() / 'data_io.py'))
    client.upload_file(str(get_file_dir() / 'data_labels.py'))

In [8]:
%%time

with threadpool_limits(limits=1, user_api='blas'):
    with parallel_backend('dask', wait_for_workers_timeout=120):
        Parallel(verbose=100)(
            delayed(fit_GLMHMM_separetely_eachparam)(allparams) for allparams in cluster_job_arr_with_animal
            )

[Parallel(n_jobs=-1)]: Using backend DaskDistributedBackend with 5 concurrent workers.
[Parallel(n_jobs=-1)]: Done   1 tasks      | elapsed:   22.2s
[Parallel(n_jobs=-1)]: Done   2 tasks      | elapsed:   25.2s
[Parallel(n_jobs=-1)]: Done   3 tasks      | elapsed:   32.4s
[Parallel(n_jobs=-1)]: Done   4 tasks      | elapsed:   36.0s
[Parallel(n_jobs=-1)]: Done   5 tasks      | elapsed:   36.0s
[Parallel(n_jobs=-1)]: Done   6 tasks      | elapsed:   39.9s
[Parallel(n_jobs=-1)]: Done   7 tasks      | elapsed:   52.9s
[Parallel(n_jobs=-1)]: Done   8 tasks      | elapsed:   53.8s
[Parallel(n_jobs=-1)]: Done   9 tasks      | elapsed:  1.1min
[Parallel(n_jobs=-1)]: Done  10 tasks      | elapsed:  1.1min
[Parallel(n_jobs=-1)]: Done  11 tasks      | elapsed:  1.3min
[Parallel(n_jobs=-1)]: Done  12 tasks      | elapsed:  1.3min
[Parallel(n_jobs=-1)]: Done  13 tasks      | elapsed:  1.3min
[Parallel(n_jobs=-1)]: Done  14 tasks      | elapsed:  1.4min
[Parallel(n_jobs=-1)]: Done  15 tasks      | 

In [9]:
# Once finished, shut down the cluster and the client.
memory.clear(warn=False)
cluster.close()
client.close()

  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone


tornado.application - ERROR - Exception in callback <bound method Cluster._sync_cluster_info of SLURMCluster(9ff27757, 'tcp://192.168.234.51:37643', workers=0, threads=0, memory=0 B)>
Traceback (most recent call last):
  File "/nfs/nhome/live/skuroda/.conda/envs/glmhmm/lib/python3.7/site-packages/distributed/comm/core.py", line 286, in connect
    timeout=min(intermediate_cap, time_left()),
  File "/nfs/nhome/live/skuroda/.conda/envs/glmhmm/lib/python3.7/asyncio/tasks.py", line 449, in wait_for
    raise futures.TimeoutError()
concurrent.futures._base.TimeoutError

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/nfs/nhome/live/skuroda/.conda/envs/glmhmm/lib/python3.7/site-packages/tornado/ioloop.py", line 923, in _run
    await val
  File "/nfs/nhome/live/skuroda/.conda/envs/glmhmm/lib/python3.7/site-packages/distributed/deploy/cluster.py", line 107, in _sync_cluster_info
    value=copy.copy(self._cluster_info),
  File "

## **Posthoc data processing**
---

We next check the output files.

In [4]:
from kfold_cv import get_best_iter
from data_io import load_data, load_session_fold_lookup, load_glmhmm_data, load_cv_arr, get_file_name_for_best_model_fold
from data_postprocessing_utils import permute_transition_matrix, calculate_state_permutation
from glm_hmm_utils import plot_states, plot_model_comparison, plot_state_occupancy
import json

model = 'GLM_HMM_y'
labels_for_plot_y = ['CSize', 'COnset', 'PrevRewarded?', 'bias']
print('Animals for individual fitting: {}'.format(animal_list))

Animals for individual fitting: ['x1108393' 'x1116760' 'x1116765' 'x1117910' 'x1119408' 'x1119541']


In [11]:
for animal in animal_list:
    print(animal)
    this_results_dir = results_2_dir / animal

    animal_file = data_2_dir / (animal + '_processed.npz')
    session_fold_lookup_table = load_session_fold_lookup(
        data_2_dir / (animal + '_session_fold_lookup.npz'))
    inpt_y, inpt_rt, y, session, rt, stim_onset = load_data(animal_file)

    _,_,_ = get_best_iter(model, C, num_folds,
                          inpt_y, inpt_rt, y, session, rt, stim_onset,
                          session_fold_lookup_table,
                          this_results_dir, 
                          outcome_dict=nested_outcome,
                          K_vals=K_vals)


x1108393
Retrieving best iter results for model = GLM_HMM_y; num_folds = 5


100%|████| 5/5 [00:07<00:00,  1.50s/it]


[[0.66511948 0.62549102 0.61666093 0.52696882 0.47195031]
 [0.66755955 0.6480227  0.64787212 0.53419323 0.50289026]
 [0.67985898 0.64256434 0.64228998 0.56437512 0.50693933]
 [0.67295974 0.64806751 0.63262338 0.55625596 0.50839125]]
[[0.55779769 0.56978993 0.56810907 0.59599135 0.60537124]
 [0.59188934 0.599317   0.59524129 0.6276701  0.63262652]
 [0.60383042 0.60882276 0.61137018 0.63670074 0.64202524]
 [0.60510805 0.60535401 0.60971387 0.63285674 0.63743462]]
Best iter saved!
x1116760
Retrieving best iter results for model = GLM_HMM_y; num_folds = 5


100%|████| 5/5 [00:04<00:00,  1.23it/s]


[[0.90560415 0.97552247 0.99873148 0.95619476 0.99227798]
 [0.91024536 0.96909238 0.99094464 0.96229239 1.02878745]
 [0.91817621 1.01212903 1.01841828 0.97365759 1.01799427]
 [0.91042584 1.01579536 1.01009229 0.98431596 1.00075481]]
[[0.98420831 0.96069573 0.95879356 0.96961557 0.9617694 ]
 [1.00209878 0.97485647 0.97111627 0.97911277 0.96706376]
 [1.01460951 0.98400032 0.98625787 0.99459053 0.98009619]
 [1.00909131 0.98898865 0.98490879 0.99871536 0.97777748]]
Best iter saved!
x1116765
Retrieving best iter results for model = GLM_HMM_y; num_folds = 5


100%|████| 5/5 [00:04<00:00,  1.12it/s]


[[1.18909472 0.96782084 1.01676972 0.97281583 1.05981937]
 [1.1905263  0.97413761 1.03176391 0.97749544 1.04595651]
 [1.21259443 0.97418557 1.05089256 1.0265197  1.05794262]
 [1.19431326 0.97408747 1.02972398 0.98844349 1.05646195]]
[[1.00260093 1.04740529 1.03481362 1.04679072 1.02410126]
 [1.00463815 1.07808866 1.04310196 1.0476037  1.03929239]
 [1.02384539 1.08125871 1.05529004 1.06649371 1.05557349]
 [1.00764724 1.07787047 1.03844854 1.05230426 1.05403272]]
Best iter saved!
x1117910
Retrieving best iter results for model = GLM_HMM_y; num_folds = 5


100%|████| 5/5 [00:04<00:00,  1.13it/s]


[[0.80421686 0.82626779 0.75443432 0.78835859 0.96371695]
 [0.82978862 0.85611875 0.77515803 0.84647677 1.00827809]
 [0.84835542 0.86871717 0.79671616 0.84531261 0.99640836]
 [0.85478281 0.86851405 0.81459498 0.863344   0.99940813]]
[[0.82819129 0.82387599 0.84373805 0.83288361 0.79256931]
 [0.87108652 0.86576123 0.88681624 0.86812215 0.83067306]
 [0.89007012 0.87786645 0.89588033 0.88161789 0.84826149]
 [0.89203325 0.89102392 0.90727935 0.89360047 0.86088276]]
Best iter saved!
x1119408
Retrieving best iter results for model = GLM_HMM_y; num_folds = 5


100%|████| 5/5 [00:04<00:00,  1.11it/s]


[[0.89885893 0.88291852 1.01414393 0.92611417 0.85998572]
 [0.90385423 0.88519332 1.01388846 0.92799249 0.85948038]
 [0.90228241 0.87794499 1.01326014 0.9278806  0.8592228 ]
 [0.8988221  0.87740664 1.01115991 0.92681926 0.86030263]]
[[0.91882227 0.92640153 0.89708405 0.9103732  0.92978811]
 [0.92236939 0.92836638 0.90135153 0.91426882 0.93186533]
 [0.92312881 0.93066062 0.90204974 0.91642136 0.932001  ]
 [0.92105978 0.9293492  0.89982193 0.91704085 0.9313396 ]]
Best iter saved!
x1119541
Retrieving best iter results for model = GLM_HMM_y; num_folds = 5


100%|████| 5/5 [00:04<00:00,  1.17it/s]

[[0.98249045 0.88263499 0.95113559 0.89769353 0.83583312]
 [0.98020548 0.91212243 0.96418088 0.93497133 0.87732067]
 [0.99290797 0.9326275  0.98352675 0.95785024 0.89550552]
 [0.99209088 0.93953548 0.97580835 0.98500025 0.9076574 ]]
[[0.89638975 0.92286997 0.90456295 0.92070678 0.93222462]
 [0.93771817 0.95517798 0.94181816 0.93855035 0.96449532]
 [0.95696102 0.97103867 0.95910083 0.96587034 0.97862177]
 [0.96708462 0.97850156 0.96917966 0.97048487 0.98758509]]
Best iter saved!





In [8]:
for animal in animal_list:
    print(animal)
    this_results_dir = results_2_dir / animal
    saving_directory = data_2_dir / ("best_params_" + animal)
    saving_directory.mkdir(parents=True, exist_ok=True)

    animal_file = data_2_dir / (animal + '_processed.npz')
    session_fold_lookup_table = load_session_fold_lookup(
        data_2_dir / (animal + '_session_fold_lookup.npz'))
    inpt_y, inpt_rt, y, session, rt, stim_onset = load_data(animal_file)

    cvbt_folds_model = load_cv_arr(this_results_dir / "cvbt_folds_model_{}.npz".format(model))
    cvbt_train_folds_model = load_cv_arr(this_results_dir / "cvbt_train_folds_model_{}.npz".format(model))

    for K in K_vals:
        print("K = " + str(K))
        with open(this_results_dir  / "best_init_cvbt_dict_{}.json".format(model), 'r') as f:
            best_init_cvbt_dict = json.load(f)

        # Get the file name corresponding to the best initialization for
        # given K value
        raw_file = get_file_name_for_best_model_fold(
            cvbt_folds_model, K, this_results_dir,  best_init_cvbt_dict, model,
            'glm_hmm_y_raw_parameters_itr_')
        hmm_params, lls = load_glmhmm_data(raw_file)

        # Calculate permutation
        permutation = calculate_state_permutation(hmm_params)
        print(permutation)

        # Save best parameters
        weight_vectors = hmm_params[2][permutation]
        log_transition_matrix = permute_transition_matrix(
            hmm_params[1][0], permutation)
        init_state_dist = hmm_params[0][0][permutation]
        best_params = [[init_state_dist], [log_transition_matrix], weight_vectors]

        plot_states(weight_vectors,
                    log_transition_matrix,
                    saving_directory,
                    K,
                    save_title='best_params_' + model + '_K_',
                    labels_for_plot=labels_for_plot_y)
        plot_state_occupancy(inpt_y, inpt_rt, y, session, rt, stim_onset,
                             K, hmm_params, animal,
                             saving_directory,
                             save_title='best_state_occpancy_' + model + '_K_')
    
        np.savez(saving_directory / ('best_params_' + model + '_K_' + str(K) + '.npz'),
                best_params)
        
    plot_model_comparison(cvbt_folds_model,
                        cvbt_train_folds_model,
                        K_vals,
                        saving_directory)

x1108393
K = 2
3
[0 1]
2
[array([[0],
       [1],
       [1],
       [0],
       [1],
       [1],
       [1],
       [1],
       [0],
       [0],
       [0],
       [1],
       [1],
       [0],
       [1],
       [0],
       [1],
       [1],
       [0],
       [0],
       [0],
       [0],
       [1],
       [1],
       [1],
       [1],
       [1],
       [1],
       [1],
       [0],
       [1],
       [1],
       [1],
       [1],
       [1],
       [1],
       [1],
       [1],
       [1],
       [1],
       [0],
       [1],
       [1],
       [1],
       [1],
       [0],
       [1],
       [0],
       [1],
       [1],
       [1],
       [1],
       [1],
       [1],
       [1],
       [1],
       [1],
       [1],
       [1],
       [1],
       [1],
       [1],
       [1],
       [1],
       [1],
       [1],
       [1],
       [1],
       [1],
       [1],
       [1],
       [0],
       [1],
       [1],
       [1],
       [1],
       [1],
       [1],
       [1],
       [1],
       [1],
  

AssertionError: 

save and sigma and alpha in npz