# **Fit GLM-HMM to individual DMDM data**
---
After fittinng GLM to all animals, we can finally fit GLM-HMM to individual animals in the dataset. Just like global GLM-HMM, we only use outcomes `y` as a dependent variable.

## **HPC setting**
Ashwood's original script is written in python scirpts. Here, we rewrite it in Jupyter to make it more user-friendly to run on HPC with `dask`. [This](https://github.com/pierreglaser/hpc-tutorial/tree/main) is very useful resource to get familiar with `dask`.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# allocate the computing resources
from dask_jobqueue import SLURMCluster
from distributed import Client
from joblib import Memory, Parallel, delayed, parallel_backend
from threadpoolctl import threadpool_limits
from tqdm import tqdm

cluster = SLURMCluster(
    workers=0,      # create the workers "lazily" (upon cluster.scal)
    memory='64g',   # amount of RAM per worker
    processes=1,    # number of execution units per worker (threads and processes)
    cores=1,        # among those execution units, number of processes
    walltime="72:00:00",
    job_extra_directives=[
        "--gres=gpu:1",
        "-p gpu",
    ], # the only way to add GPUs
    local_directory='/nfs/nhome/live/skuroda/jobs', # set your path to save log
    log_directory='/nfs/nhome/live/skuroda/jobs' # set your path to save log
)   

memory = Memory('/nfs/nhome/live/skuroda/joblib-cache') # set your path

n = 60
cluster.scale(n)
client = Client(cluster)
print(client.dashboard_link)

  from distributed.utils import format_bytes, parse_bytes, tmpfile, get_ip_interface
  from distributed.utils import format_bytes, parse_bytes, tmpfile, get_ip_interface
  from distributed.utils import format_bytes, parse_bytes, tmpfile, get_ip_interface
  from distributed.utils import parse_bytes


http://192.168.234.51:8787/status


## **Fit GLM-HMM to individual animals using MAP**
---
We use MAP instead of MLE for EM algorithm here. To select the hyperparameters σ and α governing the prior, we performed a grid search for σ ∈ {0.5, 1, 1.5,  2} and α ∈ {2} and later selected the hyperparameters that resulted in the best performance **within training dataset**. (Don't use the ones with the best performance on a held-out validation set, as it can overfit hyperparameters.) Hence, we first fit GLM-HMM to tune the parameter and repeat the fitting process to retrieve the best GLM results for each state/animal. The first grid search requires quite a long time.

In [3]:
# ------- load modules -------
import autograd.numpy as np
import numpy as onp
import autograd.numpy.random as npr
import sys
import itertools

from glm_hmm_utils import fit_glm_hmm
sys.path.append('../') # a lazy trick to search parent dir
# https://stackoverflow.com/questions/34478398/import-local-function-from-a-module-housed-in-another-directory-with-relative-im
from data_io import get_file_dir, load_session_fold_lookup, load_data, load_animal_list, load_glmhmm_data, load_global_best_params
from data_labels import create_abort_mask, partition_data_by_session

from functools import partial
from collections import OrderedDict

In [4]:
# ------- setup variables -------
dname = 'dataAllMiceTraining'
C = 3  # number of output types/categories
D = 1  # data (observations) dimension
prior_sigma = [0.5, 1, 1.5, 2, 2.5]
transition_alpha = [2] # alpha = 1 performs MLE
K_vals = [1, 2, 3, 4]
N_initializations = 2 #20
num_folds_training = num_folds_tuning = 5

nested_outcome = OrderedDict() # define nested structure for behavioral outcomes
nested_outcome["Baseline"] = [2]
nested_outcome["Change"] = [0, 1]

cluster_job_arr = []
for K in K_vals:
    for i in range(num_folds_training):
        for ii in range(num_folds_tuning):
            for j in range(N_initializations):
                    for sigma in prior_sigma:
                        for alpha in transition_alpha:
                            cluster_job_arr.append([sigma, alpha, K, i, ii, j])

N_em_iters = 600  # number of EM iterations
global_fit = False
paramter_tuning = True

In [5]:
# ------- setup path and load data -------
data_2_dir =  get_file_dir().parents[1] / "data" / "dmdm" / dname / 'data_for_cluster' / "data_by_animal"
# Create directory for results:
try: 
    results_2_dir = get_file_dir().parents[1] / "results" / "dmdm_individual_fit" / dname
except:
    raise FileNotFoundError('Run GLM first to initialize parameters')

#  read in data and train/test split
animal_list = load_animal_list(data_2_dir / 'animal_list.npz')
cluster_job_arr_with_animal = list(itertools.product(cluster_job_arr, animal_list))

print('Total number of jobs: {}'.format(len(cluster_job_arr_with_animal)))
print('Animals for individual fitting: {}'.format(animal_list))

Total number of jobs: 14000
Animals for individual fitting: ['M_AK001' 'M_AK004' 'M_AK005' 'M_AK008' 'M_IO125' 'M_IO127' 'M_IO128'
 'M_IO132' 'M_IO135' 'M_IO136' 'M_IO137' 'M_IO138' 'M_ML007' 'M_ML008']


In [6]:
def fit_GLMHMM_y_separetely(data_2_dir, results_2_dir, paramter_tuning, cluster_job_arr_with_animal):
    animal = cluster_job_arr_with_animal[1]
    print(animal)

    allparams = cluster_job_arr_with_animal[0]
    prior_sigma = allparams[0]
    transition_alpha = allparams[1]
    HMM_params = [int(p) for p in allparams[2:]]


    animal_file = data_2_dir / (animal + '_processed.npz')
    session_fold_lookup_table = load_session_fold_lookup(
        data_2_dir / (animal + '_session_fold_lookup.npz'))
    this_results_dir = results_2_dir / animal

    inpt_y, inpt_rt, y, session, rt, stim_onset = load_data(animal_file)

    fit_GLMHMM_y(inpt_y, y, session, 
                 session_fold_lookup_table, 
                 global_fit,
                 paramter_tuning,
                 transition_alpha,
                 prior_sigma,
                 this_results_dir,
                 HMM_params)

def fit_GLMHMM_y(inpt_y, y, session, 
                 session_fold_lookup_table, 
                 global_fit,
                 paramter_tuning,
                 transition_alpha,
                 prior_sigma,
                 results_dir,
                 params):
    
    [K, fold_training, fold_tuning, iter] = params

    # Append a column of ones to inpt to represent the bias covariate:
    inpt_y = np.hstack((inpt_y, np.ones((len(inpt_y),1))))
    y = y.astype('int')
    # Identify violations for exclusion:
    abort_idx = np.where(y == 3)[0]
    nonviolation_idx, mask = create_abort_mask(abort_idx, inpt_y.shape[0])

    #  GLM weights to use to initialize GLM-HMM
    if global_fit == True:
        raise NotImplementedError('This notebook only runs individual fitting')
    else:
        init_param_file = data_2_dir.parents[0] / 'best_global_params' / ('best_params_GLM_HMM_y_K_' + str(K) + '.npz')
        # fold does not matter here.

    # Create save directory for this initialization/fold combination:
    if paramter_tuning:
        saving_directory = results_dir / ("GLM_HMM_y_K_" + str(K)) / ("fold_" + str(fold_training)) / ("tuningfold_" + str(fold_tuning)) / ('iter_' + str(iter))
    else:
        saving_directory = results_dir / ("GLM_HMM_y_K_" + str(K)) / ("fold_" + str(fold_training)) / ('iter_' + str(iter))
    saving_directory.mkdir(parents=True, exist_ok=True)

    launch_glm_hmm_job(inpt_y,
                       y,
                       session,
                       mask,
                       session_fold_lookup_table,
                       K,
                       D,
                       C,
                       N_em_iters,
                       transition_alpha,
                       prior_sigma,
                       fold_training,
                       fold_tuning,
                       paramter_tuning,
                       iter,
                       global_fit,
                       init_param_file,
                       saving_directory)


def launch_glm_hmm_job(inpt, y, session, mask, session_fold_lookup_table, K, D,
                       C, N_em_iters, transition_alpha, prior_sigma, fold_training, 
                       fold_tuning, paramter_tuning, iter, global_fit, init_param_file, save_directory):
    sys.stdout.flush()
    if paramter_tuning:
        split_loc = np.logical_and(session_fold_lookup_table[:, 1] != fold_training, session_fold_lookup_table[:, 2] != fold_tuning)
        sessions_to_keep = session_fold_lookup_table[split_loc, 0]
    else:
        sessions_to_keep = session_fold_lookup_table[np.where(
            session_fold_lookup_table[:, 1] != fold_training), 0]

    idx_this_fold = [str(sess) in sessions_to_keep for sess in session]
    this_inpt, this_y, this_session, this_mask = inpt[idx_this_fold, :], \
                                                 y[idx_this_fold, :], \
                                                 session[idx_this_fold], \
                                                 mask[idx_this_fold, :]
    
    # Only do this so that errors are avoided - these y values will not
    # actually be used for anything (due to violation mask)
    this_y[np.where(this_y == 3), :] = 2
    
    inputs, datas, masks = partition_data_by_session(
        this_inpt, this_y, this_mask, this_session)
    # Read in GLM fit if global_fit = True:
    if global_fit == True:
         raise NotImplementedError('This notebook only runs individual fitting')
    else:
        params_for_initialization = load_global_best_params(init_param_file)
    M = this_inpt.shape[1]

    npr.seed(iter)
    fit_glm_hmm(datas,
                inputs,
                masks,
                K,
                D,
                M,
                C,
                N_em_iters,
                transition_alpha,
                prior_sigma,
                global_fit,
                False,
                0,
                params_for_initialization,
                save_title=save_directory / ('GLM_HMM_y_raw_parameters_itr_' + str(iter))
                )

fit_GLMHMM_separetely_eachparam = partial(fit_GLMHMM_y_separetely, data_2_dir, results_2_dir, 
                                           paramter_tuning)        
fit_GLMHMM_separetely_eachparam_cached = memory.cache(fit_GLMHMM_separetely_eachparam)

In [7]:
with Client(cluster) as client: # upload local functions to each worker. They cannot read them with sys.append or sys.insert.
    client.wait_for_workers(n)
    client.upload_file(str(get_file_dir() / 'data_io.py'))
    client.upload_file(str(get_file_dir() / 'data_labels.py'))

In [8]:
%%time

with threadpool_limits(limits=1, user_api='blas'):
    with parallel_backend('dask', wait_for_workers_timeout=120):
        Parallel(verbose=100)(
            delayed(fit_GLMHMM_separetely_eachparam)(allparams) for allparams in cluster_job_arr_with_animal
            )

[Parallel(n_jobs=-1)]: Using backend DaskDistributedBackend with 20 concurrent workers.
[Parallel(n_jobs=-1)]: Done   1 tasks      | elapsed:    5.0s
[Parallel(n_jobs=-1)]: Done   2 tasks      | elapsed:    6.1s
[Parallel(n_jobs=-1)]: Done   3 tasks      | elapsed:    6.1s
[Parallel(n_jobs=-1)]: Done   4 tasks      | elapsed:    6.3s
[Parallel(n_jobs=-1)]: Done   5 tasks      | elapsed:    6.8s
[Parallel(n_jobs=-1)]: Done   6 tasks      | elapsed:    6.9s
[Parallel(n_jobs=-1)]: Done   7 tasks      | elapsed:    7.0s
[Parallel(n_jobs=-1)]: Done   8 tasks      | elapsed:    7.0s
[Parallel(n_jobs=-1)]: Done   9 tasks      | elapsed:    7.1s
[Parallel(n_jobs=-1)]: Done  10 tasks      | elapsed:    7.2s
[Parallel(n_jobs=-1)]: Done  11 tasks      | elapsed:    7.5s
[Parallel(n_jobs=-1)]: Done  12 tasks      | elapsed:    8.1s
[Parallel(n_jobs=-1)]: Done  13 tasks      | elapsed:    8.3s
[Parallel(n_jobs=-1)]: Done  14 tasks      | elapsed:    8.4s
[Parallel(n_jobs=-1)]: Done  15 tasks      |

  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # 

KilledWorker: ('batch_of_fit_GLMHMM_y_separetely_1_calls-02ca5a1dd4d343c3a67586e5e5148048', <WorkerState 'tcp://192.168.234.56:44989', name: 1, status: closed, memory: 0, processing: 29>)

  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone


In [None]:
# Once finished, shut down the cluster and the client.
memory.clear(warn=False)
cluster.close()
client.close()

## **MAP hyperparameter tuning**
---

We first check the output files and save the best initialization results within hyperparameter-tuning-cross-validation. We then find the best sigma and alpha hyperparameters by averaged performance of all folds for each state `K`, test fold `k`, and animal `animal`.

In [None]:
from kfold_cv import KFoldCV
from data_io import load_data, load_session_fold_lookup, load_glmhmm_data, load_best_map_params, \
                     load_cv_arr, get_file_name_for_best_glmhmm_fold
from data_postprocessing_utils import calculate_state_permutation
import json

model = 'GLM_HMM_y'
animal_and_training_folds = list(itertools.product(animal_list, range(num_folds_training)))
print('Animals for individual fitting: {}'.format(animal_list))

In [None]:
out_params = np.zeros((len(K_vals)*num_folds_training*len(animal_list), 5), dtype='<U32') # sigma, alpha, K, fold, and animal
idx_count = 0

for idx, (animal, fold_training) in enumerate(animal_and_training_folds):
    print(animal)
    this_results_dir = results_2_dir / animal

    animal_file = data_2_dir / (animal + '_processed.npz')
    session_fold_lookup_table = load_session_fold_lookup(
        data_2_dir / (animal + '_session_fold_lookup.npz'))
    inpt_y, inpt_rt, y, session, rt, stim_onset = load_data(animal_file)

    KFCV = KFoldCV(model,num_folds_tuning, K_vals, global_fit, 
                   transition_alpha,prior_sigma,
                   this_results_dir)
    out = KFCV.save_best_MAPparams(inpt_y, inpt_rt, y, session, rt, stim_onset,
                                    session_fold_lookup_table, fold_training, C, 
                                    outcome_dict=nested_outcome, save_output=True)

    out_arr = np.array(out)[:,[2,1,0]] # sort

    out_arr_fold_animal = np.hstack([out_arr, 
                                      np.repeat(fold_training, len(K_vals)).T[:, None], 
                                      np.repeat(animal, len(K_vals)).T[:, None]])
                                    
    out_params[idx_count:idx_count+len(K_vals),:] = out_arr_fold_animal
    idx_count += len(K_vals)

np.savez(data_2_dir / "best_params_model_{}.npz".format(model), out_params)
print("Saved all the MAP params!")

## **Refitting GLM-HMM to individual animals with best MAP hyperparameters**
---

We fit GLM-HMM to separate animals again, but using the best MAP hyperparams we just found.

In [None]:
paramter_tuning = False
fold_dummy = 100

raw_params = load_best_map_params(data_2_dir / "best_params_model_{}.npz".format(model))
params_noanimal = raw_params[:,0:-1].astype('float32')

In [None]:
cluster_job_arr_2 = []
for i in range(len(params_noanimal)):
        for j in range(N_initializations):
                        param_i = [params_noanimal[i,0], 
                                   params_noanimal[i,1],
                                   params_noanimal[i,2],
                                   params_noanimal[i,3],
                                   fold_dummy, j]
                        cluster_job_arr_2.append((param_i, raw_params[:,-1].tolist()[i]))

print('Total number of jobs: {}'.format(len(cluster_job_arr_2)))
print('Animals for individual fitting: {}'.format(animal_list))

In [None]:
fit_GLMHMM_separetely_eachparam_2 = partial(fit_GLMHMM_y_separetely, data_2_dir, results_2_dir, 
                                           paramter_tuning)        
fit_GLMHMM_separetely_eachparam_2_cached = memory.cache(fit_GLMHMM_separetely_eachparam_2)

In [None]:
with Client(cluster) as client: # upload local functions to each worker. They cannot read them with sys.append or sys.insert.
    client.wait_for_workers(n)
    client.upload_file(str(get_file_dir() / 'data_io.py'))
    client.upload_file(str(get_file_dir() / 'data_labels.py'))

In [None]:
%%time

with threadpool_limits(limits=1, user_api='blas'):
    with parallel_backend('dask', wait_for_workers_timeout=120):
        Parallel(verbose=100)(
            delayed(fit_GLMHMM_separetely_eachparam_2)(allparams) for allparams in cluster_job_arr_2
            )

In [None]:
# Once finished, shut down the cluster and the client.
memory.clear(warn=False)
cluster.close()
client.close()

## **Posthoc data processing**
---

In [None]:
from kfold_cv import KFoldCV
from data_io import load_data, load_session_fold_lookup, load_glmhmm_data, load_cv_arr, get_file_name_for_best_glmhmm_fold, get_best_map_params
from data_postprocessing_utils import calculate_state_permutation
sys.path.append('../../../3_make_figures/dmdm/')
from plot_model_perform import plot_states, plot_model_comparison, plot_state_occupancy
import json

model = 'GLM_HMM_y'
labels_for_plot_y = ['CSize', 'COnset', 'Outcome +1', 'Outcome +2', 'Outcome +3', 'Outcome +4', 'Outcome +5', 'bias']
print('Animals for individual fitting: {}'.format(animal_list))

In [None]:
raw_params = load_best_map_params(data_2_dir / "best_params_model_{}.npz".format(model))

for animal in animal_list:
    print(animal)
    this_results_dir = results_2_dir / animal

    animal_file = data_2_dir / (animal + '_processed.npz')
    session_fold_lookup_table = load_session_fold_lookup(
        data_2_dir / (animal + '_session_fold_lookup.npz'))
    inpt_y, inpt_rt, y, session, rt, stim_onset = load_data(animal_file)

    KFCV = KFoldCV(model, num_folds_training, K_vals, global_fit, 
                   transition_alpha, prior_sigma,
                   this_results_dir, animal=animal)
    KFCV.save_best_iter(inpt_y, inpt_rt, y, session, rt, stim_onset,
                         session_fold_lookup_table, C, outcome_dict=nested_outcome,
                         map_params=raw_params)


In [None]:
for animal in animal_list:
    print(animal)
    this_results_dir = results_2_dir / animal
    saving_directory = data_2_dir / ("best_params_" + animal)
    saving_directory.mkdir(parents=True, exist_ok=True)

    animal_file = data_2_dir / (animal + '_processed.npz')
    session_fold_lookup_table = load_session_fold_lookup(
        data_2_dir / (animal + '_session_fold_lookup.npz'))
    inpt_y, inpt_rt, y, session, rt, stim_onset = load_data(animal_file)

    cvbt_folds_model = load_cv_arr(this_results_dir / "cvbt_folds_model_{}.npz".format(model))
    cvbt_train_folds_model = load_cv_arr(this_results_dir / "cvbt_train_folds_model_{}.npz".format(model))

    for model_idx, K in enumerate(K_vals):
        print("K = " + str(K))
        with open(this_results_dir / "best_init_cvbt_dict_{}.json".format(model), 'r') as f:
            best_init_cvbt_dict = json.load(f)

        # Get the normalized log likelihood corresponding to the best initialization for
        # given K, alpha, and sigma

        raw_file_bestmap, _ = get_file_name_for_best_glmhmm_fold(
            cvbt_folds_model, model_idx, K, 
            overall_dir = this_results_dir, best_init_cvbt_dict = best_init_cvbt_dict, 
            model = model, fname_header = 'GLM_HMM_y_raw_parameters_itr_',
            global_fit=False, map_params=raw_params, animal=animal)
        hmm_params, lls, _, _, _= load_glmhmm_data(raw_file_bestmap)

        # Calculate permutation
        init_state_dist, log_transition_matrix, weight_vectors, permutation = \
            calculate_state_permutation(hmm_params, K)

        if K == 1:
            best_params = weight_vectors
        elif K > 1:
            best_params = [[init_state_dist], [log_transition_matrix], weight_vectors]

        plot_states(weight_vectors,
                    log_transition_matrix,
                    saving_directory,
                    K,
                    save_title='best_params_' + model + '_K_',
                    labels_for_plot=labels_for_plot_y)
        if K > 1:
            plot_state_occupancy(inpt_y, inpt_rt, y, session, rt, stim_onset,
                                    K, hmm_params, animal,
                                    saving_directory,
                                    save_title='best_state_occpancy_' + model + '_K_')
        np.savez(saving_directory / ('best_params_' + model + '_K_' + str(K) + '.npz'),
                best_params)
                
    plot_model_comparison(cvbt_folds_model,
                           cvbt_train_folds_model,
                           global_fit,
                           K_vals,
                           saving_directory)