## **Global fitting summary**
---

Here we compare the models (again but more in detail and nicely). Some of the process has been already done in GLM-HMM-fitting-Jupyter notebooks, but the main aim of this notebooks is to allow users to do more exploratory analysis. In this particular notebook, we compare the performance of global fitting with two species.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
from sklearn import preprocessing
import json
import sys
import os
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import ssm

sys.path.append('../../2_fit_models/dmdm')
from data_io import get_file_dir, load_animal_list, load_cv_arr, load_data, get_file_name_for_best_glmhmm_iter, load_session_fold_lookup, load_glmhmm_data
from data_labels import create_abort_mask, partition_data_by_session
from data_postprocessing_utils import partition_data_by_session
from kfold_cv import prepare_data_for_cv
from plot_model_perform import create_cv_frame_for_plotting, plot_state_Wk, plot_state_dwelltime, calculate_predictive_accuracy
from plotting_utils import load_global_glmhmm_result, calc_dwell_time
from plot_animal_behav import plot_PC, plot_CC, plot_FArate
import matplotlib.ticker as ticker
from plot_model_perform import plot_states, plot_model_comparison, plot_state_Wk_diff

In /nfs/nhome/live/skuroda/.conda/envs/glmhmm/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: 
The text.latex.preview rcparam was deprecated in Matplotlib 3.3 and will be removed two minor releases later.
In /nfs/nhome/live/skuroda/.conda/envs/glmhmm/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: 
The mathtext.fallback_to_cm rcparam was deprecated in Matplotlib 3.3 and will be removed two minor releases later.
In /nfs/nhome/live/skuroda/.conda/envs/glmhmm/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: Support for setting the 'mathtext.fallback_to_cm' rcParam is deprecated since 3.3 and will be removed two minor releases later; use 'mathtext.fallback : 'cm' instead.
In /nfs/nhome/live/skuroda/.conda/envs/glmhmm/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: 
The validate_bool_maybe_none function was deprecated in Matplotlib 3.3 and will be removed two mi

## Model comparision with pred acc.
Here we compare the pred acc. betweem different states

In [3]:
# ------- setup variables -------
dnames = ['dataAllHumans', 'dataAllMiceTraining']
C = 3  # number of output types/categories
D = 1  # data (observations) dimension
labels_for_plot_y = ['CSize', 'COnset', 'Outcome +1', 'Outcome +2', 'Outcome +3', 'Outcome +4', 'Outcome +5', 'bias']
num_fold = 5
global_fit = True
transition_alpha = 1 # perform mle => set transition_alpha to 1
prior_sigma = 100

save_figures = True
figure_dir = get_file_dir().parents[1] / 'figures'
figure_dir.mkdir(parents=True, exist_ok=True)

K = 4 # max value
K_vals = [1, 2, 3, 4]
model = 'GLM_HMM_y'

In [4]:
dname = dnames[0]

In [5]:
data_dir =  get_file_dir().parents[1] / "data" / "dmdm" / dname / 'data_for_cluster'
results_dir = get_file_dir().parents[1] / "results" / "dmdm_global_fit" / dname

In [6]:
inpt_y, inpt_rt, y, session, rt, stim_onset = load_data(data_dir / 'all_animals_concat.npz')
session_fold_lookup_table = load_session_fold_lookup(
    data_dir / 'all_animals_concat_session_fold_lookup.npz')

y = y.astype('int')

In [7]:
with open(results_dir / "best_init_cvbt_dict_{}.json".format(model), 'r') as f:
    best_init_cvbt_dict = json.load(f)

In [8]:
best_init_cvbt_dict

{'GLM_HMM_y_K_1/fold_0/alpha_1/sigma_100/lambda_0/fold_tuning_0': 4,
 'GLM_HMM_y_K_2/fold_0/alpha_1/sigma_100/lambda_1000000/fold_tuning_0': 9,
 'GLM_HMM_y_K_3/fold_0/alpha_1/sigma_100/lambda_1000000/fold_tuning_0': 7,
 'GLM_HMM_y_K_4/fold_0/alpha_1/sigma_100/lambda_1000000/fold_tuning_0': 1,
 'GLM_HMM_y_K_1/fold_1/alpha_1/sigma_100/lambda_0/fold_tuning_1': 7,
 'GLM_HMM_y_K_2/fold_1/alpha_1/sigma_100/lambda_1000000/fold_tuning_1': 5,
 'GLM_HMM_y_K_3/fold_1/alpha_1/sigma_100/lambda_1000000/fold_tuning_1': 7,
 'GLM_HMM_y_K_4/fold_1/alpha_1/sigma_100/lambda_1000000/fold_tuning_1': 1,
 'GLM_HMM_y_K_1/fold_2/alpha_1/sigma_100/lambda_0/fold_tuning_2': 4,
 'GLM_HMM_y_K_2/fold_2/alpha_1/sigma_100/lambda_1000000/fold_tuning_2': 6,
 'GLM_HMM_y_K_3/fold_2/alpha_1/sigma_100/lambda_1000000/fold_tuning_2': 4,
 'GLM_HMM_y_K_4/fold_2/alpha_1/sigma_100/lambda_1000000/fold_tuning_2': 9,
 'GLM_HMM_y_K_1/fold_3/alpha_1/sigma_100/lambda_0/fold_tuning_3': 3,
 'GLM_HMM_y_K_2/fold_3/alpha_1/sigma_100/lambda_1

In [9]:
# cvbt_folds_model = load_cv_arr(results_dir / "cvbt_folds_model_{}.npz".format(model))
# cvbt_train_folds_model = load_cv_arr(results_dir / "cvbt_train_folds_model_{}.npz".format(model))

cvpa_folds_model = np.zeros((1, 1, len(K_vals), num_fold))
cvpa_train_folds_model = np.zeros((1, 1, len(K_vals), num_fold))
for model_idx, K in enumerate(K_vals):
    print("K = " + str(K))
    with open(results_dir / "best_init_cvbt_dict_{}.json".format(model), 'r') as f:
        best_init_cvbt_dict = json.load(f)

    # Get the file name corresponding to the best initialization for
    # each training fold for given K value
    raw_files = get_file_name_for_best_glmhmm_iter(K, results_dir,
                                                    best_init_cvbt_dict, model, 
                                                    'GLM_HMM_y_raw_parameters_itr_')
    for fold in range(num_fold):
        this_hmm_params, _, _, _, _= load_glmhmm_data(raw_files[fold])
        
        # prepare the dataset
        test_data, train_data, M_y, M_rt, n_test, n_train = \
        prepare_data_for_cv(inpt_y, inpt_rt, y, session, rt, stim_onset, 
                            session_fold_lookup_table, fold,
                            paramter_tuning=False)

        [test_inpt_y, test_inpt_rt, test_y, test_rt, test_stim_onset, test_mask, test_session] = test_data
        [train_inpt_y, train_inpt_rt, train_y, train_rt, train_stim_onset, train_mask, train_session] = train_data

        test_inpt_y = np.hstack((test_inpt_y, np.ones((len(test_inpt_y), 1))))
        train_inpt_y = np.hstack((train_inpt_y, np.ones((len(train_inpt_y), 1))))

        # For GLM-HMM set values of y for violations to 2.  This value doesn't
        # matter (as mask will ensure that these y values do not contribute to
        # loglikelihood calculation
        test_y[np.where(test_mask == 0)[0], :] = 2
        train_y[np.where(train_mask == 0)[0], :] = 2

        # For GLM-HMM, need to partition data by session
        this_test_inputs, this_test_datas, this_test_masks = \
            partition_data_by_session(
                test_inpt_y, test_y,
                test_mask,
                test_session)
        this_train_inputs, this_train_datas, this_train_masks = \
            partition_data_by_session(
                train_inpt_y, train_y,
                train_mask,
                train_session)   

        # Calculate permutation
        # init_state_dist, log_transition_matrix, weight_vectors, permutation = \
        #     calculate_state_permutation(hmm_params, K)

        this_hmm = ssm.HMM(K,
                        D,
                        M_y,
                        observations="input_driven_obs_multinominal",
                        observation_kwargs=dict(C=C, prior_sigma=prior_sigma),
                        transitions="standard")
        # transition_kwargs=dict(alpha=transition_alpha, kappa=0)

        if K ==1:
            this_hmm.observations.params = this_hmm_params
        elif K > 1:
            this_hmm.params = this_hmm_params
            
        predictive_acc_train = calculate_predictive_accuracy(this_train_inputs, 
                                                               this_train_datas, 
                                                               this_train_masks, 
                                                               this_hmm, 
                                                               train_y, 
                                                               np.where(train_mask)[0])

        predictive_acc_test = calculate_predictive_accuracy(this_test_inputs, 
                                                               this_test_datas, 
                                                               this_test_masks, 
                                                               this_hmm, 
                                                               test_y, 
                                                               np.where(test_mask)[0])

        cvpa_folds_model[:, :, model_idx, fold] = predictive_acc_test
        cvpa_train_folds_model[:, :, model_idx, fold] = predictive_acc_train

K = 1


KeyError: 'GLM_HMM_y_K_1/fold_0/alpha_1/sigma_100'

In [None]:
global_fit = True

In [None]:
plot_model_comparison(cvpa_folds_model,
                      cvpa_train_folds_model,
                      global_fit,
                      K_vals,
                      figure_dir, 'Pred. Acc. ', 'Pred_acc',)

## Each state

In [12]:
# ------- setup variables -------
dnames = ['dataAllHumans'] # , 'dataAllMiceTraining'
C = 3  # number of output types/categories
D = 1  # data (observations) dimension
labels_for_plot_y = ['CSize', 'COnset', 'Outcome +1', 'Outcome +2', 'Outcome +3', 'Outcome +4', 'Outcome +5', 'bias']

save_figures = True
figure_dir = get_file_dir().parents[1] / 'figures'
figure_dir.mkdir(parents=True, exist_ok=True)

K = 2 # 4
model = 'GLM_HMM_y'
regularization = 'L2'

In [13]:
for dname in dnames:

    data_dir =  get_file_dir().parents[1] / "data" / "dmdm" / dname / 'data_for_cluster'
    # results_dir = get_file_dir().parents[1] / "results" / "dmdm_global_fit" / dname

    states_max_posterior, _, inpt_rt, _, session, _, _, mask, hmm_params \
        = load_global_glmhmm_result(K, model, data_dir, regularization)

    inpt_y, _, y, _, rt, stim_onset = load_data(data_dir / 'all_animals_concat_unnormalized.npz')

    data = {'session': session,
            'fitted_trials': np.squeeze(mask),
            'state': states_max_posterior,
            'early_report': np.squeeze(y) == 2, # .astype(int)
            'hit': np.squeeze(y) == 1,
            'miss': np.squeeze(y) == 0,
            'abort': np.squeeze(y) == 3,
            'sig': inpt_y[:,0],
            'rt_change': np.squeeze(rt) - np.squeeze(stim_onset),
            }
    
    # Create DataFrame
    df_all = pd.DataFrame(data)

    fig = plt.figure(constrained_layout = True, figsize=(40/2.54, 4.25/2.54 * K))
    fig.suptitle('GLM_HMM_y: global {}'.format(dname))

    # create 3x1 subfigs
    subfigs = fig.subfigures(nrows=K, ncols=1)
    for row, subfig in enumerate(subfigs):
        zk = row
        subfig.suptitle('State Zk = {}'.format(zk))

        # create 1x3 subplots per subfig
        axs = subfig.subplots(nrows=1, ncols=7)

        # Plot Psychometric/Chronomteric Curve and False Alarm Rate
        plot_PC(df_all, axs[0], label='test data', K=zk)
        axs[0].set_xlabel('Change magnitude \n (octaves)')
        axs[0].set_ylabel('Proportion hits')
        axs[0].set_ylim(0, 1)
        axs[0].set_xlim(0, 2)
        axs[0].yaxis.set_major_locator(ticker.MultipleLocator(0.5))
        axs[0].xaxis.set_major_locator(ticker.MultipleLocator(1))

        plot_CC(df_all, axs[1], label='test data', K=zk)
        axs[1].set_xlabel('Change magnitude \n (octaves)')
        axs[1].set_ylabel('Reaction time (s)')
        axs[1].axis('tight')
        axs[1].set_xlim(0, 2)
        axs[1].set_ylim(0, 1.5)
        axs[1].yaxis.set_major_locator(ticker.MultipleLocator(0.5))
        axs[1].xaxis.set_major_locator(ticker.MultipleLocator(1))

        plot_FArate(df_all, axs[2], label='test data', K=zk)
        axs[2].set_xlabel('Total')
        axs[2].set_ylabel('Early report rate')
        axs[2].axis('tight')
        axs[2].set_xlim(-1, 1)
        axs[2].set_ylim(0, 0.5)
        axs[2].xaxis.set_major_locator(ticker.MultipleLocator(1))

        # Plot GLM weights
        weight_vectors = hmm_params[2][zk]
        plot_state_Wk_diff(weight_vectors,axs[3])
        # axs[3].set_xlabel('Covariate')
        axs[3].set_ylabel('Weight')
        axs[3].legend(bbox_to_anchor=(1, 1.35), 
                    ncol=3,
                    fontsize=5,
                    labelspacing=0.05,
                    framealpha=0,
                    markerscale=0)
        axs[3].axis('tight')
        axs[3].set_xticks(list(range(0, len(labels_for_plot_y))))
        axs[3].set_xticklabels(list(range(0, len(labels_for_plot_y))),
                            rotation=90)
        axs[3].set_xlim(-1, hmm_params[2][zk].shape[1])
        axs[3].set_ylim(-7, 7)

        # Plot dwell time
        dwell_across_sessions = calc_dwell_time(df_all)
        dwell_time_df = dwell_across_sessions[dwell_across_sessions['state'] == zk]
        plot_state_dwelltime(dwell_time_df, axs[4])
        axs[4].set_ylabel('# State changes')
        axs[4].set_xlabel("Dwell time \n (# trials)")
        axs[4].set_ylim(0, 25)
        axs[4].set_xlim(0, 80)
        axs[4].yaxis.set_major_locator(ticker.MultipleLocator(10))
        axs[4].xaxis.set_major_locator(ticker.MultipleLocator(20))
        
    sns.despine(fig, offset=3, trim=False)

    if save_figures:
        fig.savefig(figure_dir / 'fig_state_summary_global_{}_NumState_{}.pdf'.format(dname, str(K)))

    plt.axis('off')
    plt.close(fig)
    

In [None]:
# consider abort states_max_posterior nan

In [None]:
master_y_inpt, master_y, master_session, master_rt, master_stim_onset

In [None]:
# Each animal
for animal in animal_list:

    # Generate figure
    fig = plt.figure(figsize=(6, 6))
    plt.subplots_adjust(left=0.1,
                        bottom=0.1,
                        right=0.95,
                        top=0.95,
                        wspace=0.45,
                        hspace=0.6)
    ax1= plt.subplot() 

    # Plot normalized LL
    cols = ['#999999', '#984ea3', '#e41a1c', '#dede00']

    this_results_dir = results_2_dir / animal
    cv_arr_GLM_y = load_cv_arr(this_results_dir / "cvbt_folds_model_GLM_y.npz")
    df_GLM_y, _, _ = create_cv_frame_for_plotting(cv_arr_GLM_y)
    df_GLM_y['model'] = -1
    cv_arr_GLM_HMM_y = load_cv_arr(this_results_dir / "cvbt_folds_model_GLM_HMM_y.npz")
    df_GLM_HMM_y, _, _ = create_cv_frame_for_plotting(cv_arr_GLM_HMM_y)
    df_test = pd.concat([df_GLM_y, df_GLM_HMM_y])
    df_test['label'] = 'test'

    cv_train_arr_GLM_y = load_cv_arr(this_results_dir / "cvbt_train_folds_model_GLM_y.npz")
    df_train_GLM_y, _, _ = create_cv_frame_for_plotting(cv_train_arr_GLM_y)
    df_train_GLM_y['model'] = -1
    cv_train_arr_GLM_HMM_y = load_cv_arr(this_results_dir / "cvbt_train_folds_model_GLM_HMM_y.npz")
    df_train_GLM_HMM_y, _, _ = create_cv_frame_for_plotting(cv_train_arr_GLM_HMM_y)
    df_train = pd.concat([df_train_GLM_y, df_train_GLM_HMM_y])
    df_train['label'] = 'training'

    df_all = pd.concat([df_test, df_train])

    y_min = df_train.groupby('model')['cv_bit_trial'].min()[-1]
    df_all['cv_bit_trial'] = df_all['cv_bit_trial'] - y_min

    xrange = [-1, 0, 1, 2, 3, 4]
    meanst = df_all.groupby(['model','label'], as_index=False).agg({'cv_bit_trial': 'mean'})
    sdt = df_all.groupby(['model','label'], as_index=False).agg({'cv_bit_trial': 'std'})
    sns.lineplot(data=df_all, x="model", y="cv_bit_trial", hue="label")

    plt.xticks([-1, 0, 1, 2, 3, 4], ['GLM', '1', '2', '3', '4', '5'],
                fontsize=10)
    plt.ylabel("$\Delta$ test LL (bits/trial)", fontsize=10, labelpad=0)
    plt.xlabel("# states", fontsize=10, labelpad=0)
    plt.gca().spines['right'].set_visible(False)
    plt.gca().spines['top'].set_visible(False)
    plt.ylim((-0.01, 0.24))
    # plt.yticks(color = cols[0])
    leg = plt.legend(fontsize=10,
                        labelspacing=0.05,
                        handlelength=1.4,
                        borderaxespad=0.05,
                        borderpad=0.05,
                        framealpha=0,
                        bbox_to_anchor=(1.2, 0.90),
                        loc='lower right',
                        markerscale=0)
    for legobj in leg.legendHandles:
        legobj.set_linewidth(1.0)

    fig.suptitle('GLM_HMM_y LL: {}'.format(animal))

    if save_figures:
        fig.savefig(figure_dir / 'fig_nll_{}.pdf'.format(animal))

    plt.axis('off')
    plt.close(fig)