# **Fit GLM to dmdm data**
---
We first fit normal GLM to the dataset. One of differences between our dmdm dataset and IBL dataset is that we have two observations of animal behavior: choice outcome `y` and reaction time `rt`. Here, we fit those two observation independently by simply using two separate GLMs.

## **HPC setting**
Ashwood's original script is written in python scirpts. Here, we rewrite it in Jupyter to make it more user-friendly to run on HPC with `dask`. [This](https://github.com/pierreglaser/hpc-tutorial/tree/main) is very useful resource to get familiar with `dask`. You can skip these if you are only using one node (which is usually the case when nodes are occupied.)

In [1]:
%load_ext autoreload
%autoreload 2

In [15]:
# Allocate the computing resources
from dask_jobqueue import SLURMCluster
from distributed import Client
from joblib import Memory, Parallel, delayed, parallel_backend
from threadpoolctl import threadpool_limits

cluster = SLURMCluster(
    workers=0,      # create the workers "lazily" (upon cluster.scal)
    memory='32g',   # amount of RAM per worker
    processes=1,    # number of execution units per worker (threads and processes)
    cores=1,        # among those execution units, number of processes
    # a lazy trick to avoid matplotlib crash with parallel plotting
    walltime="48:00:00",
    worker_extra_args=["--resources GPU=1"], # the only way to add GPUs
    local_directory='/nfs/nhome/live/skuroda/jobs', # set your path to save log
    log_directory='/nfs/nhome/live/skuroda/jobs' # set your path to save log
)   

memory = Memory('/nfs/nhome/live/skuroda/joblib-cache') # set your path

n = 4
cluster.scale(n)
client = Client(cluster)
print(client.dashboard_link)

http://192.168.234.51:8787/status


## **GLM fitting**
---
At this step, we remove abort trials from the model. Abort trials are hard to predict and often ends up with unstable weights, which is the last thing we want to have. Hence the number of behavioral outcomes: `C = 3`.

In [3]:
# ------- load modules -------
import autograd.numpy as np
import autograd.numpy.random as npr

from glm_utils import fit_glm, fit_glm_runml, fit_RT_glm, \
                      plot_input_vectors, plot_logOR_hit_vs_miss, plot_rt_weights, plot_lls, plot_logOR_hit_vs_FA
import sys
sys.path.insert(0, '../') # a lazy trick to search parent dir
# https://stackoverflow.com/questions/34478398/import-local-function-from-a-module-housed-in-another-directory-with-relative-im
from data_io import get_file_dir, load_session_fold_lookup, load_data, load_animal_list
from data_labels import create_abort_mask
from functools import partial
from collections import OrderedDict

In /nfs/nhome/live/skuroda/.conda/envs/glmhmm/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: 
The text.latex.preview rcparam was deprecated in Matplotlib 3.3 and will be removed two minor releases later.
In /nfs/nhome/live/skuroda/.conda/envs/glmhmm/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: 
The mathtext.fallback_to_cm rcparam was deprecated in Matplotlib 3.3 and will be removed two minor releases later.
In /nfs/nhome/live/skuroda/.conda/envs/glmhmm/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: Support for setting the 'mathtext.fallback_to_cm' rcParam is deprecated since 3.3 and will be removed two minor releases later; use 'mathtext.fallback : 'cm' instead.
In /nfs/nhome/live/skuroda/.conda/envs/glmhmm/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: 
The validate_bool_maybe_none function was deprecated in Matplotlib 3.3 and will be removed two mi

In [4]:
# ------- setup variables -------
dname = 'dataAllMiceTraining'   # 'dataAllHumans' 'dataAllMiceTraining' 
num_folds = num_folds_training = num_folds_tuning = 5

C = 3  # number of output types/categories. Hit/FA/Miss
nested_outcome = OrderedDict() # define nested structure for behavioral outcomes
nested_outcome["Baseline"] = [2]
nested_outcome["Change"] = [0, 1]
ridge_lambda = [0, 0.5, 1, 5, 10, 50, 100] # for ridge regression

N_initializations = 10
labels_for_plot_y = ['CSize', 'TempExp', 'HzrdBlck', 'PrevHit +1', 'PrevFA +1', 'bias']
labels_for_plot_rt = ['CSize', 'COnset', 
                      'PrevHowDeviant?',
                      'PrevCOnset', 'PrevRT','bias']
labels_for_plot = {'y':labels_for_plot_y, 'rt':labels_for_plot_rt}
regularization = 'L2'
npr.seed(65)  # set seed in case of randomization

## **Fit GLM to all animals**
---
We fit GLMs on both `y` (outcomes) and `rt` (reaction times). In the outcome-predicting GLM, each dependent variables is assumed to be generated from a multinominal distribution. On the other hand, reaction-time-predicting GLM expects dependent variables to follow a gaussian distribution.

In [5]:
# ------- setup path and load data -------
data_dir =  get_file_dir().parents[1] / "data" / "dmdm" / dname / 'data_for_cluster'
# Create directory for results:
results_dir = get_file_dir().parents[1] / "results" / "dmdm_global_fit" / dname
results_dir.mkdir(parents=True, exist_ok=True)

animal_file = data_dir / 'all_animals_concat.npz'
inpt_y, inpt_rt, y, session, rt, stim_onset = load_data(animal_file)
session_fold_lookup_table = load_session_fold_lookup(
    data_dir / 'all_animals_concat_session_fold_lookup.npz')

In [6]:
def fit_GLM_y(inpt_y, y, session, session_fold_lookup_table, results_dir, labels_for_plot, 
               nested_outcome, regularization, ridge_lambda, num_folds_tuning, fold):
    # Subset to relevant covariates for covar set of interest:
    y = y.astype('int')
    figure_directory = results_dir / "GLM" / ("fold_" + str(fold)) 
    figure_directory.mkdir(parents=True, exist_ok=True)

    if regularization is None:
        num_folds_tuning = 1

    for fold_tuning in range(num_folds_tuning):
        # Subset to sessions of interest for fold
        if regularization == 'L2':
            split_loc = np.logical_and(session_fold_lookup_table[:, 1] != fold, session_fold_lookup_table[:, 2] != fold_tuning)
            sessions_to_keep = session_fold_lookup_table[split_loc, 0]
        elif regularization is None:
            sessions_to_keep = session_fold_lookup_table[np.where(
                session_fold_lookup_table[:, 1] != fold), 0]
        
        idx_this_fold = [
            str(sess) in sessions_to_keep for id, sess in enumerate(session)
        ]
        this_inpt_y, this_y, this_session = inpt_y[idx_this_fold, :], y[
            idx_this_fold, :], session[idx_this_fold]
        train_size = this_inpt_y.shape[0]

        # Identify abort trials for exclusion:
        abort_idx = np.where(this_y == 3)[0]
        nonviolation_idx, mask = create_abort_mask(abort_idx, this_inpt_y.shape[0])

        M = this_inpt_y.shape[1]
        loglikelihood_train_vector = []

        tuning_directory = figure_directory / ("foldtuning_" + str(fold_tuning)) 
        tuning_directory.mkdir(parents=True, exist_ok=True)

        for l2_penalty in ridge_lambda:
            for iter in range(N_initializations):  # GLM fitting should be
                # independent of initialization, so fitting multiple
                # initializations is a good way to check that everything is
                # working correctly

                # Only do this so that errors are avoided - these y values will not
                # actually be used for anything (due to violation mask)
                this_y[np.where(this_y == 3), :] = 2

                loglikelihood_train, recovered_weights, fit_ll = fit_glm([this_inpt_y], # runml
                                                                        [this_y], 
                                                                        M, 
                                                                        C,
                                                                        [mask],
                                                                        nested_outcome,
                                                                        regularization=regularization,
                                                                        l2_penalty=l2_penalty)
                plot_input_vectors(recovered_weights,
                                tuning_directory,
                                title="GLM fit; Final LL = " +
                                str(loglikelihood_train),
                                save_title='init' + str(iter) + 'l' + str(l2_penalty),
                                labels_for_plot=labels_for_plot)
                plot_logOR_hit_vs_miss(recovered_weights,
                                    tuning_directory,
                                    title="GLM fit; Final LL = " +
                                    str(loglikelihood_train),
                                    save_title='init' + str(iter) + 'l' + str(l2_penalty),
                                    labels_for_plot=labels_for_plot)
                plot_logOR_hit_vs_FA(recovered_weights,
                                    tuning_directory,
                                    title="GLM fit; Final LL = " +
                                    str(loglikelihood_train),
                                    save_title='init' + str(iter) + 'l' + str(l2_penalty),
                                    labels_for_plot=labels_for_plot)

                plot_lls(fit_ll, tuning_directory, save_title='y_init' + str(iter) + 'l' + str(l2_penalty))
                loglikelihood_train_vector.append(loglikelihood_train)
                np.savez(
                    tuning_directory / ('GLM_y_variables_of_interest_iter_' + str(iter) + \
                                        '_l' + str(l2_penalty)+ '.npz'), 
                    loglikelihood_train, recovered_weights)
            

fit_GLM_y_eachfold = partial(fit_GLM_y, inpt_y, y, session, session_fold_lookup_table, results_dir, labels_for_plot['y'],
                              nested_outcome,regularization, ridge_lambda, num_folds_tuning)        
fit_GLM_y_eachfold_cached = memory.cache(fit_GLM_y_eachfold)

In [7]:
with Client(cluster) as client: # upload local functions to each worker. They cannot read them with sys.append or sys.insert.
    client.wait_for_workers(n)
    client.upload_file(str(get_file_dir() / 'data_labels.py'))

In [8]:
%%time

with threadpool_limits(limits=1, user_api='blas'):
    with parallel_backend('dask', wait_for_workers_timeout=120):
        Parallel(verbose=100)(delayed(fit_GLM_y_eachfold)(fold) for fold in range(num_folds))
        
        # using cached function should improve the performance to some extent...
        # I am not using it because it hides the name of ongoing process in Dask's progress bar

[Parallel(n_jobs=-1)]: Using backend DaskDistributedBackend with 4 concurrent workers.
[Parallel(n_jobs=-1)]: Done   1 tasks      | elapsed: 85.9min
[Parallel(n_jobs=-1)]: Done   2 out of   5 | elapsed: 86.4min remaining: 129.5min
[Parallel(n_jobs=-1)]: Done   3 out of   5 | elapsed: 86.5min remaining: 57.6min
[Parallel(n_jobs=-1)]: Done   5 out of   5 | elapsed: 130.7min remaining:    0.0s
[Parallel(n_jobs=-1)]: Done   5 out of   5 | elapsed: 130.7min finished
CPU times: user 6min 35s, sys: 41.1 s, total: 7min 16s
Wall time: 2h 10min 40s


In [9]:
# Once finished, shut down the cluster and the client
cluster.close()
client.close()

  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone


distributed.client - ERROR - Failed to reconnect to scheduler after 30.00 seconds, closing client
_GatheringFuture exception was never retrieved
future: <_GatheringFuture finished exception=CancelledError()>
concurrent.futures._base.CancelledError


## L2 parameter hypertuning

In [10]:
from kfold_cv import KFoldCV
from data_io import load_data, load_session_fold_lookup, load_glmhmm_data, load_cv_arr, get_file_name_for_best_glmhmm_fold, get_best_map_params
from data_postprocessing_utils import calculate_state_permutation
sys.path.append('../../../3_make_figures/dmdm/')
from plot_model_perform import plot_states, plot_model_comparison, plot_state_occupancy
import json

In [11]:
result = []

for fold in range(num_folds_training):
    # get best lambda given fold for training

    results_dir_thid_fold = results_dir / 'GLM' / ('fold_' + str(fold))

    KFCV = KFoldCV('GLM_y', num_folds_tuning, Lambda_vals=ridge_lambda,
                   results_dir=results_dir_thid_fold)
    out = KFCV.save_best_L2params(inpt_y, inpt_rt, y, session, rt, stim_onset,
                                    session_fold_lookup_table, fold, C, 
                                    outcome_dict=nested_outcome, save_output=True)
    
    result.append([fold, out[0][0]])

Retrieving best iter results for model = GLM_y; fold = 0; num_folds = 5


  0%|          | 0/5 [00:00<?, ?it/s]

(24881, 5)
(97432, 5)
(24881, 5)
(97432, 5)
(24881, 5)
(97432, 5)
(24881, 5)
(97432, 5)
(24881, 5)
(97432, 5)
(24881, 5)
(97432, 5)


 20%|██        | 1/5 [00:14<00:58, 14.67s/it]

(24881, 5)
(97432, 5)
(24154, 5)
(98159, 5)
(24154, 5)
(98159, 5)
(24154, 5)
(98159, 5)
(24154, 5)
(98159, 5)
(24154, 5)
(98159, 5)
(24154, 5)
(98159, 5)
(24154, 5)


 40%|████      | 2/5 [00:27<00:40, 13.57s/it]

(98159, 5)
(25009, 5)
(97304, 5)
(25009, 5)
(97304, 5)
(25009, 5)
(97304, 5)
(25009, 5)
(97304, 5)


 60%|██████    | 3/5 [00:40<00:26, 13.36s/it]

(25009, 5)
(97304, 5)
(25009, 5)
(97304, 5)
(25009, 5)
(97304, 5)
(23964, 5)
(98349, 5)
(23964, 5)
(98349, 5)
(23964, 5)


 80%|████████  | 4/5 [00:55<00:14, 14.12s/it]

(98349, 5)
(23964, 5)
(98349, 5)
(23964, 5)
(98349, 5)
(23964, 5)
(98349, 5)
(23964, 5)
(98349, 5)
(24305, 5)
(98008, 5)
(24305, 5)
(98008, 5)
(24305, 5)
(98008, 5)
(24305, 5)
(98008, 5)
(24305, 5)
(98008, 5)
(24305, 5)
(98008, 5)
(24305, 5)


100%|██████████| 5/5 [01:08<00:00, 13.72s/it]


(98008, 5)
Calculating best L2 parameters...
[0.15816259 0.15793576 0.15775197 0.15657726 0.15537008 0.14849962
 0.14255741]
(0,)
Retrieving best iter results for model = GLM_y; fold = 1; num_folds = 5


  0%|          | 0/5 [00:00<?, ?it/s]

(25532, 5)
(97027, 5)
(25532, 5)
(97027, 5)
(25532, 5)
(97027, 5)
(25532, 5)
(97027, 5)
(25532, 5)
(97027, 5)
(25532, 5)
(97027, 5)


 20%|██        | 1/5 [00:13<00:52, 13.03s/it]

(25532, 5)
(97027, 5)
(25384, 5)
(97175, 5)
(25384, 5)
(97175, 5)
(25384, 5)
(97175, 5)
(25384, 5)
(97175, 5)
(25384, 5)
(97175, 5)


 40%|████      | 2/5 [00:26<00:39, 13.21s/it]

(25384, 5)
(97175, 5)
(25384, 5)
(97175, 5)
(24760, 5)
(97799, 5)
(24760, 5)
(97799, 5)
(24760, 5)
(97799, 5)
(24760, 5)
(97799, 5)


 60%|██████    | 3/5 [00:42<00:28, 14.35s/it]

(24760, 5)
(97799, 5)
(24760, 5)
(97799, 5)
(24760, 5)
(97799, 5)
(23597, 5)
(98962, 5)
(23597, 5)
(98962, 5)
(23597, 5)
(98962, 5)
(23597, 5)
(98962, 5)
(23597, 5)
(98962, 5)


 80%|████████  | 4/5 [00:56<00:14, 14.57s/it]

(23597, 5)
(98962, 5)
(23597, 5)
(98962, 5)
(23286, 5)
(99273, 5)
(23286, 5)
(99273, 5)
(23286, 5)
(99273, 5)
(23286, 5)
(99273, 5)
(23286, 5)


100%|██████████| 5/5 [01:12<00:00, 14.41s/it]


(99273, 5)
(23286, 5)
(99273, 5)
(23286, 5)
(99273, 5)
Calculating best L2 parameters...
[0.15844515 0.15821207 0.15802335 0.15681838 0.15558214 0.14857927
 0.14255464]
(0,)
Retrieving best iter results for model = GLM_y; fold = 2; num_folds = 5


  0%|          | 0/5 [00:00<?, ?it/s]

(25730, 5)
(96840, 5)
(25730, 5)
(96840, 5)
(25730, 5)
(96840, 5)
(25730, 5)
(96840, 5)


 20%|██        | 1/5 [00:13<00:55, 13.84s/it]

(25730, 5)
(96840, 5)
(25730, 5)
(96840, 5)
(25730, 5)
(96840, 5)
(25087, 5)
(97483, 5)
(25087, 5)
(97483, 5)
(25087, 5)
(97483, 5)
(25087, 5)
(97483, 5)
(25087, 5)
(97483, 5)
(25087, 5)


 40%|████      | 2/5 [00:26<00:40, 13.42s/it]

(97483, 5)
(25087, 5)
(97483, 5)
(24963, 5)
(97607, 5)
(24963, 5)
(97607, 5)
(24963, 5)
(97607, 5)
(24963, 5)
(97607, 5)
(24963, 5)


 60%|██████    | 3/5 [00:39<00:25, 12.79s/it]

(97607, 5)
(24963, 5)
(97607, 5)
(24963, 5)
(97607, 5)
(23667, 5)
(98903, 5)
(23667, 5)
(98903, 5)
(23667, 5)
(98903, 5)
(23667, 5)
(98903, 5)
(23667, 5)
(98903, 5)
(23667, 5)
(98903, 5)
(23667, 5)


 80%|████████  | 4/5 [00:51<00:12, 12.51s/it]

(98903, 5)
(23123, 5)
(99447, 5)
(23123, 5)
(99447, 5)
(23123, 5)
(99447, 5)
(23123, 5)
(99447, 5)
(23123, 5)


100%|██████████| 5/5 [01:03<00:00, 12.72s/it]


(99447, 5)
(23123, 5)
(99447, 5)
(23123, 5)
(99447, 5)
Calculating best L2 parameters...
[0.1575055  0.15727521 0.15708753 0.15587857 0.15462763 0.14746109
 0.14124663]
(0,)
Retrieving best iter results for model = GLM_y; fold = 3; num_folds = 5


  0%|          | 0/5 [00:00<?, ?it/s]

(25495, 5)
(97061, 5)
(25495, 5)
(97061, 5)
(25495, 5)
(97061, 5)
(25495, 5)
(97061, 5)
(25495, 5)
(97061, 5)
(25495, 5)
(97061, 5)


 20%|██        | 1/5 [00:12<00:50, 12.52s/it]

(25495, 5)
(97061, 5)
(24692, 5)
(97864, 5)
(24692, 5)
(97864, 5)
(24692, 5)
(97864, 5)
(24692, 5)
(97864, 5)
(24692, 5)


 40%|████      | 2/5 [00:25<00:37, 12.52s/it]

(97864, 5)
(24692, 5)
(97864, 5)
(24692, 5)
(97864, 5)
(24911, 5)
(97645, 5)
(24911, 5)
(97645, 5)
(24911, 5)
(97645, 5)
(24911, 5)
(97645, 5)
(24911, 5)
(97645, 5)
(24911, 5)
(97645, 5)
(24911, 5)


 60%|██████    | 3/5 [00:37<00:24, 12.28s/it]

(97645, 5)
(24095, 5)
(98461, 5)
(24095, 5)
(98461, 5)
(24095, 5)


 80%|████████  | 4/5 [00:49<00:12, 12.34s/it]

(98461, 5)
(24095, 5)
(98461, 5)
(24095, 5)
(98461, 5)
(24095, 5)
(98461, 5)
(24095, 5)
(98461, 5)
(23363, 5)
(99193, 5)
(23363, 5)
(99193, 5)
(23363, 5)
(99193, 5)
(23363, 5)
(99193, 5)
(23363, 5)
(99193, 5)
(23363, 5)
(99193, 5)
(23363, 5)


100%|██████████| 5/5 [01:02<00:00, 12.49s/it]


(99193, 5)
Calculating best L2 parameters...
[0.16010783 0.15988597 0.15970595 0.15855306 0.15736639 0.15060722
 0.14476039]
(0,)
Retrieving best iter results for model = GLM_y; fold = 4; num_folds = 5


  0%|          | 0/5 [00:00<?, ?it/s]

(26122, 5)
(99020, 5)
(26122, 5)
(99020, 5)
(26122, 5)
(99020, 5)
(26122, 5)
(99020, 5)


 20%|██        | 1/5 [00:13<00:52, 13.21s/it]

(26122, 5)
(99020, 5)
(26122, 5)
(99020, 5)
(26122, 5)
(99020, 5)
(25571, 5)
(99571, 5)
(25571, 5)
(99571, 5)
(25571, 5)
(99571, 5)
(25571, 5)


 40%|████      | 2/5 [00:26<00:40, 13.52s/it]

(99571, 5)
(25571, 5)
(99571, 5)
(25571, 5)
(99571, 5)
(25571, 5)
(99571, 5)
(25021, 5)
(100121, 5)
(25021, 5)
(100121, 5)
(25021, 5)
(100121, 5)
(25021, 5)
(100121, 5)
(25021, 5)
(100121, 5)
(25021, 5)
(100121, 5)


 60%|██████    | 3/5 [00:42<00:28, 14.43s/it]

(25021, 5)
(100121, 5)
(24233, 5)
(100909, 5)
(24233, 5)
(100909, 5)
(24233, 5)
(100909, 5)
(24233, 5)
(100909, 5)
(24233, 5)


 80%|████████  | 4/5 [00:58<00:15, 15.16s/it]

(100909, 5)
(24233, 5)
(100909, 5)
(24233, 5)
(100909, 5)
(24195, 5)
(100947, 5)
(24195, 5)
(100947, 5)
(24195, 5)
(100947, 5)
(24195, 5)
(100947, 5)
(24195, 5)


100%|██████████| 5/5 [01:14<00:00, 14.94s/it]

(100947, 5)
(24195, 5)
(100947, 5)
(24195, 5)
(100947, 5)
Calculating best L2 parameters...
[0.15858716 0.15836562 0.15818556 0.15703031 0.15583932 0.14903737
 0.14313403]
(0,)





In [12]:
np.savez(data_dir / "best_l2_params_model_GLM_y.npz", np.array(result))
print("Saved all the L2 params!")
print(np.array(result))

Saved all the L2 params!
[[0 0]
 [1 0]
 [2 0]
 [3 0]
 [4 0]]


In [13]:
best_l2_params = np.array(result)

## Use best L2 hyperparameter and fit again

In [14]:
def fit_GLM_y_bestL2(inpt_y, y, session, session_fold_lookup_table, results_dir, labels_for_plot, 
               nested_outcome, best_l2_params, fold):
    # Subset to relevant covariates for covar set of interest:
    y = y.astype('int')
    figure_directory = results_dir / "GLM" / ("fold_" + str(fold)) 

    sessions_to_keep = session_fold_lookup_table[np.where(
        session_fold_lookup_table[:, 1] != fold), 0]
    
    idx_this_fold = [
        str(sess) in sessions_to_keep for id, sess in enumerate(session)
    ]
    this_inpt_y, this_y, this_session = inpt_y[idx_this_fold, :], y[
        idx_this_fold, :], session[idx_this_fold]
    train_size = this_inpt_y.shape[0]

    # Identify abort trials for exclusion:
    abort_idx = np.where(this_y == 3)[0]
    nonviolation_idx, mask = create_abort_mask(abort_idx, this_inpt_y.shape[0])

    M = this_inpt_y.shape[1]
    loglikelihood_train_vector = []
    l2_penalty = best_l2_params[fold,1]

    for iter in range(N_initializations):  # GLM fitting should be
        # independent of initialization, so fitting multiple
        # initializations is a good way to check that everything is
        # working correctly

        # Only do this so that errors are avoided - these y values will not
        # actually be used for anything (due to violation mask)
        this_y[np.where(this_y == 3), :] = 2

        loglikelihood_train, recovered_weights, fit_ll = fit_glm([this_inpt_y], # runml
                                                                [this_y], 
                                                                M, 
                                                                C,
                                                                [mask],
                                                                nested_outcome,
                                                                regularization='L2',
                                                                l2_penalty=l2_penalty)
        plot_input_vectors(recovered_weights,
                        figure_directory,
                        title="GLM fit; Final LL = " +
                        str(loglikelihood_train),
                        save_title='init' + str(iter) + 'l' + str(l2_penalty),
                        labels_for_plot=labels_for_plot)
        plot_logOR_hit_vs_miss(recovered_weights,
                            figure_directory,
                            title="GLM fit; Final LL = " +
                            str(loglikelihood_train),
                            save_title='init' + str(iter) + 'l' + str(l2_penalty),
                            labels_for_plot=labels_for_plot)
        plot_logOR_hit_vs_FA(recovered_weights,
                            figure_directory,
                            title="GLM fit; Final LL = " +
                            str(loglikelihood_train),
                            save_title='init' + str(iter) + 'l' + str(l2_penalty),
                            labels_for_plot=labels_for_plot)

        plot_lls(fit_ll, figure_directory, save_title='y_init' + str(iter) + 'l' + str(l2_penalty))
        loglikelihood_train_vector.append(loglikelihood_train)
        np.savez(
            figure_directory / ('GLM_y_variables_of_interest_iter_' + str(iter) + \
                                '_l' + str(l2_penalty)+ '.npz'), 
            loglikelihood_train, recovered_weights)
            

fit_GLM_y_bestL2_eachfold = partial(fit_GLM_y_bestL2, inpt_y, y, session, session_fold_lookup_table, results_dir, labels_for_plot['y'],
                              nested_outcome,best_l2_params)        
fit_GLM_y_bestL2_eachfold_cached = memory.cache(fit_GLM_y_bestL2_eachfold)

In [16]:
with Client(cluster) as client: # upload local functions to each worker. They cannot read them with sys.append or sys.insert.
    client.wait_for_workers(n)
    client.upload_file(str(get_file_dir() / 'data_labels.py'))

In [17]:
%%time

with threadpool_limits(limits=1, user_api='blas'):
    with parallel_backend('dask', wait_for_workers_timeout=120):
        Parallel(verbose=100)(delayed(fit_GLM_y_bestL2_eachfold_cached)(fold) for fold in range(num_folds))
        
        # using cached function should improve the performance to some extent...
        # I am not using it because it hides the name of ongoing process in Dask's progress bar

[Parallel(n_jobs=-1)]: Using backend DaskDistributedBackend with 4 concurrent workers.
[Parallel(n_jobs=-1)]: Done   1 tasks      | elapsed:  3.4min
[Parallel(n_jobs=-1)]: Done   2 out of   5 | elapsed:  3.4min remaining:  5.1min
[Parallel(n_jobs=-1)]: Done   3 out of   5 | elapsed:  3.6min remaining:  2.4min
[Parallel(n_jobs=-1)]: Done   5 out of   5 | elapsed:  6.2min remaining:    0.0s
[Parallel(n_jobs=-1)]: Done   5 out of   5 | elapsed:  6.2min finished
CPU times: user 14.5 s, sys: 4.78 s, total: 19.3 s
Wall time: 6min 10s


In [19]:
# Once finished, shut down the cluster and the client
cluster.close()
client.close()

## **Fit GLM to each animal separately**
---
We next fit GLMs to each animal. Each animal's behavior is different from each other, and there is always a chance that GLM weights end up being also different between animals.

In [8]:
# ------- setup path and load data -------
data_2_dir =  get_file_dir().parents[1] / "data" / "dmdm" / dname / 'data_for_cluster' / "data_by_animal"
# Create directory for results:
results_2_dir = get_file_dir().parents[1] / "results" / "dmdm_individual_fit" / dname
results_2_dir.mkdir(parents=True, exist_ok=True)

animal_list = load_animal_list(data_2_dir / 'animal_list.npz')

In [9]:
def fit_GLM_separately(data_2_dir, results_2_dir, labels_for_plot: dict, nested_outcome, num_folds, animal):
    # Fit GLM to data from single animal:
    animal_file = data_2_dir / (animal + '_processed.npz')
    session_fold_lookup_table = load_session_fold_lookup(
        data_2_dir / (animal + '_session_fold_lookup.npz'))
    this_results_dir = results_2_dir / animal

    # Load data
    print(str(animal_file))
    inpt_y, inpt_rt, y, session, rt, stim_onset = load_data(animal_file)

    for fold in range(num_folds):
        fit_GLM_y(inpt_y, y, session, session_fold_lookup_table, this_results_dir, labels_for_plot['y'], nested_outcome, fold)
        # fit_GLM_rt(inpt_rt, rt, stim_onset, y, session, session_fold_lookup_table, this_results_dir, labels_for_plot['rt'], nested_outcome, fold)
            
fit_GLM_separately_eachanimal = partial(fit_GLM_separately, data_2_dir, results_2_dir, labels_for_plot, nested_outcome, num_folds)     
fit_GLM_separately_eachanimal_cached = memory.cache(fit_GLM_separately_eachanimal)


In [10]:
with Client(cluster) as client: # upload local functions to each worker. They cannot read them with sys.append or sys.insert.
    client.wait_for_workers(n)
    client.upload_file(str(get_file_dir() / 'data_io.py'))
    client.upload_file(str(get_file_dir() / 'data_labels.py'))

In [None]:
%%time
with threadpool_limits(limits=1, user_api='blas'):
    with parallel_backend('dask', wait_for_workers_timeout=120):
        Parallel(verbose=100)(delayed(fit_GLM_separately_eachanimal_cached)(animal) for animal in animal_list)

In [None]:
# Once finished, shut down the cluster and the client
cluster.close()
client.close()

In [None]:
from kfold_cv import KFoldCV
from data_io import load_data, load_session_fold_lookup, load_glmhmm_data, load_cv_arr, get_file_name_for_best_glmhmm_fold, get_best_map_params
from data_postprocessing_utils import calculate_state_permutation
sys.path.append('../../../3_make_figures/dmdm/')
from plot_model_perform import plot_states, plot_model_comparison, plot_state_occupancy
import json

model = 'GLM_y'
print('Animals for individual fitting: {}'.format(animal_list))

In [None]:
for animal in animal_list:
    print(animal)
    this_results_dir = results_2_dir / animal

    animal_file = data_2_dir / (animal + '_processed.npz')
    session_fold_lookup_table = load_session_fold_lookup(
        data_2_dir / (animal + '_session_fold_lookup.npz'))
    inpt_y, inpt_rt, y, session, rt, stim_onset = load_data(animal_file)

    KFCV = KFoldCV(model, num_folds, results_dir=this_results_dir, animal=animal)
    KFCV.save_best_iter(inpt_y, inpt_rt, y, session, rt, stim_onset,
                         session_fold_lookup_table, C, outcome_dict=nested_outcome)
