# **Fit GLM-HMM to all DMDM data**
---
We next fit GLM-HMM to the all the animals in the dataset. This step is important as it searchs thprough the parameter space and find possible parameters needed to fit GLM-HMM on individual animal later.

## **HPC setting**
Ashwood's original script is written in python scirpts. Here, we rewrite it in Jupyter to make it more user-friendly to run on HPC with `dask`. [This](https://github.com/pierreglaser/hpc-tutorial/tree/main) is very useful resource to get familiar with `dask`.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# allocate the computing resources
from dask_jobqueue import SLURMCluster
from distributed import Client
from joblib import Memory, Parallel, delayed, parallel_backend
from threadpoolctl import threadpool_limits
from tqdm import tqdm

cluster = SLURMCluster(
    workers=0,      # create the workers "lazily" (upon cluster.scal)
    memory='64g',   # amount of RAM per worker
    processes=1,    # number of execution units per worker (threads and processes)
    cores=1,        # among those execution units, number of processes
    walltime="48:00:00",
    worker_extra_args=["--resources GPU=1"], # the only way to add GPUs
    local_directory='/nfs/nhome/live/skuroda/jobs', # set your path to save log
    log_directory='/nfs/nhome/live/skuroda/jobs' # set your path to save log
)   

memory = Memory('/nfs/nhome/live/skuroda/joblib-cache') # set your path

n = 10
cluster.scale(n)
client = Client(cluster)
print(client.dashboard_link)

  from distributed.utils import format_bytes, parse_bytes, tmpfile, get_ip_interface
  from distributed.utils import format_bytes, parse_bytes, tmpfile, get_ip_interface
  from distributed.utils import format_bytes, parse_bytes, tmpfile, get_ip_interface
  from distributed.utils import parse_bytes


http://192.168.234.11:8787/status


## **Fit GLM-HMM to all animals**
---

In [3]:
# ------- load modules -------
import autograd.numpy as np
import numpy as onp
import autograd.numpy.random as npr
import os
import sys

from glm_hmm_utils import fit_glm_hmm
sys.path.append('../') # a lazy trick to search parent dir
# https://stackoverflow.com/questions/34478398/import-local-function-from-a-module-housed-in-another-directory-with-relative-im
from data_io import get_file_dir, load_session_fold_lookup, load_data, load_glm_vectors
from data_labels import create_violation_mask, partition_data_by_session

from functools import partial
from collections import OrderedDict

In [4]:
# ------- setup variables -------
dname = 'dataAllMice'
C = 4  # number of output types/categories
D = 1  # data (observations) dimension
K_vals = [2, 3, 4, 5]
N_initializations = 10 #20
num_folds = 5
npr.seed(65)  # set seed in case of randomization

nested_outcome = OrderedDict() # define nested structure for behavioral outcomes
nested_outcome["Baseline"] = [2, 3]
nested_outcome["Change"] = [0, 1]

labels_for_plot = ['CSize', 'COnset', 
                   'PrevMiss?', 'PrevHit?', 'PrevFA?', 'PrevAbort?',
                   'bias']

cluster_job_arr = []
for K in K_vals:
    for i in range(num_folds):
        for j in range(N_initializations):
            cluster_job_arr.append([K, i, j])

N_em_iters = 400  # number of EM iterations
global_fit = True
transition_alpha = 1 # perform mle => set transition_alpha to 1
prior_sigma = 100

In [5]:
# ------- setup path and load data -------
data_dir =  get_file_dir().parents[1] / "data" / "dmdm" / dname / 'data_for_cluster'
# Create directory for results:
try: 
    results_dir = get_file_dir().parents[1] / "results" / "dmdm_global_fit" / dname
except:
    raise FileNotFoundError('Run GLM first to initialize parameters')


#  read in data and train/test split
animal_file = data_dir / 'all_animals_concat.npz'
session_fold_lookup_table = load_session_fold_lookup(
    data_dir / 'all_animals_concat_session_fold_lookup.npz')

inpt, y, session, _, _ = load_data(animal_file)

In [6]:
def fit_GLMHMM(inpt, y, session, 
                session_fold_lookup_table, 
                global_fit,
                transition_alpha,
                prior_sigma,
                params):
    
    [K, fold, iter] = params

    # Append a column of ones to inpt to represent the bias covariate:
    inpt = np.hstack((inpt, np.ones((len(inpt),1))))
    y = y.astype('int')
    # Identify violations for exclusion:
    violation_idx = np.where(y == -1)[0]
    nonviolation_idx, mask = create_violation_mask(violation_idx,
                                                    inpt.shape[0])

    #  GLM weights to use to initialize GLM-HMM
    init_param_file = results_dir / 'GLM' / ('fold_' + str(fold)) / 'variables_of_interest_iter_0.npz'

    # Create save directory for this initialization/fold combination:
    saving_directory = results_dir / ("GLM_HMM_K_" + str(K)) / ("fold_" + str(fold)) / ('iter_' + str(iter))
    saving_directory.mkdir(parents=True, exist_ok=True)

    launch_glm_hmm_job(inpt,
                        y,
                        session,
                        mask,
                        session_fold_lookup_table,
                        K,
                        D,
                        C,
                        N_em_iters,
                        transition_alpha,
                        prior_sigma,
                        fold,
                        iter,
                        global_fit,
                        init_param_file,
                        saving_directory)

def launch_glm_hmm_job(inpt, y, session, mask, session_fold_lookup_table, K, D,
                       C, N_em_iters, transition_alpha, prior_sigma, fold,
                       iter, global_fit, init_param_file, save_directory):
    sys.stdout.flush()
    sessions_to_keep = session_fold_lookup_table[np.where(
        session_fold_lookup_table[:, 1] != fold), 0]
    idx_this_fold = [str(sess) in sessions_to_keep for sess in session]
    this_inpt, this_y, this_session, this_mask = inpt[idx_this_fold, :], \
                                                 y[idx_this_fold, :], \
                                                 session[idx_this_fold], \
                                                 mask[idx_this_fold]
    # Only do this so that errors are avoided - these y values will not
    # actually be used for anything (due to violation mask)
    inputs, datas, masks = partition_data_by_session(
        this_inpt, this_y, this_mask, this_session)
    masks = None
    # Read in GLM fit if global_fit = True:
    if global_fit == True:
        _, params_for_initialization = load_glm_vectors(init_param_file)
    else:
        raise NotImplementedError('This notebook only runs global fitting')
    M = this_inpt.shape[1]
    npr.seed(iter)
    fit_glm_hmm(datas,
                inputs,
                masks,
                K,
                D,
                M,
                C,
                N_em_iters,
                transition_alpha,
                prior_sigma,
                global_fit,
                params_for_initialization,
                save_title=save_directory / ('glm_hmm_raw_parameters_itr_' + str(iter) + '.npz')
                )
    
fit_GLMHMM_eachparam = partial(fit_GLMHMM, inpt, y, session, session_fold_lookup_table, 
                               global_fit, transition_alpha, prior_sigma)        
fit_GLMHMM_eachparam_cached = memory.cache(fit_GLMHMM_eachparam)

with Client(cluster) as client: # upload local functions to each worker. They cannot read them with sys.append or sys.insert.
    client.wait_for_workers(n)
    client.upload_file(str(get_file_dir() / 'data_io.py'))
    client.upload_file(str(get_file_dir() / 'data_labels.py'))

In [7]:
%%time

with threadpool_limits(limits=1, user_api='blas'):
    with parallel_backend('dask', wait_for_workers_timeout=60):
        Parallel(verbose=100)(
            delayed(fit_GLMHMM_eachparam_cached)(params) for params in cluster_job_arr
            )

[Parallel(n_jobs=-1)]: Using backend DaskDistributedBackend with 10 concurrent workers.
[Parallel(n_jobs=-1)]: Done   1 tasks      | elapsed:  3.6min
[Parallel(n_jobs=-1)]: Done   2 tasks      | elapsed:  3.7min
[Parallel(n_jobs=-1)]: Done   3 tasks      | elapsed:  4.1min
[Parallel(n_jobs=-1)]: Done   4 tasks      | elapsed:  4.3min
[Parallel(n_jobs=-1)]: Done   5 tasks      | elapsed:  5.1min
[Parallel(n_jobs=-1)]: Done   6 tasks      | elapsed:  5.3min
[Parallel(n_jobs=-1)]: Done   7 tasks      | elapsed:  6.2min
[Parallel(n_jobs=-1)]: Done   8 tasks      | elapsed:  6.7min
[Parallel(n_jobs=-1)]: Done   9 tasks      | elapsed:  6.8min
[Parallel(n_jobs=-1)]: Done  10 tasks      | elapsed:  7.1min
[Parallel(n_jobs=-1)]: Done  11 tasks      | elapsed:  7.4min
[Parallel(n_jobs=-1)]: Done  12 tasks      | elapsed:  9.1min
[Parallel(n_jobs=-1)]: Done  13 tasks      | elapsed:  9.1min
[Parallel(n_jobs=-1)]: Done  14 tasks      | elapsed:  9.6min
[Parallel(n_jobs=-1)]: Done  15 tasks      |

tornado.application - ERROR - Uncaught exception GET /status/ws (172.24.170.111)
HTTPServerRequest(protocol='http', host='192.168.234.11:8787', method='GET', uri='/status/ws', version='HTTP/1.1', remote_ip='172.24.170.111')
Traceback (most recent call last):
  File "/nfs/nhome/live/skuroda/.conda/envs/glmhmm/lib/python3.7/site-packages/tornado/websocket.py", line 942, in _accept_connection
    open_result = handler.open(*handler.open_args, **handler.open_kwargs)
  File "/nfs/nhome/live/skuroda/.conda/envs/glmhmm/lib/python3.7/site-packages/tornado/web.py", line 3208, in wrapper
    return method(self, *args, **kwargs)
  File "/nfs/nhome/live/skuroda/.conda/envs/glmhmm/lib/python3.7/site-packages/bokeh/server/views/ws.py", line 149, in open
    raise ProtocolError("Token is expired.")
bokeh.protocol.exceptions.ProtocolError: Token is expired.


[Parallel(n_jobs=-1)]: Done  90 tasks      | elapsed: 213.5min
[Parallel(n_jobs=-1)]: Done  91 tasks      | elapsed: 213.8min
[Parallel(n_jobs=-1)]: Done  92 tasks      | elapsed: 214.7min
[Parallel(n_jobs=-1)]: Done  93 tasks      | elapsed: 215.3min
[Parallel(n_jobs=-1)]: Done  94 tasks      | elapsed: 217.0min
[Parallel(n_jobs=-1)]: Done  95 tasks      | elapsed: 224.1min
[Parallel(n_jobs=-1)]: Done  96 tasks      | elapsed: 226.8min
[Parallel(n_jobs=-1)]: Done  97 tasks      | elapsed: 227.3min
[Parallel(n_jobs=-1)]: Done  98 tasks      | elapsed: 231.2min
[Parallel(n_jobs=-1)]: Done  99 tasks      | elapsed: 239.1min
[Parallel(n_jobs=-1)]: Done 100 tasks      | elapsed: 259.1min
[Parallel(n_jobs=-1)]: Done 101 tasks      | elapsed: 268.2min
[Parallel(n_jobs=-1)]: Done 102 tasks      | elapsed: 276.7min
[Parallel(n_jobs=-1)]: Done 103 tasks      | elapsed: 285.7min
[Parallel(n_jobs=-1)]: Done 104 tasks      | elapsed: 300.1min
[Parallel(n_jobs=-1)]: Done 105 tasks      | elapsed: 3

  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # 

KeyboardInterrupt: 

In [None]:
# Once finished, shut down the cluster and the client.
memory.clear(warn=False)
cluster.close()
client.close()

  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # 

## **Posthoc data processing**
---

We first create a matrix of size num_models x num_folds containing normalized loglikelihood for both train and test splits. Then, save best parameters from global fits to initialize each animal's model

In [None]:
from kfold_cv import get_best_iter
from io import load_glmhmm_data, load_cv_arr
from data_postprocessing_utils import get_file_name_for_best_model_fold, \
    permute_transition_matrix, calculate_state_permutation
from glm_hmm_utils import plot_states
import json

model = 'GLM_HMM'

In [None]:
_,_,_ = get_best_iter(model, C, num_folds, data_dir, 
                      results_dir, 
                      outcome_dict=nested_outcome,
                      K_vals=K_vals)


In [None]:
saving_directory = data_dir / "best_global_params"
saving_directory.mkdir(parents=True, exist_ok=True)

cvbt_folds_model = load_cv_arr(results_dir / "cvbt_folds_model_{}.npz".format(model))

for K in K_vals:
    print("K = " + str(K))
    with open(results_dir / "/best_init_cvbt_dict_{}.json".format(model), 'r') as f:
        best_init_cvbt_dict = json.load(f)

    # Get the file name corresponding to the best initialization for
    # given K value
    raw_file = get_file_name_for_best_model_fold(
        cvbt_folds_model, K, results_dir, best_init_cvbt_dict)
    hmm_params, lls = load_glmhmm_data(raw_file)

    # Calculate permutation
    permutation = calculate_state_permutation(hmm_params)
    print(permutation)

    # Save parameters for initializing individual fits
    weight_vectors = hmm_params[2][permutation]
    log_transition_matrix = permute_transition_matrix(
        hmm_params[1][0], permutation)
    init_state_dist = hmm_params[0][0][permutation]
    params_for_individual_initialization = [[init_state_dist],
                                            [log_transition_matrix],
                                            weight_vectors]

    cv_file = results_dir / "cvbt_folds_model.npz"
    cv_file_train = results_dir / "cvbt_train_folds_model.npz"
    plot_states(weight_vectors,
                log_transition_matrix,
                cv_file,
                cv_file_train,
                saving_directory,
                K,
                labels_for_plot=labels_for_plot)

    np.savez(saving_directory + 'best_params_K_' + str(K) + '.npz',
             params_for_individual_initialization)