# **Fit GLM to IBL data**
---
We first fit normal GLM to the dataset.

## **HPC setting**
Ashwood's original script is written in python scirpts. Here, we rewrite it in Jupyter to make it more user-friendly to run on HPC with `dask`. [This](https://github.com/pierreglaser/hpc-tutorial/tree/main) is very useful resource to get familiar with `dask`.

In [40]:
# allocate the computing resources
from dask_jobqueue import SLURMCluster
from distributed import Client
from joblib import Memory, Parallel, delayed, parallel_backend
from threadpoolctl import threadpool_limits

cluster = SLURMCluster(
    workers=0,      # create the workers "lazily" (upon cluster.scal)
    memory='32g',   # amount of RAM per worker
    processes=1,    # number of execution units per worker (threads and processes)
    cores=4,        # among those execution units, number of processes
    worker_extra_args=["--resources GPU=2"], # the only way to add GPUs
    local_directory='/nfs/nhome/live/skuroda/jobs', # set your path to save log
    log_directory='/nfs/nhome/live/skuroda/jobs' # set your path to save log
)   

memory = Memory('/nfs/nhome/live/skuroda/joblib-cache') # set your path

cluster.scale(2)
client = Client(cluster)
client

  from distributed.utils import format_bytes, parse_bytes, tmpfile, get_ip_interface
  from distributed.utils import format_bytes, parse_bytes, tmpfile, get_ip_interface
  from distributed.utils import format_bytes, parse_bytes, tmpfile, get_ip_interface
  from distributed.utils import parse_bytes
Perhaps you already have a cluster running?
Hosting the HTTP server on port 39885 instead
  f"Port {expected} is already in use.\n"


0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.SLURMCluster
Dashboard: http://192.168.234.51:39885/status,

0,1
Dashboard: http://192.168.234.51:39885/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://192.168.234.51:35345,Workers: 0
Dashboard: http://192.168.234.51:39885/status,Total threads: 0
Started: Just now,Total memory: 0 B


## **Fit GLM to all animals**
---

In [41]:
# ------- load modules -------
import autograd.numpy as np
import autograd.numpy.random as npr
from glm_utils import load_session_fold_lookup, load_data, fit_glm, \
    plot_input_vectors, append_zeros
import os

In [42]:
# ------- setup variables -------
C = 2  # number of output types/categories
N_initializations = 10
num_folds = 5
npr.seed(65)  # set seed in case of randomization

In [43]:
# ------- setup path and load data -------
data_dir = '../../data/ibl/data_for_cluster/'
# Create directory for results:
results_dir = '../../results/ibl_global_fit/'
if not os.path.exists(results_dir):
    os.makedirs(results_dir)

animal_file = data_dir + 'all_animals_concat.npz'
inpt, y, session = load_data(animal_file)
session_fold_lookup_table = load_session_fold_lookup(
    data_dir + 'all_animals_concat_session_fold_lookup.npz')

In [44]:
os.getcwd()

'/nfs/nhome/live/skuroda/Workstation2023/glm-hmm/2_fit_models/fit_glm'

In [45]:
def fit_GLM(fold,inpt,y,session,session_fold_lookup_table):
    # Subset to relevant covariates for covar set of interest:
    labels_for_plot = ['stim', 'P_C', 'WSLS', 'bias']
    y = y.astype('int')
    figure_directory = results_dir + "GLM/fold_" + str(fold) + '/'
    if not os.path.exists(figure_directory):
        os.makedirs(figure_directory)

    # Subset to sessions of interest for fold
    sessions_to_keep = session_fold_lookup_table[np.where(
        session_fold_lookup_table[:, 1] != fold), 0]
    idx_this_fold = [
        str(sess) in sessions_to_keep and y[id, 0] != -1
        for id, sess in enumerate(session)
    ]
    this_inpt, this_y, this_session = inpt[idx_this_fold, :], y[
        idx_this_fold, :], session[idx_this_fold]
    assert len(
        np.unique(this_y)
    ) == 2, "choice vector should only include 2 possible values"
    train_size = this_inpt.shape[0]

    M = this_inpt.shape[1]
    loglikelihood_train_vector = []

    for iter in range(N_initializations):  # GLM fitting should be
        # independent of initialization, so fitting multiple
        # initializations is a good way to check that everything is
        # working correctly
        loglikelihood_train, recovered_weights = fit_glm([this_inpt],
                                                            [this_y], M, C)
        weights_for_plotting = append_zeros(recovered_weights)
        plot_input_vectors(weights_for_plotting,
                            figure_directory,
                            title="GLM fit; Final LL = " +
                            str(loglikelihood_train),
                            save_title='init' + str(iter),
                            labels_for_plot=labels_for_plot)
        loglikelihood_train_vector.append(loglikelihood_train)
        np.savez(
            figure_directory + 'variables_of_interest_iter_' + str(iter) +
            '.npz', loglikelihood_train, recovered_weights)

In [46]:
%%time
fit_GLM_cached = memory.cache(fit_GLM)
with threadpool_limits(limits=1, user_api='blas'):
    with parallel_backend('dask'):
        Parallel()(delayed(fit_GLM_cached)(fold,inpt,y,session,session_fold_lookup_table) for fold in range(num_folds))

CPU times: user 6.16 s, sys: 1.26 s, total: 7.42 s
Wall time: 2min 2s


  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone
  with ignoring(RuntimeError):  # deleting job when job already gone


In [None]:
# Once finished, shut down the cluster and the client
cluster.close()
client.close()