# **Fit GLM-HMM to all DMDM data**
---
We next fit GLM-HMM to the all the animals in the dataset. This step is important as it searchs thprough the parameter space and find possible parameters needed to fit GLM-HMM on individual animal later.

## **HPC setting**
Ashwood's original script is written in python scirpts. Here, we rewrite it in Jupyter to make it more user-friendly to run on HPC with `dask`. [This](https://github.com/pierreglaser/hpc-tutorial/tree/main) is very useful resource to get familiar with `dask`.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# allocate the computing resources
from dask_jobqueue import SLURMCluster
from distributed import Client
from joblib import Memory, Parallel, delayed, parallel_backend
from threadpoolctl import threadpool_limits
from tqdm import tqdm

cluster = SLURMCluster(
    workers=0,      # create the workers "lazily" (upon cluster.scal)
    memory='32g',   # amount of RAM per worker
    processes=1,    # number of execution units per worker (threads and processes)
    cores=1,        # among those execution units, number of processes
    # A lazy trick to avoid matplotlib crash with parallel plotting
    walltime="24:00:00",
    worker_extra_args=["--resources GPU=2"], # the only way to add GPUs
    local_directory='/nfs/nhome/live/skuroda/jobs', # set your path to save log
    log_directory='/nfs/nhome/live/skuroda/jobs' # set your path to save log
)   

memory = Memory('/nfs/nhome/live/skuroda/joblib-cache') # set your path

cluster.scale(10)
client = Client(cluster)
print(client.dashboard_link)

  from distributed.utils import format_bytes, parse_bytes, tmpfile, get_ip_interface
  from distributed.utils import format_bytes, parse_bytes, tmpfile, get_ip_interface
  from distributed.utils import format_bytes, parse_bytes, tmpfile, get_ip_interface
  from distributed.utils import parse_bytes
Perhaps you already have a cluster running?
Hosting the HTTP server on port 35465 instead
  f"Port {expected} is already in use.\n"


http://192.168.234.51:35465/status


## **Fit GLM-HMM to all animals**
---

In [3]:
# ------- load modules -------
import autograd.numpy as np
import numpy as onp
import autograd.numpy.random as npr
from glm_hmm_utils import get_file_dir, load_session_fold_lookup, \
     load_data, create_violation_mask, fit_glm_hmm, partition_data_by_session, \
     load_glm_vectors, load_global_params
import os
import sys
from functools import partial
from collections import OrderedDict

In [4]:
# ------- setup variables -------
dname = 'dataAllMice'
C = 4  # number of output types/categories
D = 1  # data (observations) dimension
K_vals = [2, 3, 4, 5]
N_initializations = 10 #20
num_folds = 5
npr.seed(65)  # set seed in case of randomization

nested_outcome = OrderedDict() # define nested structure for behavioral outcomes
nested_outcome["Baseline"] = [2, 3]
nested_outcome["Change"] = [0, 1]

cluster_job_arr = []
for K in K_vals:
    for i in range(num_folds):
        for j in range(N_initializations):
            cluster_job_arr.append([K, i, j])

N_em_iters = 300  # number of EM iterations
global_fit = True
transition_alpha = 1 # perform mle => set transition_alpha to 1
prior_sigma = 100

In [5]:
# ------- setup path and load data -------
data_dir =  get_file_dir().parents[2] / "data" / "dmdm" / dname / 'data_for_cluster'
# Create directory for results:
try: 
    results_dir = get_file_dir().parents[2] / "results" / "dmdm_global_fit" / dname
except:
    raise FileNotFoundError('Run GLM First to initialize parameters')


#  read in data and train/test split
animal_file = data_dir / 'all_animals_concat.npz'
session_fold_lookup_table = load_session_fold_lookup(
    data_dir / 'all_animals_concat_session_fold_lookup.npz')

inpt, y, session, _, _ = load_data(animal_file)

In [6]:
def fit_GLM_HMM(inpt, y, session, 
                session_fold_lookup_table, 
                global_fit,
                transition_alpha,
                prior_sigma,
                params):
    
    [K, fold, iter] = params

    # Append a column of ones to inpt to represent the bias covariate:
    inpt = np.hstack((inpt, np.ones((len(inpt),1))))
    y = y.astype('int')
    # Identify violations for exclusion:
    violation_idx = np.where(y == -1)[0]
    nonviolation_idx, mask = create_violation_mask(violation_idx,
                                                    inpt.shape[0])

    #  GLM weights to use to initialize GLM-HMM
    init_param_file = results_dir / 'GLM' / ('fold_' + str(fold)) / 'variables_of_interest_iter_0.npz'

    # Create save directory for this initialization/fold combination:
    saving_directory = results_dir / ("GLM_HMM_K_" + str(K)) / ("fold_" + str(fold)) / ('iter_' + str(iter))
    saving_directory.mkdir(parents=True, exist_ok=True)

    launch_glm_hmm_job(inpt,
                        y,
                        session,
                        mask,
                        session_fold_lookup_table,
                        K,
                        D,
                        C,
                        N_em_iters,
                        transition_alpha,
                        prior_sigma,
                        fold,
                        iter,
                        global_fit,
                        init_param_file,
                        saving_directory)

def launch_glm_hmm_job(inpt, y, session, mask, session_fold_lookup_table, K, D,
                       C, N_em_iters, transition_alpha, prior_sigma, fold,
                       iter, global_fit, init_param_file, save_directory):
    sys.stdout.flush()
    sessions_to_keep = session_fold_lookup_table[np.where(
        session_fold_lookup_table[:, 1] != fold), 0]
    idx_this_fold = [str(sess) in sessions_to_keep for sess in session]
    this_inpt, this_y, this_session, this_mask = inpt[idx_this_fold, :], \
                                                 y[idx_this_fold, :], \
                                                 session[idx_this_fold], \
                                                 mask[idx_this_fold]
    # Only do this so that errors are avoided - these y values will not
    # actually be used for anything (due to violation mask)
    inputs, datas, masks = partition_data_by_session(
        this_inpt, this_y, this_mask, this_session)
    # Read in GLM fit if global_fit = True:
    if global_fit == True:
        _, params_for_initialization = load_glm_vectors(init_param_file)
    else:
        params_for_initialization = load_global_params(init_param_file)
    M = this_inpt.shape[1]
    npr.seed(iter)
    fit_glm_hmm(datas,
                inputs,
                masks,
                K,
                D,
                M,
                C,
                N_em_iters,
                transition_alpha,
                prior_sigma,
                global_fit,
                params_for_initialization,
                save_title=save_directory / ('glm_hmm_raw_parameters_itr_' + str(iter) + '.npz')
                )
    
fit_GLM_eachparam = partial(fit_GLM_HMM, inpt, y, session, session_fold_lookup_table, 
                            global_fit, transition_alpha, prior_sigma)        
fit_GLM_eachparam_cached = memory.cache(fit_GLM_eachparam)

In [7]:
%%time

with threadpool_limits(limits=1, user_api='blas'):
    with parallel_backend('dask', wait_for_workers_timeout=600):
        Parallel(verbose=100)(
            delayed(fit_GLM_eachparam_cached)(params) for params in cluster_job_arr
            )

100%|███████████████████████| 200/200 [00:00<00:00, 201.44it/s]

[Parallel(n_jobs=-1)]: Using backend DaskDistributedBackend with 1 concurrent workers.





In [None]:
# Once finished, shut down the cluster and the client
cluster.close()
client.close()

## **Posthoc data processing**
---

Create a matrix of size num_models x num_folds containing normalized loglikelihood for both train and test splits

In [None]:
sys.path.insert(0, '../') # a lazy trick to search parent dir
# https://stackoverflow.com/questions/34478398/import-local-function-from-a-module-housed-in-another-directory-with-relative-im
from kfold_cv import get_best_iter

In [None]:
model = 'GLM_HMM'
_,_,_ = get_best_iter(model, C, num_folds, data_dir, 
                      results_dir, 
                      outcome_dict=nested_outcome,
                      K_max=5)


In [None]:
# Save best parameters from IBL global fits (for K = 2 to 5) to initialize
# each animal's model
import json
import os
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from post_processing_utils import load_glmhmm_data, load_cv_arr, \
    create_cv_frame_for_plotting, get_file_name_for_best_model_fold, \
    permute_transition_matrix, calculate_state_permutation


if __name__ == '__main__':

    data_dir = '../../data/ibl/data_for_cluster/'
    results_dir = '../../results/ibl_global_fit/'
    save_directory = data_dir + "best_global_params/"

    if not os.path.exists(save_directory):
        os.makedirs(save_directory)

    labels_for_plot = ['stim', 'pc', 'wsls', 'bias']

    cv_file = results_dir + "/cvbt_folds_model.npz"
    cvbt_folds_model = load_cv_arr(cv_file)

    for K in range(2, 6):
        print("K = " + str(K))
        with open(results_dir + "/best_init_cvbt_dict.json", 'r') as f:
            best_init_cvbt_dict = json.load(f)

        # Get the file name corresponding to the best initialization for
        # given K value
        raw_file = get_file_name_for_best_model_fold(
            cvbt_folds_model, K, results_dir, best_init_cvbt_dict)
        hmm_params, lls = load_glmhmm_data(raw_file)

        # Calculate permutation
        permutation = calculate_state_permutation(hmm_params)
        print(permutation)

        # Save parameters for initializing individual fits
        weight_vectors = hmm_params[2][permutation]
        log_transition_matrix = permute_transition_matrix(
            hmm_params[1][0], permutation)
        init_state_dist = hmm_params[0][0][permutation]
        params_for_individual_initialization = [[init_state_dist],
                                                [log_transition_matrix],
                                                weight_vectors]

        np.savez(
            save_directory + 'best_params_K_' + str(K) + '.npz',
            params_for_individual_initialization)

        # Plot these too:
        cols = ["#e74c3c", "#15b01a", "#7e1e9c", "#3498db", "#f97306"]
        fig = plt.figure(figsize=(4 * 8, 10),
                         dpi=80,
                         facecolor='w',
                         edgecolor='k')
        plt.subplots_adjust(left=0.1,
                            bottom=0.24,
                            right=0.95,
                            top=0.7,
                            wspace=0.8,
                            hspace=0.5)
        plt.subplot(1, 3, 1)
        M = weight_vectors.shape[2] - 1
        for k in range(K):
            plt.plot(range(M + 1),
                     -weight_vectors[k][0],
                     marker='o',
                     label='State ' + str(k + 1),
                     color=cols[k],
                     lw=4)
        plt.xticks(list(range(0, len(labels_for_plot))),
                   labels_for_plot,
                   rotation='20',
                   fontsize=24)
        plt.yticks(fontsize=30)
        plt.legend(fontsize=30)
        plt.axhline(y=0, color="k", alpha=0.5, ls="--")
        # plt.ylim((-3, 14))
        plt.ylabel("Weight", fontsize=30)
        plt.xlabel("Covariate", fontsize=30, labelpad=20)
        plt.title("GLM Weights: Choice = R", fontsize=40)

        plt.subplot(1, 3, 2)
        transition_matrix = np.exp(log_transition_matrix)
        plt.imshow(transition_matrix, vmin=0, vmax=1)
        for i in range(transition_matrix.shape[0]):
            for j in range(transition_matrix.shape[1]):
                text = plt.text(j,
                                i,
                                np.around(transition_matrix[i, j],
                                          decimals=3),
                                ha="center",
                                va="center",
                                color="k",
                                fontsize=30)
        plt.ylabel("Previous State", fontsize=30)
        plt.xlabel("Next State", fontsize=30)
        plt.xlim(-0.5, K - 0.5)
        plt.ylim(-0.5, K - 0.5)
        plt.xticks(range(0, K), ('1', '2', '3', '4', '4', '5', '6', '7',
                                 '8', '9', '10')[:K],
                   fontsize=30)
        plt.yticks(range(0, K), ('1', '2', '3', '4', '4', '5', '6', '7',
                                 '8', '9', '10')[:K],
                   fontsize=30)
        plt.title("Retrieved", fontsize=40)

        plt.subplot(1, 3, 3)
        cols = [
            "#7e1e9c", "#0343df", "#15b01a", "#bf77f6", "#95d0fc",
            "#96f97b"
        ]
        cv_file = results_dir + "/cvbt_folds_model.npz"
        data_for_plotting_df, loc_best, best_val, glm_lapse_model = \
            create_cv_frame_for_plotting(
            cv_file)
        cv_file_train = results_dir + "/cvbt_train_folds_model.npz"
        train_data_for_plotting_df, train_loc_best, train_best_val, \
        train_glm_lapse_model = create_cv_frame_for_plotting(
            cv_file_train)

        glm_lapse_model_cvbt_means = np.mean(glm_lapse_model, axis=1)
        train_glm_lapse_model_cvbt_means = np.mean(train_glm_lapse_model,
                                                   axis=1)
        g = sns.lineplot(
            data_for_plotting_df['model'],
            data_for_plotting_df['cv_bit_trial'],
            err_style="bars",
            mew=0,
            color=cols[0],
            marker='o',
            ci=68,
            label="test",
            alpha=1,
            lw=4)
        sns.lineplot(
            train_data_for_plotting_df['model'],
            train_data_for_plotting_df['cv_bit_trial'],
            err_style="bars",
            mew=0,
            color=cols[1],
            marker='o',
            ci=68,
            label="train",
            alpha=1,
            lw=4)
        plt.xlabel("Model", fontsize=30)
        plt.ylabel("Normalized LL", fontsize=30)
        plt.xticks([0, 1, 2, 3, 4],
                   ['1 State', '2 State', '3 State', '4 State', '5 State'],
                   rotation=45,
                   fontsize=24)
        plt.yticks(fontsize=15)
        plt.axhline(y=glm_lapse_model_cvbt_means[2],
                    color=cols[2],
                    label="Lapse (test)",
                    alpha=0.9,
                    lw=4)
        plt.legend(loc='upper right', fontsize=30)
        plt.tick_params(axis='y')
        plt.yticks([0.2, 0.3, 0.4, 0.5], fontsize=30)
        plt.ylim((0.2, 0.55))
        plt.title("Model Comparison", fontsize=40)
        fig.tight_layout()

        fig.savefig(results_dir + 'best_params_cross_validation_K_' +
                    str(K) + '.png')