In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import numpy.random as npr
import pandas as pd
import matplotlib.pyplot as plt
import ssm
from sklearn import preprocessing
from sklearn.model_selection import KFold
from scipy import stats

from notebooks.imports import *
from config import dir_config, main_config
from src.utils.glm_hmm_utils import *
import pickle
import copy

### Configuration

In [3]:
from config import dir_config, main_config

raw_dir = Path(dir_config.data.raw)
processed_dir = Path(dir_config.data.processed)

metadata = pd.read_csv(Path(processed_dir, "processed_metadata_accu_60.csv"))
data = pd.read_csv(Path(processed_dir, "processed_data_accu_60_all.csv"))

In [4]:
experiment_sites = ["Stanford"]

metadata = metadata[metadata['experiment_site'].isin(experiment_sites)].reset_index(drop=True)
data = data[data['subject_id'].isin(metadata['subject_id'])].reset_index(drop=True)

# add session_id to data with matching subject_id and medication
metadata['session_id'] = metadata[['subject_id', 'treatment']].apply(lambda x: '_'.join(x.astype(str).str.upper()), axis=1)
data['session_id'] = data[['subject_id', 'medication']].apply(lambda x: '_'.join(x.astype(str).str.upper()), axis=1)
data.choice = data.choice.fillna(-1).astype(int)
data.target = data.target.fillna(-1).astype(int)
data.outcome = data.outcome.fillna(-1).astype(int)

### Helper Functions

In [5]:
def extract_previous_data(trial_data, invalid_idx):
    npr.seed(1)
    prev_choice = np.hstack([trial_data.choice[0], trial_data.choice[:-1]])  # 0:awayPrior, 1:toPrior of previous valid trial
    prev_target = np.hstack([trial_data.target[0], trial_data.target[:-1]]) * 2 - 1  # 0:awayPrior, 1:toPrior of previous valid trial
    prev_color = np.hstack([trial_data.color[0], trial_data.color[:-1]])  # 0:equalPrior, 1:UnequalPrior of previous valid trial

    # indices where the previous trial is invalid/valid
    prev_invalid_idx = np.array(invalid_idx) + 1
    if 0 in invalid_idx:
        prev_invalid_idx = np.append(0, prev_invalid_idx)
    prev_valid_idx = np.setdiff1d(np.arange(len(trial_data)), prev_invalid_idx)

    for i in prev_invalid_idx[prev_invalid_idx < len(trial_data)]:
        if i < prev_valid_idx[0]: #randomly sample if no previous valid trials
            prev_choice[i] = np.random.binomial(1,0.5)
            prev_target[i] = np.random.binomial(1,0.5) * 2 - 1
            prev_color[i]  = np.random.binomial(1,0.5)
        else:
            last_valid =  np.where(prev_valid_idx<i)[0][-1]
            prev_choice[i] = prev_choice[prev_valid_idx[last_valid]]
            prev_target[i] = prev_target[prev_valid_idx[last_valid]]
            prev_color[i] = prev_color[prev_valid_idx[last_valid]]

    prev_choice = (prev_choice * 2) - 1 # -1:awayPrior, 1:toPrior of previous valid trial
    return prev_choice.astype(int), prev_target.astype(int), prev_color.astype(int)

def prepare_input_data(data, input_dim, invalid_idx):
    X = np.ones((1, data.shape[0], input_dim))

    X[0,:,0] = data.signed_coherence / 100
    X[0,:,1] = data.color
    X[0,:,3], X[0,:,4], X[0,:,5] = extract_previous_data(data, invalid_idx)
    return list(X)

### Data processing

In [6]:
print("------------- info ----------------")
print(data.info())
print("------------- Head ----------------")
print(data.head())
print("\n\n------------- describe ----------------\n\n")
print(data.describe())
print("------------- nan counts ----------------")
print(data.isnull().sum())
print("\n\n------------- dtypes ----------------\n\n")
print(data.dtypes)
print("\n\n------------- shape ----------------\n\n")
print(data.shape)

------------- info ----------------
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 20585 entries, 0 to 20584
Data columns (total 15 columns):
 #   Column            Non-Null Count  Dtype  
---  ------            --------------  -----  
 0   subject_id        20585 non-null  object 
 1   medication        20585 non-null  object 
 2   prior             20585 non-null  object 
 3   prior_direction   20585 non-null  object 
 4   prior_color       20585 non-null  object 
 5   color             20307 non-null  float64
 6   coherence         20307 non-null  float64
 7   target            20585 non-null  int64  
 8   is_valid          20585 non-null  bool   
 9   outcome           20585 non-null  int64  
 10  choice            20585 non-null  int64  
 11  reaction_time     20585 non-null  float64
 12  session_filename  20585 non-null  object 
 13  signed_coherence  20307 non-null  float64
 14  session_id        20585 non-null  object 
dtypes: bool(1), float64(4), int64(3), object(7)
memory 

#### Data preparation

In [7]:
off_med_sessions = metadata[metadata['treatment'] == 'OFF'].session_id.unique()
on_med_sessions = metadata[metadata['treatment'] == 'ON'].session_id.unique()
off_med_sessions, on_med_sessions

(array(['P1_OFF', 'P3_OFF', 'P4_OFF', 'P6_OFF', 'P7_OFF', 'P9_OFF',
        'P11_OFF', 'P12_OFF', 'P13_OFF', 'P17_OFF', 'P18_OFF', 'P19_OFF',
        'P20_OFF', 'P22_OFF', 'P23_OFF', 'P24_OFF'], dtype=object),
 array(['P1_ON', 'P3_ON', 'P4_ON', 'P6_ON', 'P7_ON', 'P9_ON', 'P11_ON',
        'P12_ON', 'P13_ON', 'P17_ON', 'P18_ON', 'P19_ON', 'P20_ON',
        'P22_ON', 'P23_ON', 'P24_ON'], dtype=object))

#### Create design matrix (input, output, mask)

In [8]:
n_states = 2       # number of discrete states
obs_dim = 1           # number of observed dimensions: choice(toPrior/awayPrior)
num_categories = 2    # number of categories for output
input_dim = 6        # input dimensions: current signed coherence, current stimulus color, 1(bias), previous choice(toPrior/awayPrior), previous target side(toPrior/awayPrior), previous color(toPrior/awayPrior)

### Off medication sessions

In [9]:
inputs_session_wise = []
choices_session_wise = []
invalid_idx_session_wise = []
masks_session_wise = []
reaction_time_session_wise = []

# off medication sessions
for session in off_med_sessions:
    session_data = data[data['session_id'] == session].reset_index(drop=True)


    invalid_idx = np.where(session_data.outcome < 0)[0]
    valid_idx = np.where(session_data.outcome >= 0)[0]

    inputs = prepare_input_data(session_data, input_dim, invalid_idx)
    choices = session_data.choice.values.reshape(-1,1).astype('int')

    # for training, replace -1 with random sample from 0,1
    choices[choices == -1] = npr.choice([0,1],invalid_idx.shape[0])
    mask = np.ones_like(choices, dtype=bool)
    mask[invalid_idx] = 0
    reaction_time = np.array(session_data.reaction_time)

    masks_session_wise.append(mask)
    inputs_session_wise += inputs
    choices_session_wise.append(choices)
    reaction_time_session_wise.append(reaction_time)

off_med_inputs_aggregate, off_med_choices_aggregate, off_med_masks_aggregate = [], [], []
off_med_inputs_aggregate.append(np.vstack(inputs_session_wise))
off_med_choices_aggregate.append(np.vstack(choices_session_wise))
off_med_masks_aggregate.append(np.vstack(masks_session_wise))

unnormalized_off_med_inputs = copy.deepcopy(inputs_session_wise)

# scaling signed coherence
off_med_inputs_aggregate[0][off_med_masks_aggregate[0][:,0],0] = preprocessing.scale(off_med_inputs_aggregate[0][off_med_masks_aggregate[0][:,0],0], axis=0)
for idx in range(len(off_med_sessions)):
    inputs_session_wise[idx][masks_session_wise[idx][:,0]] = preprocessing.scale(inputs_session_wise[idx][masks_session_wise[idx][:,0]], axis=0)


In [None]:
models_glm_hmm_off_med, fit_lls_glm_hmm_off_med = global_fit(off_med_choices_aggregate, off_med_inputs_aggregate, masks=off_med_masks_aggregate, n_iters= 1000, n_initializations=20)

Fitting GLM globally...


  0%|          | 0/1000 [00:00<?, ?it/s]

Fitting 2 states...


In [None]:
# get best model of 20 initializations for each state
init_params = {
    'glm_weights': {},
    'transition_matrices': {}
}
for n_states in np.arange(2,6):
    best_idx = fit_lls_glm_hmm_off_med[n_states].index(max(fit_lls_glm_hmm_off_med[n_states]))
    init_params['glm_weights'][n_states] = models_glm_hmm_off_med[n_states][best_idx].observations.params
    init_params['transition_matrices'][n_states] = models_glm_hmm_off_med[n_states][best_idx].transitions.params

In [None]:
# session-wise fitting with 5 fold cross-validation
models_session_state_fold_off_med, train_ll_session_off_med, test_ll_session_off_med = session_wise_fit_cv(choices_session_wise, inputs_session_wise, masks=masks_session_wise,
                                                                                    n_sessions=len(off_med_sessions), init_params=init_params, n_iters= 1000)

Fitting session 0...
Fitting 2 states...


Converged to LP: -318.9:   2%|▏         | 20/1000 [00:00<00:05, 185.92it/s]
Converged to LP: -319.2:   2%|▎         | 25/1000 [00:00<00:05, 168.07it/s]
Converged to LP: -310.4:   2%|▎         | 25/1000 [00:00<00:05, 187.86it/s]
Converged to LP: -315.1:   4%|▍         | 41/1000 [00:00<00:03, 240.75it/s]
Converged to LP: -316.8:   9%|▉         | 88/1000 [00:00<00:04, 212.52it/s]


Fitting 3 states...


Converged to LP: -356.6:   2%|▏         | 24/1000 [00:00<00:06, 142.74it/s]
Converged to LP: -361.6:   4%|▎         | 36/1000 [00:00<00:05, 171.42it/s]
Converged to LP: -353.5:   7%|▋         | 74/1000 [00:00<00:04, 223.96it/s]
Converged to LP: -346.6:  11%|█         | 110/1000 [00:00<00:04, 200.32it/s]
Converged to LP: -353.5:  11%|█         | 106/1000 [00:00<00:04, 188.61it/s]


Fitting 4 states...


Converged to LP: -397.2:   8%|▊         | 75/1000 [00:00<00:05, 178.03it/s]
Converged to LP: -396.7:   8%|▊         | 80/1000 [00:00<00:05, 159.70it/s]
Converged to LP: -397.8:   9%|▉         | 94/1000 [00:00<00:04, 182.88it/s]
Converged to LP: -398.5:  16%|█▋        | 163/1000 [00:00<00:04, 202.11it/s]
Converged to LP: -388.2:  16%|█▌        | 157/1000 [00:00<00:04, 180.34it/s]


Fitting 5 states...


Converged to LP: -458.8:   4%|▍         | 38/1000 [00:00<00:06, 137.44it/s]
Converged to LP: -437.0:  12%|█▏        | 122/1000 [00:00<00:05, 158.63it/s]
Converged to LP: -455.1:  17%|█▋        | 170/1000 [00:00<00:04, 176.91it/s]
Converged to LP: -439.5:  19%|█▉        | 192/1000 [00:01<00:05, 160.50it/s]
Converged to LP: -429.2:  33%|███▎      | 328/1000 [00:01<00:03, 186.56it/s]


Fitting session 1...
Fitting 2 states...


Converged to LP: -301.8:   1%|          | 6/1000 [00:00<00:11, 84.22it/s]
Converged to LP: -308.5:   1%|          | 9/1000 [00:00<00:09, 104.00it/s]
Converged to LP: -314.0:   2%|▏         | 22/1000 [00:00<00:05, 177.74it/s]
Converged to LP: -316.7:   3%|▎         | 32/1000 [00:00<00:05, 192.69it/s]
Converged to LP: -314.0:   1%|          | 8/1000 [00:00<00:09, 109.60it/s]


Fitting 3 states...


Converged to LP: -357.5:   1%|▏         | 13/1000 [00:00<00:09, 109.50it/s]
Converged to LP: -343.4:   1%|          | 10/1000 [00:00<00:09, 101.22it/s]
Converged to LP: -351.3:   1%|          | 11/1000 [00:00<00:08, 114.64it/s]
Converged to LP: -348.3:   2%|▏         | 16/1000 [00:00<00:07, 132.71it/s]
Converged to LP: -354.8:   7%|▋         | 73/1000 [00:00<00:05, 169.09it/s]
LP: -396.3:   2%|▏         | 15/1000 [00:00<00:06, 148.02it/s]

Fitting 4 states...


Converged to LP: -392.3:   4%|▍         | 38/1000 [00:00<00:06, 141.63it/s]
Converged to LP: -393.3:  15%|█▌        | 151/1000 [00:00<00:04, 192.94it/s]
Converged to LP: -387.6:  15%|█▌        | 151/1000 [00:00<00:04, 189.07it/s]
Converged to LP: -385.0:   2%|▏         | 24/1000 [00:00<00:07, 132.15it/s]
Converged to LP: -399.0:   5%|▌         | 50/1000 [00:00<00:05, 165.88it/s]
LP: -435.4:   1%|          | 7/1000 [00:00<00:15, 65.26it/s]s]

Fitting 5 states...


Converged to LP: -431.8:   6%|▌         | 59/1000 [00:00<00:06, 149.07it/s]
Converged to LP: -442.2:   8%|▊         | 80/1000 [00:00<00:06, 144.20it/s]
Converged to LP: -437.1:   9%|▉         | 92/1000 [00:00<00:05, 160.65it/s]
Converged to LP: -416.5:  13%|█▎        | 130/1000 [00:00<00:04, 196.37it/s]
Converged to LP: -432.6:  15%|█▍        | 149/1000 [00:00<00:04, 177.34it/s]
Converged to LP: -332.9:   1%|          | 12/1000 [00:00<00:06, 153.57it/s]
Converged to LP: -343.9:   2%|▏         | 20/1000 [00:00<00:04, 218.84it/s]
LP: -356.9:   0%|          | 0/1000 [00:00<?, ?it/s]71.80it/s]

Fitting session 2...
Fitting 2 states...
Fitting 3 states...


Converged to LP: -350.1:   2%|▏         | 18/1000 [00:00<00:05, 178.39it/s]
Converged to LP: -355.6:   3%|▎         | 27/1000 [00:00<00:04, 212.21it/s]
Converged to LP: -353.5:   4%|▎         | 37/1000 [00:00<00:04, 205.12it/s]
Converged to LP: -370.7:   2%|▏         | 21/1000 [00:00<00:05, 191.06it/s]
Converged to LP: -402.0:   2%|▎         | 25/1000 [00:00<00:04, 195.32it/s]
Converged to LP: -397.8:   3%|▎         | 31/1000 [00:00<00:04, 217.04it/s]
Converged to LP: -384.8:   6%|▋         | 65/1000 [00:00<00:04, 213.59it/s]
Converged to LP: -378.1:  12%|█▏        | 118/1000 [00:00<00:03, 227.71it/s]


Fitting 4 states...


Converged to LP: -417.9:   5%|▌         | 50/1000 [00:00<00:05, 186.26it/s]
Converged to LP: -436.0:   7%|▋         | 68/1000 [00:00<00:05, 169.47it/s]
Converged to LP: -440.9:   8%|▊         | 81/1000 [00:00<00:05, 181.41it/s]
Converged to LP: -431.3:  10%|▉         | 98/1000 [00:00<00:04, 183.00it/s]
Converged to LP: -417.0:   8%|▊         | 82/1000 [00:00<00:06, 132.61it/s]
LP: -485.2:   1%|          | 11/1000 [00:00<00:09, 105.49it/s]

Fitting 5 states...


Converged to LP: -475.8:   8%|▊         | 84/1000 [00:00<00:06, 145.91it/s]
Converged to LP: -459.3:   8%|▊         | 82/1000 [00:00<00:07, 125.34it/s]
Converged to LP: -460.7:  13%|█▎        | 128/1000 [00:01<00:06, 127.55it/s]
Converged to LP: -483.2:  20%|█▉        | 195/1000 [00:01<00:04, 173.41it/s]
Converged to LP: -470.8:  30%|██▉       | 296/1000 [00:01<00:04, 172.41it/s]
LP: -293.3:   0%|          | 0/1000 [00:00<?, ?it/s]

Fitting session 3...
Fitting 2 states...


Converged to LP: -297.3:   4%|▍         | 38/1000 [00:00<00:03, 254.22it/s]
Converged to LP: -275.2:   4%|▎         | 35/1000 [00:00<00:04, 235.53it/s]
Converged to LP: -287.6:   9%|▉         | 90/1000 [00:00<00:02, 314.98it/s]
Converged to LP: -294.5:  12%|█▏        | 117/1000 [00:00<00:02, 333.99it/s]
Converged to LP: -286.3:  12%|█▎        | 125/1000 [00:00<00:03, 230.98it/s]
LP: -338.7:   1%|▏         | 13/1000 [00:00<00:07, 125.76it/s]

Fitting 3 states...


Converged to LP: -335.3:   3%|▎         | 29/1000 [00:00<00:05, 190.22it/s]
Converged to LP: -335.8:   4%|▍         | 45/1000 [00:00<00:06, 149.88it/s]
Converged to LP: -330.8:   6%|▌         | 62/1000 [00:00<00:05, 173.93it/s]
Converged to LP: -333.4:  10%|█         | 105/1000 [00:00<00:04, 218.42it/s]
Converged to LP: -310.7:  24%|██▍       | 239/1000 [00:01<00:03, 222.85it/s]
LP: -392.0:   0%|          | 0/1000 [00:00<?, ?it/s]

Fitting 4 states...


Converged to LP: -374.8:   5%|▍         | 46/1000 [00:00<00:06, 139.15it/s]
Converged to LP: -375.1:   8%|▊         | 79/1000 [00:00<00:04, 185.45it/s]
Converged to LP: -374.5:  28%|██▊       | 280/1000 [00:01<00:03, 232.14it/s]
Converged to LP: -362.3:  29%|██▊       | 287/1000 [00:01<00:03, 212.81it/s]
Converged to LP: -365.1:  61%|██████    | 606/1000 [00:02<00:01, 232.93it/s]
LP: -430.2:   0%|          | 0/1000 [00:00<?, ?it/s]

Fitting 5 states...


Converged to LP: -412.1:   6%|▌         | 58/1000 [00:00<00:06, 155.74it/s]
Converged to LP: -410.3:  13%|█▎        | 132/1000 [00:00<00:05, 161.34it/s]
Converged to LP: -413.7:  15%|█▌        | 153/1000 [00:00<00:04, 179.07it/s]
Converged to LP: -406.5:  20%|██        | 201/1000 [00:01<00:05, 153.07it/s]
Converged to LP: -400.8:  26%|██▋       | 263/1000 [00:01<00:04, 180.85it/s]
LP: -409.8:   0%|          | 0/1000 [00:00<?, ?it/s]

Fitting session 4...
Fitting 2 states...


Converged to LP: -378.7:   1%|▏         | 13/1000 [00:00<00:03, 257.81it/s]
Converged to LP: -382.0:   2%|▏         | 18/1000 [00:00<00:03, 303.93it/s]
Converged to LP: -376.4:   5%|▌         | 50/1000 [00:00<00:03, 241.27it/s]
Converged to LP: -384.4:   7%|▋         | 68/1000 [00:00<00:03, 253.07it/s]
Converged to LP: -388.6:   9%|▉         | 90/1000 [00:00<00:03, 301.27it/s]
Converged to LP: -416.8:   1%|▏         | 13/1000 [00:00<00:05, 185.30it/s]
LP: -429.3:   0%|          | 0/1000 [00:00<?, ?it/s]

Fitting 3 states...
Fitting 4 states...


Converged to LP: -421.3:   2%|▏         | 16/1000 [00:00<00:06, 153.59it/s]
Converged to LP: -426.7:   2%|▎         | 25/1000 [00:00<00:05, 183.62it/s]
Converged to LP: -421.6:   4%|▎         | 37/1000 [00:00<00:04, 229.17it/s]
Converged to LP: -428.7:   4%|▍         | 39/1000 [00:00<00:04, 209.02it/s]
Converged to LP: -460.6:   4%|▍         | 43/1000 [00:00<00:05, 185.32it/s]
Converged to LP: -455.8:   4%|▍         | 44/1000 [00:00<00:05, 173.80it/s]
Converged to LP: -463.3:  13%|█▎        | 133/1000 [00:00<00:04, 185.93it/s]
Converged to LP: -453.1:  21%|██        | 212/1000 [00:00<00:03, 220.84it/s]
Converged to LP: -469.9:  21%|██        | 210/1000 [00:00<00:03, 212.43it/s]
LP: -517.5:   0%|          | 0/1000 [00:00<?, ?it/s]

Fitting 5 states...


Converged to LP: -499.7:   5%|▌         | 52/1000 [00:00<00:07, 133.86it/s]
Converged to LP: -494.2:   7%|▋         | 68/1000 [00:00<00:06, 155.15it/s]
Converged to LP: -491.9:  10%|▉         | 97/1000 [00:00<00:05, 170.48it/s]
Converged to LP: -496.5:   9%|▉         | 89/1000 [00:00<00:05, 153.90it/s]
Converged to LP: -508.3:  15%|█▌        | 150/1000 [00:00<00:05, 160.25it/s]
Converged to LP: -392.2:   1%|          | 7/1000 [00:00<00:05, 185.83it/s]
Converged to LP: -387.3:   1%|          | 11/1000 [00:00<00:04, 233.71it/s]
Converged to LP: -392.8:   1%|          | 11/1000 [00:00<00:04, 205.18it/s]
Converged to LP: -398.1:   1%|          | 11/1000 [00:00<00:04, 205.80it/s]
Converged to LP: -383.3:   2%|▏         | 17/1000 [00:00<00:04, 230.85it/s]
Converged to LP: -418.3:   1%|          | 8/1000 [00:00<00:06, 162.22it/s]
Converged to LP: -429.6:   1%|          | 9/1000 [00:00<00:05, 167.24it/s]
Converged to LP: -430.2:   1%|          | 10/1000 [00:00<00:05, 183.06it/s]
Converged to L

Fitting session 5...
Fitting 2 states...
Fitting 3 states...
Fitting 4 states...


Converged to LP: -423.9:   1%|          | 12/1000 [00:00<00:06, 157.54it/s]
Converged to LP: -434.7:   4%|▍         | 44/1000 [00:00<00:05, 167.87it/s]
Converged to LP: -430.9:   4%|▍         | 39/1000 [00:00<00:06, 143.53it/s]
Converged to LP: -441.8:   6%|▌         | 59/1000 [00:00<00:04, 195.13it/s]
Converged to LP: -437.1:   7%|▋         | 66/1000 [00:00<00:04, 192.25it/s]
Converged to LP: -450.2:   8%|▊         | 75/1000 [00:00<00:05, 178.09it/s]
LP: -498.0:   1%|          | 9/1000 [00:00<00:12, 80.88it/s]s]

Fitting 5 states...


Converged to LP: -492.0:   4%|▍         | 41/1000 [00:00<00:06, 143.66it/s]
Converged to LP: -498.4:   5%|▌         | 52/1000 [00:00<00:05, 161.64it/s]
Converged to LP: -483.2:   5%|▍         | 47/1000 [00:00<00:06, 140.45it/s]
Converged to LP: -511.3:   5%|▌         | 51/1000 [00:00<00:07, 125.99it/s]
Converged to LP: -490.5:   7%|▋         | 68/1000 [00:00<00:07, 126.63it/s]
LP: -277.2:   0%|          | 0/1000 [00:00<?, ?it/s]

Fitting session 6...
Fitting 2 states...


Converged to LP: -256.5:   1%|          | 11/1000 [00:00<00:03, 273.62it/s]
Converged to LP: -269.5:   1%|          | 11/1000 [00:00<00:03, 262.28it/s]
Converged to LP: -270.3:   2%|▎         | 25/1000 [00:00<00:04, 243.65it/s]
Converged to LP: -270.0:   6%|▋         | 65/1000 [00:00<00:03, 262.44it/s]
Converged to LP: -266.6:  13%|█▎        | 133/1000 [00:00<00:02, 319.80it/s]
Converged to LP: -314.6:   2%|▏         | 19/1000 [00:00<00:05, 192.12it/s]
LP: -313.8:   1%|          | 12/1000 [00:00<00:08, 114.72it/s]

Fitting 3 states...


Converged to LP: -297.4:   4%|▍         | 42/1000 [00:00<00:04, 213.48it/s]
Converged to LP: -310.1:   4%|▎         | 36/1000 [00:00<00:06, 138.15it/s]
Converged to LP: -312.5:   4%|▍         | 44/1000 [00:00<00:06, 151.67it/s]
Converged to LP: -296.9:  19%|█▊        | 186/1000 [00:00<00:03, 225.25it/s]
LP: -357.7:   1%|          | 10/1000 [00:00<00:10, 98.22it/s]]

Fitting 4 states...


Converged to LP: -350.4:   6%|▌         | 57/1000 [00:00<00:05, 172.57it/s]
Converged to LP: -356.0:   7%|▋         | 71/1000 [00:00<00:05, 156.29it/s]
Converged to LP: -341.1:   7%|▋         | 68/1000 [00:00<00:06, 142.54it/s]
Converged to LP: -350.9:  17%|█▋        | 171/1000 [00:00<00:04, 189.62it/s]
Converged to LP: -354.7:  24%|██▍       | 244/1000 [00:01<00:03, 202.66it/s]
LP: -399.0:   1%|▏         | 14/1000 [00:00<00:07, 134.80it/s]

Fitting 5 states...


Converged to LP: -400.2:   5%|▌         | 51/1000 [00:00<00:05, 164.93it/s]
Converged to LP: -402.7:   7%|▋         | 66/1000 [00:00<00:05, 184.83it/s]
Converged to LP: -403.8:   5%|▌         | 54/1000 [00:00<00:06, 136.23it/s]
Converged to LP: -385.0:  23%|██▎       | 229/1000 [00:01<00:04, 183.73it/s]
Converged to LP: -385.5:  28%|██▊       | 284/1000 [00:01<00:04, 156.62it/s]
Converged to LP: -284.4:   2%|▏         | 17/1000 [00:00<00:03, 286.24it/s]
Converged to LP: -289.2:   2%|▎         | 25/1000 [00:00<00:03, 268.31it/s]
Converged to LP: -282.7:   2%|▏         | 17/1000 [00:00<00:05, 177.44it/s]
Converged to LP: -276.0:   3%|▎         | 31/1000 [00:00<00:04, 223.29it/s]
LP: -284.2:   2%|▏         | 17/1000 [00:00<00:05, 168.39it/s]

Fitting session 7...
Fitting 2 states...


Converged to LP: -281.6:  10%|▉         | 95/1000 [00:00<00:03, 271.58it/s]


Fitting 3 states...


Converged to LP: -328.8:   2%|▏         | 22/1000 [00:00<00:07, 139.30it/s]
Converged to LP: -326.7:   4%|▍         | 42/1000 [00:00<00:04, 232.65it/s]
Converged to LP: -326.5:   4%|▍         | 43/1000 [00:00<00:04, 233.46it/s]
Converged to LP: -322.3:   4%|▍         | 39/1000 [00:00<00:05, 187.75it/s]
Converged to LP: -332.3:   5%|▍         | 49/1000 [00:00<00:04, 226.94it/s]
LP: -376.6:   1%|▏         | 14/1000 [00:00<00:07, 139.85it/s]

Fitting 4 states...


Converged to LP: -364.9:   9%|▉         | 94/1000 [00:00<00:04, 181.35it/s]
Converged to LP: -360.6:  20%|█▉        | 196/1000 [00:01<00:04, 171.56it/s]
Converged to LP: -363.2:  23%|██▎       | 230/1000 [00:01<00:03, 199.11it/s]
Converged to LP: -351.5:  34%|███▍      | 342/1000 [00:01<00:03, 194.56it/s]
Converged to LP: -369.1:  62%|██████▏   | 622/1000 [00:02<00:01, 212.03it/s]
LP: -431.3:   0%|          | 0/1000 [00:00<?, ?it/s]

Fitting 5 states...


Converged to LP: -408.8:   8%|▊         | 83/1000 [00:00<00:05, 154.38it/s]
Converged to LP: -418.8:  12%|█▏        | 120/1000 [00:00<00:06, 130.86it/s]
Converged to LP: -408.4:  27%|██▋       | 267/1000 [00:01<00:03, 192.59it/s]
Converged to LP: -417.9:  36%|███▋      | 365/1000 [00:01<00:03, 191.55it/s]
Converged to LP: -403.0:  36%|███▋      | 365/1000 [00:02<00:03, 159.31it/s]
Converged to LP: -373.1:   1%|          | 10/1000 [00:00<00:03, 273.24it/s]
Converged to LP: -360.0:   1%|          | 10/1000 [00:00<00:03, 247.88it/s]
Converged to LP: -363.9:   1%|          | 12/1000 [00:00<00:03, 266.08it/s]
Converged to LP: -373.7:   1%|▏         | 13/1000 [00:00<00:03, 260.19it/s]
Converged to LP: -367.7:   2%|▏         | 19/1000 [00:00<00:05, 170.50it/s]
LP: -418.1:   0%|          | 0/1000 [00:00<?, ?it/s]

Fitting session 8...
Fitting 2 states...
Fitting 3 states...


Converged to LP: -418.1:   2%|▏         | 15/1000 [00:00<00:04, 218.80it/s]
Converged to LP: -403.7:   1%|          | 12/1000 [00:00<00:05, 189.79it/s]
Converged to LP: -416.0:   2%|▏         | 19/1000 [00:00<00:04, 230.96it/s]
Converged to LP: -409.4:   2%|▏         | 15/1000 [00:00<00:06, 143.23it/s]
Converged to LP: -407.7:   6%|▌         | 57/1000 [00:00<00:04, 219.63it/s]


Fitting 4 states...


Converged to LP: -447.4:   5%|▍         | 47/1000 [00:00<00:05, 177.08it/s]
Converged to LP: -451.3:   8%|▊         | 85/1000 [00:00<00:06, 144.01it/s]
Converged to LP: -454.9:  11%|█         | 112/1000 [00:00<00:05, 172.10it/s]
Converged to LP: -443.4:  13%|█▎        | 134/1000 [00:00<00:04, 187.44it/s]
Converged to LP: -447.7:  25%|██▌       | 251/1000 [00:01<00:03, 228.92it/s]
LP: -502.9:   0%|          | 0/1000 [00:00<?, ?it/s]

Fitting 5 states...


Converged to LP: -480.3:  10%|▉         | 95/1000 [00:00<00:06, 144.96it/s]
Converged to LP: -481.8:  11%|█         | 112/1000 [00:00<00:05, 156.45it/s]
Converged to LP: -494.1:  12%|█▏        | 122/1000 [00:00<00:06, 144.88it/s]
Converged to LP: -499.1:  13%|█▎        | 129/1000 [00:00<00:05, 149.97it/s]
Converged to LP: -478.8:  21%|██        | 210/1000 [00:01<00:05, 154.56it/s]
Converged to LP: -346.5:   1%|          | 6/1000 [00:00<00:03, 264.47it/s]
Converged to LP: -348.6:   1%|          | 7/1000 [00:00<00:03, 286.68it/s]
Converged to LP: -346.4:   1%|          | 8/1000 [00:00<00:03, 269.52it/s]
Converged to LP: -338.2:   1%|          | 6/1000 [00:00<00:05, 173.64it/s]
Converged to LP: -339.2:   1%|          | 11/1000 [00:00<00:04, 211.68it/s]
LP: -392.5:   0%|          | 0/1000 [00:00<?, ?it/s]

Fitting session 9...
Fitting 2 states...
Fitting 3 states...


Converged to LP: -391.4:   2%|▏         | 21/1000 [00:00<00:03, 249.77it/s]
Converged to LP: -382.0:   2%|▏         | 18/1000 [00:00<00:04, 226.30it/s]
Converged to LP: -392.5:   3%|▎         | 30/1000 [00:00<00:03, 253.96it/s]
Converged to LP: -381.7:   4%|▍         | 39/1000 [00:00<00:04, 229.47it/s]
Converged to LP: -392.0:   6%|▋         | 65/1000 [00:00<00:04, 233.58it/s]


Fitting 4 states...


Converged to LP: -436.2:   3%|▎         | 33/1000 [00:00<00:05, 189.28it/s]
Converged to LP: -423.9:   4%|▍         | 42/1000 [00:00<00:06, 149.06it/s]
Converged to LP: -419.5:   8%|▊         | 81/1000 [00:00<00:04, 205.78it/s]
Converged to LP: -402.2:  15%|█▌        | 151/1000 [00:00<00:04, 197.51it/s]
Converged to LP: -431.2:  15%|█▍        | 146/1000 [00:00<00:04, 189.90it/s]
LP: -486.8:   0%|          | 0/1000 [00:00<?, ?it/s]

Fitting 5 states...


Converged to LP: -454.2:   4%|▎         | 37/1000 [00:00<00:09, 96.32it/s]
Converged to LP: -475.4:   8%|▊         | 83/1000 [00:00<00:05, 175.39it/s]
Converged to LP: -460.6:  10%|█         | 105/1000 [00:00<00:05, 165.25it/s]
Converged to LP: -459.7:   9%|▉         | 90/1000 [00:00<00:06, 135.34it/s]
Converged to LP: -461.5:  12%|█▏        | 117/1000 [00:00<00:05, 163.43it/s]
Converged to LP: -337.3:   1%|▏         | 14/1000 [00:00<00:03, 247.28it/s]
LP: -330.4:   3%|▎         | 27/1000 [00:00<00:03, 263.91it/s]

Fitting session 10...
Fitting 2 states...


Converged to LP: -330.4:   4%|▍         | 40/1000 [00:00<00:03, 290.45it/s]
Converged to LP: -335.8:  10%|▉         | 95/1000 [00:00<00:02, 307.29it/s]
Converged to LP: -336.7:  11%|█         | 108/1000 [00:00<00:03, 255.25it/s]
Converged to LP: -346.6:  30%|███       | 300/1000 [00:01<00:02, 285.56it/s]
LP: -380.0:   0%|          | 0/1000 [00:00<?, ?it/s]

Fitting 3 states...


Converged to LP: -385.8:   8%|▊         | 77/1000 [00:00<00:05, 184.35it/s]
Converged to LP: -373.5:  12%|█▎        | 125/1000 [00:00<00:04, 218.31it/s]
Converged to LP: -377.8:  12%|█▏        | 121/1000 [00:00<00:04, 191.05it/s]
Converged to LP: -376.2:  15%|█▍        | 149/1000 [00:00<00:03, 228.31it/s]
Converged to LP: -375.7:  16%|█▋        | 164/1000 [00:00<00:03, 247.07it/s]


Fitting 4 states...


Converged to LP: -413.9:   7%|▋         | 69/1000 [00:00<00:05, 178.64it/s]
Converged to LP: -439.2:   9%|▉         | 91/1000 [00:00<00:04, 199.06it/s]
Converged to LP: -418.3:   9%|▉         | 93/1000 [00:00<00:05, 172.73it/s]
Converged to LP: -419.3:  19%|█▉        | 192/1000 [00:00<00:03, 207.94it/s]
Converged to LP: -413.3:  32%|███▏      | 319/1000 [00:01<00:03, 199.15it/s]
LP: -486.0:   0%|          | 0/1000 [00:00<?, ?it/s]

Fitting 5 states...


Converged to LP: -451.5:  11%|█▏        | 113/1000 [00:00<00:05, 165.88it/s]
Converged to LP: -449.9:  14%|█▍        | 140/1000 [00:00<00:05, 153.69it/s]
Converged to LP: -471.6:  19%|█▉        | 193/1000 [00:01<00:04, 175.57it/s]
Converged to LP: -471.9:  25%|██▍       | 247/1000 [00:01<00:04, 177.51it/s]
Converged to LP: -453.4:  30%|██▉       | 299/1000 [00:01<00:03, 175.49it/s]
Converged to LP: -387.5:   1%|          | 12/1000 [00:00<00:05, 182.25it/s]
LP: -374.7:   2%|▏         | 24/1000 [00:00<00:04, 235.42it/s]

Fitting session 11...
Fitting 2 states...


Converged to LP: -374.7:   3%|▎         | 26/1000 [00:00<00:04, 237.90it/s]
Converged to LP: -377.5:   3%|▎         | 33/1000 [00:00<00:03, 275.19it/s]
Converged to LP: -372.8:   4%|▎         | 37/1000 [00:00<00:03, 273.81it/s]
Converged to LP: -376.7:  10%|▉         | 97/1000 [00:00<00:03, 286.06it/s]
LP: -419.9:   2%|▏         | 20/1000 [00:00<00:05, 192.09it/s], 207.87it/s]
Converged to LP: -419.8:   3%|▎         | 31/1000 [00:00<00:04, 208.15it/s]
Converged to LP: -411.8:   2%|▎         | 25/1000 [00:00<00:05, 163.33it/s]
LP: -407.0:   2%|▏         | 19/1000 [00:00<00:05, 183.10it/s]

Fitting 3 states...


Converged to LP: -404.2:   6%|▋         | 64/1000 [00:00<00:04, 207.26it/s]
Converged to LP: -431.1:  10%|▉         | 97/1000 [00:00<00:03, 235.36it/s]
LP: -464.5:   1%|▏         | 13/1000 [00:00<00:07, 127.15it/s]

Fitting 4 states...


Converged to LP: -464.9:   7%|▋         | 72/1000 [00:00<00:05, 177.51it/s]
Converged to LP: -437.7:  10%|▉         | 97/1000 [00:00<00:05, 176.66it/s]
Converged to LP: -449.5:   9%|▉         | 94/1000 [00:00<00:05, 151.06it/s]
Converged to LP: -451.5:  13%|█▎        | 127/1000 [00:00<00:05, 145.79it/s]
Converged to LP: -459.6:  18%|█▊        | 183/1000 [00:01<00:04, 178.24it/s]
LP: -504.2:   0%|          | 0/1000 [00:00<?, ?it/s]5.36it/s]]

Fitting 5 states...


Converged to LP: -503.3:  10%|█         | 102/1000 [00:00<00:05, 151.53it/s]
Converged to LP: -490.0:  12%|█▏        | 122/1000 [00:00<00:05, 156.86it/s]
Converged to LP: -492.2:  13%|█▎        | 131/1000 [00:00<00:05, 153.17it/s]
Converged to LP: -502.5:  17%|█▋        | 166/1000 [00:01<00:05, 163.76it/s]
Converged to LP: -497.3:  27%|██▋       | 272/1000 [00:01<00:04, 173.78it/s]
Converged to LP: -354.0:   1%|▏         | 14/1000 [00:00<00:04, 213.30it/s]
Converged to LP: -358.6:   3%|▎         | 26/1000 [00:00<00:03, 261.83it/s]
Converged to LP: -359.1:   2%|▏         | 18/1000 [00:00<00:05, 169.97it/s]
Converged to LP: -357.8:   3%|▎         | 31/1000 [00:00<00:04, 210.42it/s]
Converged to LP: -368.4:   4%|▍         | 39/1000 [00:00<00:03, 257.76it/s]


Fitting session 12...
Fitting 2 states...
Fitting 3 states...


Converged to LP: -404.8:   2%|▏         | 20/1000 [00:00<00:04, 220.59it/s]
Converged to LP: -406.1:   2%|▎         | 25/1000 [00:00<00:04, 242.38it/s]
Converged to LP: -400.1:   2%|▏         | 24/1000 [00:00<00:04, 196.82it/s]
Converged to LP: -409.1:   4%|▍         | 38/1000 [00:00<00:05, 178.15it/s]
Converged to LP: -388.8:   8%|▊         | 79/1000 [00:00<00:04, 185.33it/s]
LP: -442.8:   2%|▏         | 15/1000 [00:00<00:06, 141.47it/s]

Fitting 4 states...


Converged to LP: -443.4:   8%|▊         | 79/1000 [00:00<00:04, 198.67it/s]
Converged to LP: -422.3:   7%|▋         | 67/1000 [00:00<00:06, 146.20it/s]
Converged to LP: -429.3:   9%|▉         | 93/1000 [00:00<00:05, 172.19it/s]
Converged to LP: -434.1:  11%|█         | 108/1000 [00:00<00:04, 196.02it/s]
Converged to LP: -432.9:  22%|██▏       | 224/1000 [00:01<00:03, 207.52it/s]
LP: -496.8:   0%|          | 0/1000 [00:00<?, ?it/s]

Fitting 5 states...


Converged to LP: -474.7:   6%|▌         | 61/1000 [00:00<00:07, 131.24it/s]
Converged to LP: -482.5:   5%|▌         | 50/1000 [00:00<00:09, 104.99it/s]
Converged to LP: -476.6:   6%|▋         | 64/1000 [00:00<00:08, 108.87it/s]
Converged to LP: -477.8:   9%|▉         | 92/1000 [00:00<00:06, 132.06it/s]
Converged to LP: -477.3:  20%|█▉        | 197/1000 [00:01<00:04, 162.96it/s]
Converged to LP: -381.2:   2%|▏         | 15/1000 [00:00<00:03, 261.61it/s]
Converged to LP: -378.1:   1%|▏         | 14/1000 [00:00<00:04, 200.94it/s]
Converged to LP: -377.5:   2%|▏         | 15/1000 [00:00<00:05, 165.68it/s]
Converged to LP: -373.0:   2%|▏         | 17/1000 [00:00<00:05, 183.14it/s]


Fitting session 13...
Fitting 2 states...


Converged to LP: -368.6:  20%|█▉        | 198/1000 [00:00<00:02, 292.52it/s]
LP: -448.0:   0%|          | 0/1000 [00:00<?, ?it/s]

Fitting 3 states...


Converged to LP: -426.0:   2%|▏         | 18/1000 [00:00<00:06, 153.58it/s]
Converged to LP: -411.8:   2%|▎         | 25/1000 [00:00<00:05, 163.07it/s]
Converged to LP: -414.3:  10%|▉         | 95/1000 [00:00<00:04, 223.43it/s]
Converged to LP: -412.1:  20%|██        | 200/1000 [00:00<00:03, 219.32it/s]
Converged to LP: -412.1:  24%|██▍       | 241/1000 [00:01<00:03, 217.58it/s]
LP: -452.5:   1%|▏         | 14/1000 [00:00<00:07, 137.25it/s]

Fitting 4 states...


Converged to LP: -440.9:   7%|▋         | 67/1000 [00:00<00:05, 158.81it/s]
Converged to LP: -455.8:   8%|▊         | 80/1000 [00:00<00:07, 127.89it/s]
Converged to LP: -446.1:  17%|█▋        | 168/1000 [00:00<00:03, 209.99it/s]
Converged to LP: -443.8:  16%|█▌        | 156/1000 [00:00<00:04, 171.09it/s]
Converged to LP: -451.9:  21%|██        | 211/1000 [00:01<00:04, 179.01it/s]
LP: -515.2:   1%|          | 8/1000 [00:00<00:13, 73.04it/s]s]

Fitting 5 states...


Converged to LP: -500.0:   6%|▌         | 57/1000 [00:00<00:06, 157.11it/s]
Converged to LP: -502.4:   9%|▊         | 87/1000 [00:00<00:08, 113.39it/s]
Converged to LP: -476.3:  13%|█▎        | 130/1000 [00:00<00:06, 135.34it/s]
Converged to LP: -492.8:  24%|██▍       | 241/1000 [00:01<00:04, 155.85it/s]
Converged to LP: -497.6:  26%|██▋       | 265/1000 [00:01<00:04, 167.05it/s]
Converged to LP: -343.4:   2%|▏         | 17/1000 [00:00<00:03, 265.59it/s]
Converged to LP: -341.9:   2%|▏         | 23/1000 [00:00<00:03, 285.98it/s]
LP: -335.3:   2%|▏         | 22/1000 [00:00<00:04, 210.84it/s]

Fitting session 14...
Fitting 2 states...


Converged to LP: -340.1:   3%|▎         | 30/1000 [00:00<00:05, 192.35it/s]
Converged to LP: -343.5:   4%|▍         | 42/1000 [00:00<00:03, 259.15it/s]
Converged to LP: -329.5:  13%|█▎        | 127/1000 [00:00<00:03, 260.90it/s]
LP: -385.5:   0%|          | 0/1000 [00:00<?, ?it/s]

Fitting 3 states...


Converged to LP: -385.6:   3%|▎         | 26/1000 [00:00<00:04, 200.01it/s]
Converged to LP: -386.1:   5%|▌         | 51/1000 [00:00<00:04, 225.47it/s]
Converged to LP: -379.7:   5%|▌         | 52/1000 [00:00<00:04, 201.03it/s]
Converged to LP: -369.9:   8%|▊         | 75/1000 [00:00<00:03, 238.00it/s]
Converged to LP: -385.0:   6%|▌         | 62/1000 [00:00<00:05, 184.34it/s]
LP: -429.9:   1%|▏         | 13/1000 [00:00<00:07, 129.36it/s]

Fitting 4 states...


Converged to LP: -425.0:   8%|▊         | 81/1000 [00:00<00:04, 185.27it/s]
Converged to LP: -426.9:   7%|▋         | 68/1000 [00:00<00:06, 133.91it/s]
Converged to LP: -419.3:  13%|█▎        | 126/1000 [00:00<00:04, 186.72it/s]
Converged to LP: -434.3:  18%|█▊        | 185/1000 [00:00<00:03, 225.17it/s]
Converged to LP: -428.8:  19%|█▊        | 186/1000 [00:00<00:03, 216.80it/s]
LP: -489.0:   0%|          | 0/1000 [00:00<?, ?it/s]

Fitting 5 states...


Converged to LP: -457.0:   6%|▋         | 64/1000 [00:00<00:08, 116.31it/s]
Converged to LP: -453.0:  14%|█▎        | 135/1000 [00:00<00:05, 168.88it/s]
Converged to LP: -465.1:  16%|█▌        | 157/1000 [00:01<00:05, 142.85it/s]
Converged to LP: -465.8:  21%|██        | 208/1000 [00:01<00:04, 184.95it/s]
Converged to LP: -463.8:  28%|██▊       | 275/1000 [00:01<00:04, 177.02it/s]
Converged to LP: -311.2:   2%|▏         | 15/1000 [00:00<00:03, 277.34it/s]
LP: -298.8:   0%|          | 0/1000 [00:00<?, ?it/s]

Fitting session 15...
Fitting 2 states...


Converged to LP: -302.7:   3%|▎         | 26/1000 [00:00<00:03, 257.15it/s]
Converged to LP: -291.6:   6%|▌         | 61/1000 [00:00<00:03, 277.31it/s]
Converged to LP: -307.8:   6%|▌         | 57/1000 [00:00<00:04, 197.55it/s]
Converged to LP: -298.5:   6%|▌         | 60/1000 [00:00<00:04, 206.79it/s]


Fitting 3 states...


Converged to LP: -343.4:   2%|▏         | 18/1000 [00:00<00:08, 120.40it/s]
Converged to LP: -340.8:   3%|▎         | 30/1000 [00:00<00:06, 147.74it/s]
Converged to LP: -341.2:   4%|▍         | 39/1000 [00:00<00:05, 184.19it/s]
Converged to LP: -342.5:   6%|▌         | 57/1000 [00:00<00:05, 182.85it/s]
Converged to LP: -355.6:   8%|▊         | 81/1000 [00:00<00:05, 178.18it/s]
LP: -391.2:   1%|▏         | 14/1000 [00:00<00:07, 136.46it/s]

Fitting 4 states...


Converged to LP: -390.9:   6%|▌         | 61/1000 [00:00<00:04, 188.97it/s]
Converged to LP: -378.5:   6%|▌         | 56/1000 [00:00<00:05, 168.69it/s]
Converged to LP: -374.7:   8%|▊         | 83/1000 [00:00<00:04, 196.96it/s]
Converged to LP: -378.9:  14%|█▎        | 137/1000 [00:00<00:04, 187.71it/s]
Converged to LP: -375.7:  18%|█▊        | 182/1000 [00:00<00:03, 213.91it/s]
LP: -436.3:   0%|          | 0/1000 [00:00<?, ?it/s]

Fitting 5 states...


Converged to LP: -429.9:   9%|▉         | 91/1000 [00:00<00:06, 135.64it/s]
Converged to LP: -420.3:  16%|█▌        | 160/1000 [00:00<00:04, 179.92it/s]
Converged to LP: -416.3:  21%|██▏       | 213/1000 [00:01<00:04, 170.09it/s]
Converged to LP: -423.0:  16%|█▌        | 162/1000 [00:01<00:06, 123.59it/s]
Converged to LP: -425.3:  29%|██▉       | 293/1000 [00:01<00:04, 172.47it/s]


In [None]:
off_medication_results = {
    'global':{
		'inputs': off_med_inputs_aggregate,
		'choices': off_med_choices_aggregate,
		'masks': off_med_masks_aggregate,
		'models': models_glm_hmm_off_med,
		'fit_lls': fit_lls_glm_hmm_off_med,
		'best_params': init_params
	},
	'session':{
		'session_ids': off_med_sessions,
        'unnormalized_inputs': unnormalized_off_med_inputs,
		'inputs': inputs_session_wise,
		'choices': choices_session_wise,
		'masks': masks_session_wise,
		'reaction_time': reaction_time_session_wise,
		'models': models_session_state_fold_off_med,
		'train_lls': train_ll_session_off_med,
		'test_lls': test_ll_session_off_med
	}
}


with open(Path(processed_dir, f'glm_hmm_off_meds_result.pkl'), 'wb') as f:
    pickle.dump(off_medication_results, f)

### On medication sessions

In [None]:
inputs_session_wise = []
choices_session_wise = []
invalid_idx_session_wise = []
masks_session_wise = []
reaction_time_session_wise = []

# on medication sessions
for session in on_med_sessions:
    session_data = data[data['session_id'] == session].reset_index(drop=True)


    invalid_idx = np.where(session_data.outcome < 0)[0]
    valid_idx = np.where(session_data.outcome >= 0)[0]

    inputs = prepare_input_data(session_data, input_dim, invalid_idx)
    choices = session_data.choice.values.reshape(-1,1).astype('int')

    # for training, replace -1 with random sample from 0,1
    choices[choices == -1] = npr.choice([0,1],invalid_idx.shape[0])
    mask = np.ones_like(choices, dtype=bool)
    mask[invalid_idx] = 0
    reaction_time = np.array(session_data.reaction_time)

    masks_session_wise.append(mask)
    inputs_session_wise += inputs
    choices_session_wise.append(choices)
    reaction_time_session_wise.append(reaction_time)

on_med_inputs_aggregate, on_med_choices_aggregate, on_med_masks_aggregate = [], [], []
on_med_inputs_aggregate.append(np.vstack(inputs_session_wise))
on_med_choices_aggregate.append(np.vstack(choices_session_wise))
on_med_masks_aggregate.append(np.vstack(masks_session_wise))

unnormalized_on_med_inputs = copy.deepcopy(inputs_session_wise)

# scaling signed coherence
on_med_inputs_aggregate[0][on_med_masks_aggregate[0][:,0],0] = preprocessing.scale(on_med_inputs_aggregate[0][on_med_masks_aggregate[0][:,0],0], axis=0)
for idx in range(len(on_med_sessions)):
    inputs_session_wise[idx][masks_session_wise[idx][:,0]] = preprocessing.scale(inputs_session_wise[idx][masks_session_wise[idx][:,0]], axis=0)


In [None]:
models_glm_hmm_on_med, fit_lls_glm_hmm_on_med = global_fit(on_med_choices_aggregate, on_med_inputs_aggregate, masks=on_med_masks_aggregate, n_iters= 1000, n_initializations=20)

Fitting GLM globally...


  0%|          | 0/1000 [00:00<?, ?it/s]

Fitting 2 states...


Converged to LP: -5361.9:   0%|          | 2/1000 [00:09<1:21:19,  4.89s/it]
Converged to LP: -5132.0:   7%|▋         | 66/1000 [00:11<02:45,  5.65it/s]
Converged to LP: -5132.0:   7%|▋         | 74/1000 [00:11<02:29,  6.19it/s]
Converged to LP: -5132.0:   9%|▉         | 89/1000 [00:12<02:03,  7.38it/s]
Converged to LP: -5132.0:   7%|▋         | 70/1000 [00:12<02:41,  5.75it/s]
Converged to LP: -5132.0:  10%|▉         | 95/1000 [00:12<01:57,  7.70it/s]
Converged to LP: -5132.0:  10%|▉         | 97/1000 [00:12<01:56,  7.77it/s]
Converged to LP: -5132.0:  11%|█         | 108/1000 [00:12<01:45,  8.49it/s]
Converged to LP: -5132.0:  11%|█         | 111/1000 [00:12<01:41,  8.73it/s]
Converged to LP: -5132.0:  12%|█▏        | 118/1000 [00:12<01:35,  9.24it/s]
Converged to LP: -5132.0:  10%|█         | 100/1000 [00:13<01:57,  7.69it/s]
Converged to LP: -5132.0:  10%|█         | 104/1000 [00:13<01:53,  7.89it/s]
Converged to LP: -5132.0:  11%|█         | 110/1000 [00:13<01:46,  8.32it/s]
Conve

Fitting 3 states...


Converged to LP: -5075.2:   6%|▌         | 62/1000 [00:02<00:44, 21.02it/s]
Converged to LP: -5075.4:   9%|▉         | 89/1000 [00:04<00:44, 20.28it/s]
Converged to LP: -5075.4:   7%|▋         | 70/1000 [00:03<00:49, 18.80it/s]
Converged to LP: -5075.4:   9%|▉         | 92/1000 [00:04<00:44, 20.22it/s]
Converged to LP: -5075.4:  10%|█         | 100/1000 [00:04<00:43, 20.51it/s]
Converged to LP: -5075.4:  11%|█         | 107/1000 [00:05<00:45, 19.73it/s]
Converged to LP: -5075.4:  13%|█▎        | 134/1000 [00:06<00:42, 20.50it/s]
Converged to LP: -5075.4:  12%|█▏        | 124/1000 [00:06<00:42, 20.41it/s]
Converged to LP: -5075.4:  14%|█▎        | 135/1000 [00:06<00:42, 20.56it/s]
Converged to LP: -5075.2:  17%|█▋        | 173/1000 [00:06<00:33, 24.77it/s]
Converged to LP: -5075.4:  14%|█▍        | 139/1000 [00:06<00:42, 20.28it/s]
Converged to LP: -5075.4:  19%|█▉        | 189/1000 [00:08<00:36, 22.46it/s]
Converged to LP: -5075.1:  18%|█▊        | 177/1000 [00:08<00:38, 21.51it/s]
Con

Fitting 4 states...


Converged to LP: -5455.8:   2%|▏         | 18/1000 [00:00<00:45, 21.65it/s]
Converged to LP: -5107.4:  18%|█▊        | 185/1000 [00:10<00:45, 17.83it/s]
Converged to LP: -5031.9:  16%|█▌        | 157/1000 [00:10<00:56, 14.94it/s]
Converged to LP: -5031.9:  20%|██        | 203/1000 [00:13<00:52, 15.31it/s]
Converged to LP: -5111.0:  29%|██▊       | 287/1000 [00:13<00:34, 20.62it/s]
Converged to LP: -5111.5:  29%|██▉       | 294/1000 [00:14<00:33, 20.80it/s]
Converged to LP: -5112.2:  31%|███       | 309/1000 [00:15<00:34, 20.25it/s]
Converged to LP: -5111.0:  33%|███▎      | 333/1000 [00:16<00:32, 20.25it/s]
Converged to LP: -5031.9:  30%|██▉       | 295/1000 [00:16<00:39, 17.66it/s]
Converged to LP: -5108.5:  29%|██▉       | 294/1000 [00:17<00:41, 17.15it/s]
Converged to LP: -5115.4:  39%|███▉      | 393/1000 [00:20<00:31, 19.47it/s]
Converged to LP: -5115.2:  42%|████▏     | 423/1000 [00:20<00:27, 20.62it/s]
Converged to LP: -5034.4:  45%|████▍     | 449/1000 [00:22<00:27, 20.37it/s]


Fitting 5 states...


Converged to LP: -5147.0:  29%|██▊       | 287/1000 [00:17<00:43, 16.47it/s]
Converged to LP: -5144.7:  39%|███▉      | 394/1000 [00:25<00:38, 15.75it/s]
Converged to LP: -5069.8:  36%|███▋      | 365/1000 [00:26<00:45, 13.99it/s]
Converged to LP: -5144.5:  52%|█████▏    | 515/1000 [00:32<00:30, 15.77it/s]
Converged to LP: -5029.6:  49%|████▊     | 486/1000 [00:34<00:36, 14.24it/s]
Converged to LP: -5071.2:  60%|█████▉    | 597/1000 [00:37<00:25, 15.97it/s]
Converged to LP: -5162.9:  60%|██████    | 602/1000 [00:37<00:24, 15.93it/s]
Converged to LP: -5064.0:  60%|██████    | 603/1000 [00:42<00:27, 14.18it/s]
Converged to LP: -5069.8:  72%|███████▏  | 721/1000 [00:43<00:16, 16.76it/s]
Converged to LP: -5065.1:  70%|███████   | 701/1000 [00:44<00:18, 15.87it/s]
Converged to LP: -5059.1:  67%|██████▋   | 673/1000 [00:44<00:21, 15.17it/s]
Converged to LP: -5064.4:  71%|███████▏  | 714/1000 [00:45<00:18, 15.86it/s]
Converged to LP: -5059.1:  75%|███████▍  | 748/1000 [00:46<00:15, 15.95it/s]

In [None]:
# get best model of 20 initializations for each state
init_params = {
    'glm_weights': {},
    'transition_matrices': {}
}
for n_states in np.arange(2,6):
    best_idx = fit_lls_glm_hmm_on_med[n_states].index(max(fit_lls_glm_hmm_on_med[n_states]))
    init_params['glm_weights'][n_states] = models_glm_hmm_on_med[n_states][best_idx].observations.params
    init_params['transition_matrices'][n_states] = models_glm_hmm_on_med[n_states][best_idx].transitions.params

In [None]:
# session-wise fitting with 5 fold cross-validation
models_session_state_fold_on_med, train_ll_session_on_med, test_ll_session_on_med = session_wise_fit_cv(choices_session_wise, inputs_session_wise, masks=masks_session_wise,
                                                                                    n_sessions=len(on_med_sessions), init_params=init_params, n_iters= 1000)

Fitting session 0...
Fitting 2 states...


Converged to LP: -318.9:   2%|▏         | 20/1000 [00:00<00:05, 185.92it/s]
Converged to LP: -319.2:   2%|▎         | 25/1000 [00:00<00:05, 168.07it/s]
Converged to LP: -310.4:   2%|▎         | 25/1000 [00:00<00:05, 187.86it/s]
Converged to LP: -315.1:   4%|▍         | 41/1000 [00:00<00:03, 240.75it/s]
Converged to LP: -316.8:   9%|▉         | 88/1000 [00:00<00:04, 212.52it/s]


Fitting 3 states...


Converged to LP: -356.6:   2%|▏         | 24/1000 [00:00<00:06, 142.74it/s]
Converged to LP: -361.6:   4%|▎         | 36/1000 [00:00<00:05, 171.42it/s]
Converged to LP: -353.5:   7%|▋         | 74/1000 [00:00<00:04, 223.96it/s]
Converged to LP: -346.6:  11%|█         | 110/1000 [00:00<00:04, 200.32it/s]
Converged to LP: -353.5:  11%|█         | 106/1000 [00:00<00:04, 188.61it/s]


Fitting 4 states...


Converged to LP: -397.2:   8%|▊         | 75/1000 [00:00<00:05, 178.03it/s]
Converged to LP: -396.7:   8%|▊         | 80/1000 [00:00<00:05, 159.70it/s]
Converged to LP: -397.8:   9%|▉         | 94/1000 [00:00<00:04, 182.88it/s]
Converged to LP: -398.5:  16%|█▋        | 163/1000 [00:00<00:04, 202.11it/s]
Converged to LP: -388.2:  16%|█▌        | 157/1000 [00:00<00:04, 180.34it/s]


Fitting 5 states...


Converged to LP: -458.8:   4%|▍         | 38/1000 [00:00<00:06, 137.44it/s]
Converged to LP: -437.0:  12%|█▏        | 122/1000 [00:00<00:05, 158.63it/s]
Converged to LP: -455.1:  17%|█▋        | 170/1000 [00:00<00:04, 176.91it/s]
Converged to LP: -439.5:  19%|█▉        | 192/1000 [00:01<00:05, 160.50it/s]
Converged to LP: -429.2:  33%|███▎      | 328/1000 [00:01<00:03, 186.56it/s]


Fitting session 1...
Fitting 2 states...


Converged to LP: -301.8:   1%|          | 6/1000 [00:00<00:11, 84.22it/s]
Converged to LP: -308.5:   1%|          | 9/1000 [00:00<00:09, 104.00it/s]
Converged to LP: -314.0:   2%|▏         | 22/1000 [00:00<00:05, 177.74it/s]
Converged to LP: -316.7:   3%|▎         | 32/1000 [00:00<00:05, 192.69it/s]
Converged to LP: -314.0:   1%|          | 8/1000 [00:00<00:09, 109.60it/s]


Fitting 3 states...


Converged to LP: -357.5:   1%|▏         | 13/1000 [00:00<00:09, 109.50it/s]
Converged to LP: -343.4:   1%|          | 10/1000 [00:00<00:09, 101.22it/s]
Converged to LP: -351.3:   1%|          | 11/1000 [00:00<00:08, 114.64it/s]
Converged to LP: -348.3:   2%|▏         | 16/1000 [00:00<00:07, 132.71it/s]
Converged to LP: -354.8:   7%|▋         | 73/1000 [00:00<00:05, 169.09it/s]
LP: -396.3:   2%|▏         | 15/1000 [00:00<00:06, 148.02it/s]

Fitting 4 states...


Converged to LP: -392.3:   4%|▍         | 38/1000 [00:00<00:06, 141.63it/s]
Converged to LP: -393.3:  15%|█▌        | 151/1000 [00:00<00:04, 192.94it/s]
Converged to LP: -387.6:  15%|█▌        | 151/1000 [00:00<00:04, 189.07it/s]
Converged to LP: -385.0:   2%|▏         | 24/1000 [00:00<00:07, 132.15it/s]
Converged to LP: -399.0:   5%|▌         | 50/1000 [00:00<00:05, 165.88it/s]
LP: -435.4:   1%|          | 7/1000 [00:00<00:15, 65.26it/s]s]

Fitting 5 states...


Converged to LP: -431.8:   6%|▌         | 59/1000 [00:00<00:06, 149.07it/s]
Converged to LP: -442.2:   8%|▊         | 80/1000 [00:00<00:06, 144.20it/s]
Converged to LP: -437.1:   9%|▉         | 92/1000 [00:00<00:05, 160.65it/s]
Converged to LP: -416.5:  13%|█▎        | 130/1000 [00:00<00:04, 196.37it/s]
Converged to LP: -432.6:  15%|█▍        | 149/1000 [00:00<00:04, 177.34it/s]
Converged to LP: -332.9:   1%|          | 12/1000 [00:00<00:06, 153.57it/s]
Converged to LP: -343.9:   2%|▏         | 20/1000 [00:00<00:04, 218.84it/s]
LP: -356.9:   0%|          | 0/1000 [00:00<?, ?it/s]71.80it/s]

Fitting session 2...
Fitting 2 states...
Fitting 3 states...


Converged to LP: -350.1:   2%|▏         | 18/1000 [00:00<00:05, 178.39it/s]
Converged to LP: -355.6:   3%|▎         | 27/1000 [00:00<00:04, 212.21it/s]
Converged to LP: -353.5:   4%|▎         | 37/1000 [00:00<00:04, 205.12it/s]
Converged to LP: -370.7:   2%|▏         | 21/1000 [00:00<00:05, 191.06it/s]
Converged to LP: -402.0:   2%|▎         | 25/1000 [00:00<00:04, 195.32it/s]
Converged to LP: -397.8:   3%|▎         | 31/1000 [00:00<00:04, 217.04it/s]
Converged to LP: -384.8:   6%|▋         | 65/1000 [00:00<00:04, 213.59it/s]
Converged to LP: -378.1:  12%|█▏        | 118/1000 [00:00<00:03, 227.71it/s]


Fitting 4 states...


Converged to LP: -417.9:   5%|▌         | 50/1000 [00:00<00:05, 186.26it/s]
Converged to LP: -436.0:   7%|▋         | 68/1000 [00:00<00:05, 169.47it/s]
Converged to LP: -440.9:   8%|▊         | 81/1000 [00:00<00:05, 181.41it/s]
Converged to LP: -431.3:  10%|▉         | 98/1000 [00:00<00:04, 183.00it/s]
Converged to LP: -417.0:   8%|▊         | 82/1000 [00:00<00:06, 132.61it/s]
LP: -485.2:   1%|          | 11/1000 [00:00<00:09, 105.49it/s]

Fitting 5 states...


Converged to LP: -475.8:   8%|▊         | 84/1000 [00:00<00:06, 145.91it/s]
Converged to LP: -459.3:   8%|▊         | 82/1000 [00:00<00:07, 125.34it/s]
Converged to LP: -460.7:  13%|█▎        | 128/1000 [00:01<00:06, 127.55it/s]
Converged to LP: -483.2:  20%|█▉        | 195/1000 [00:01<00:04, 173.41it/s]
Converged to LP: -470.8:  30%|██▉       | 296/1000 [00:01<00:04, 172.41it/s]
LP: -293.3:   0%|          | 0/1000 [00:00<?, ?it/s]

Fitting session 3...
Fitting 2 states...


Converged to LP: -297.3:   4%|▍         | 38/1000 [00:00<00:03, 254.22it/s]
Converged to LP: -275.2:   4%|▎         | 35/1000 [00:00<00:04, 235.53it/s]
Converged to LP: -287.6:   9%|▉         | 90/1000 [00:00<00:02, 314.98it/s]
Converged to LP: -294.5:  12%|█▏        | 117/1000 [00:00<00:02, 333.99it/s]
Converged to LP: -286.3:  12%|█▎        | 125/1000 [00:00<00:03, 230.98it/s]
LP: -338.7:   1%|▏         | 13/1000 [00:00<00:07, 125.76it/s]

Fitting 3 states...


Converged to LP: -335.3:   3%|▎         | 29/1000 [00:00<00:05, 190.22it/s]
Converged to LP: -335.8:   4%|▍         | 45/1000 [00:00<00:06, 149.88it/s]
Converged to LP: -330.8:   6%|▌         | 62/1000 [00:00<00:05, 173.93it/s]
Converged to LP: -333.4:  10%|█         | 105/1000 [00:00<00:04, 218.42it/s]
Converged to LP: -310.7:  24%|██▍       | 239/1000 [00:01<00:03, 222.85it/s]
LP: -392.0:   0%|          | 0/1000 [00:00<?, ?it/s]

Fitting 4 states...


Converged to LP: -374.8:   5%|▍         | 46/1000 [00:00<00:06, 139.15it/s]
Converged to LP: -375.1:   8%|▊         | 79/1000 [00:00<00:04, 185.45it/s]
Converged to LP: -374.5:  28%|██▊       | 280/1000 [00:01<00:03, 232.14it/s]
Converged to LP: -362.3:  29%|██▊       | 287/1000 [00:01<00:03, 212.81it/s]
Converged to LP: -365.1:  61%|██████    | 606/1000 [00:02<00:01, 232.93it/s]
LP: -430.2:   0%|          | 0/1000 [00:00<?, ?it/s]

Fitting 5 states...


Converged to LP: -412.1:   6%|▌         | 58/1000 [00:00<00:06, 155.74it/s]
Converged to LP: -410.3:  13%|█▎        | 132/1000 [00:00<00:05, 161.34it/s]
Converged to LP: -413.7:  15%|█▌        | 153/1000 [00:00<00:04, 179.07it/s]
Converged to LP: -406.5:  20%|██        | 201/1000 [00:01<00:05, 153.07it/s]
Converged to LP: -400.8:  26%|██▋       | 263/1000 [00:01<00:04, 180.85it/s]
LP: -409.8:   0%|          | 0/1000 [00:00<?, ?it/s]

Fitting session 4...
Fitting 2 states...


Converged to LP: -378.7:   1%|▏         | 13/1000 [00:00<00:03, 257.81it/s]
Converged to LP: -382.0:   2%|▏         | 18/1000 [00:00<00:03, 303.93it/s]
Converged to LP: -376.4:   5%|▌         | 50/1000 [00:00<00:03, 241.27it/s]
Converged to LP: -384.4:   7%|▋         | 68/1000 [00:00<00:03, 253.07it/s]
Converged to LP: -388.6:   9%|▉         | 90/1000 [00:00<00:03, 301.27it/s]
Converged to LP: -416.8:   1%|▏         | 13/1000 [00:00<00:05, 185.30it/s]
LP: -429.3:   0%|          | 0/1000 [00:00<?, ?it/s]

Fitting 3 states...
Fitting 4 states...


Converged to LP: -421.3:   2%|▏         | 16/1000 [00:00<00:06, 153.59it/s]
Converged to LP: -426.7:   2%|▎         | 25/1000 [00:00<00:05, 183.62it/s]
Converged to LP: -421.6:   4%|▎         | 37/1000 [00:00<00:04, 229.17it/s]
Converged to LP: -428.7:   4%|▍         | 39/1000 [00:00<00:04, 209.02it/s]
Converged to LP: -460.6:   4%|▍         | 43/1000 [00:00<00:05, 185.32it/s]
Converged to LP: -455.8:   4%|▍         | 44/1000 [00:00<00:05, 173.80it/s]
Converged to LP: -463.3:  13%|█▎        | 133/1000 [00:00<00:04, 185.93it/s]
Converged to LP: -453.1:  21%|██        | 212/1000 [00:00<00:03, 220.84it/s]
Converged to LP: -469.9:  21%|██        | 210/1000 [00:00<00:03, 212.43it/s]
LP: -517.5:   0%|          | 0/1000 [00:00<?, ?it/s]

Fitting 5 states...


Converged to LP: -499.7:   5%|▌         | 52/1000 [00:00<00:07, 133.86it/s]
Converged to LP: -494.2:   7%|▋         | 68/1000 [00:00<00:06, 155.15it/s]
Converged to LP: -491.9:  10%|▉         | 97/1000 [00:00<00:05, 170.48it/s]
Converged to LP: -496.5:   9%|▉         | 89/1000 [00:00<00:05, 153.90it/s]
Converged to LP: -508.3:  15%|█▌        | 150/1000 [00:00<00:05, 160.25it/s]
Converged to LP: -392.2:   1%|          | 7/1000 [00:00<00:05, 185.83it/s]
Converged to LP: -387.3:   1%|          | 11/1000 [00:00<00:04, 233.71it/s]
Converged to LP: -392.8:   1%|          | 11/1000 [00:00<00:04, 205.18it/s]
Converged to LP: -398.1:   1%|          | 11/1000 [00:00<00:04, 205.80it/s]
Converged to LP: -383.3:   2%|▏         | 17/1000 [00:00<00:04, 230.85it/s]
Converged to LP: -418.3:   1%|          | 8/1000 [00:00<00:06, 162.22it/s]
Converged to LP: -429.6:   1%|          | 9/1000 [00:00<00:05, 167.24it/s]
Converged to LP: -430.2:   1%|          | 10/1000 [00:00<00:05, 183.06it/s]
Converged to L

Fitting session 5...
Fitting 2 states...
Fitting 3 states...
Fitting 4 states...


Converged to LP: -423.9:   1%|          | 12/1000 [00:00<00:06, 157.54it/s]
Converged to LP: -434.7:   4%|▍         | 44/1000 [00:00<00:05, 167.87it/s]
Converged to LP: -430.9:   4%|▍         | 39/1000 [00:00<00:06, 143.53it/s]
Converged to LP: -441.8:   6%|▌         | 59/1000 [00:00<00:04, 195.13it/s]
Converged to LP: -437.1:   7%|▋         | 66/1000 [00:00<00:04, 192.25it/s]
Converged to LP: -450.2:   8%|▊         | 75/1000 [00:00<00:05, 178.09it/s]
LP: -498.0:   1%|          | 9/1000 [00:00<00:12, 80.88it/s]s]

Fitting 5 states...


Converged to LP: -492.0:   4%|▍         | 41/1000 [00:00<00:06, 143.66it/s]
Converged to LP: -498.4:   5%|▌         | 52/1000 [00:00<00:05, 161.64it/s]
Converged to LP: -483.2:   5%|▍         | 47/1000 [00:00<00:06, 140.45it/s]
Converged to LP: -511.3:   5%|▌         | 51/1000 [00:00<00:07, 125.99it/s]
Converged to LP: -490.5:   7%|▋         | 68/1000 [00:00<00:07, 126.63it/s]
LP: -277.2:   0%|          | 0/1000 [00:00<?, ?it/s]

Fitting session 6...
Fitting 2 states...


Converged to LP: -256.5:   1%|          | 11/1000 [00:00<00:03, 273.62it/s]
Converged to LP: -269.5:   1%|          | 11/1000 [00:00<00:03, 262.28it/s]
Converged to LP: -270.3:   2%|▎         | 25/1000 [00:00<00:04, 243.65it/s]
Converged to LP: -270.0:   6%|▋         | 65/1000 [00:00<00:03, 262.44it/s]
Converged to LP: -266.6:  13%|█▎        | 133/1000 [00:00<00:02, 319.80it/s]
Converged to LP: -314.6:   2%|▏         | 19/1000 [00:00<00:05, 192.12it/s]
LP: -313.8:   1%|          | 12/1000 [00:00<00:08, 114.72it/s]

Fitting 3 states...


Converged to LP: -297.4:   4%|▍         | 42/1000 [00:00<00:04, 213.48it/s]
Converged to LP: -310.1:   4%|▎         | 36/1000 [00:00<00:06, 138.15it/s]
Converged to LP: -312.5:   4%|▍         | 44/1000 [00:00<00:06, 151.67it/s]
Converged to LP: -296.9:  19%|█▊        | 186/1000 [00:00<00:03, 225.25it/s]
LP: -357.7:   1%|          | 10/1000 [00:00<00:10, 98.22it/s]]

Fitting 4 states...


Converged to LP: -350.4:   6%|▌         | 57/1000 [00:00<00:05, 172.57it/s]
Converged to LP: -356.0:   7%|▋         | 71/1000 [00:00<00:05, 156.29it/s]
Converged to LP: -341.1:   7%|▋         | 68/1000 [00:00<00:06, 142.54it/s]
Converged to LP: -350.9:  17%|█▋        | 171/1000 [00:00<00:04, 189.62it/s]
Converged to LP: -354.7:  24%|██▍       | 244/1000 [00:01<00:03, 202.66it/s]
LP: -399.0:   1%|▏         | 14/1000 [00:00<00:07, 134.80it/s]

Fitting 5 states...


Converged to LP: -400.2:   5%|▌         | 51/1000 [00:00<00:05, 164.93it/s]
Converged to LP: -402.7:   7%|▋         | 66/1000 [00:00<00:05, 184.83it/s]
Converged to LP: -403.8:   5%|▌         | 54/1000 [00:00<00:06, 136.23it/s]
Converged to LP: -385.0:  23%|██▎       | 229/1000 [00:01<00:04, 183.73it/s]
Converged to LP: -385.5:  28%|██▊       | 284/1000 [00:01<00:04, 156.62it/s]
Converged to LP: -284.4:   2%|▏         | 17/1000 [00:00<00:03, 286.24it/s]
Converged to LP: -289.2:   2%|▎         | 25/1000 [00:00<00:03, 268.31it/s]
Converged to LP: -282.7:   2%|▏         | 17/1000 [00:00<00:05, 177.44it/s]
Converged to LP: -276.0:   3%|▎         | 31/1000 [00:00<00:04, 223.29it/s]
LP: -284.2:   2%|▏         | 17/1000 [00:00<00:05, 168.39it/s]

Fitting session 7...
Fitting 2 states...


Converged to LP: -281.6:  10%|▉         | 95/1000 [00:00<00:03, 271.58it/s]


Fitting 3 states...


Converged to LP: -328.8:   2%|▏         | 22/1000 [00:00<00:07, 139.30it/s]
Converged to LP: -326.7:   4%|▍         | 42/1000 [00:00<00:04, 232.65it/s]
Converged to LP: -326.5:   4%|▍         | 43/1000 [00:00<00:04, 233.46it/s]
Converged to LP: -322.3:   4%|▍         | 39/1000 [00:00<00:05, 187.75it/s]
Converged to LP: -332.3:   5%|▍         | 49/1000 [00:00<00:04, 226.94it/s]
LP: -376.6:   1%|▏         | 14/1000 [00:00<00:07, 139.85it/s]

Fitting 4 states...


Converged to LP: -364.9:   9%|▉         | 94/1000 [00:00<00:04, 181.35it/s]
Converged to LP: -360.6:  20%|█▉        | 196/1000 [00:01<00:04, 171.56it/s]
Converged to LP: -363.2:  23%|██▎       | 230/1000 [00:01<00:03, 199.11it/s]
Converged to LP: -351.5:  34%|███▍      | 342/1000 [00:01<00:03, 194.56it/s]
Converged to LP: -369.1:  62%|██████▏   | 622/1000 [00:02<00:01, 212.03it/s]
LP: -431.3:   0%|          | 0/1000 [00:00<?, ?it/s]

Fitting 5 states...


Converged to LP: -408.8:   8%|▊         | 83/1000 [00:00<00:05, 154.38it/s]
Converged to LP: -418.8:  12%|█▏        | 120/1000 [00:00<00:06, 130.86it/s]
Converged to LP: -408.4:  27%|██▋       | 267/1000 [00:01<00:03, 192.59it/s]
Converged to LP: -417.9:  36%|███▋      | 365/1000 [00:01<00:03, 191.55it/s]
Converged to LP: -403.0:  36%|███▋      | 365/1000 [00:02<00:03, 159.31it/s]
Converged to LP: -373.1:   1%|          | 10/1000 [00:00<00:03, 273.24it/s]
Converged to LP: -360.0:   1%|          | 10/1000 [00:00<00:03, 247.88it/s]
Converged to LP: -363.9:   1%|          | 12/1000 [00:00<00:03, 266.08it/s]
Converged to LP: -373.7:   1%|▏         | 13/1000 [00:00<00:03, 260.19it/s]
Converged to LP: -367.7:   2%|▏         | 19/1000 [00:00<00:05, 170.50it/s]
LP: -418.1:   0%|          | 0/1000 [00:00<?, ?it/s]

Fitting session 8...
Fitting 2 states...
Fitting 3 states...


Converged to LP: -418.1:   2%|▏         | 15/1000 [00:00<00:04, 218.80it/s]
Converged to LP: -403.7:   1%|          | 12/1000 [00:00<00:05, 189.79it/s]
Converged to LP: -416.0:   2%|▏         | 19/1000 [00:00<00:04, 230.96it/s]
Converged to LP: -409.4:   2%|▏         | 15/1000 [00:00<00:06, 143.23it/s]
Converged to LP: -407.7:   6%|▌         | 57/1000 [00:00<00:04, 219.63it/s]


Fitting 4 states...


Converged to LP: -447.4:   5%|▍         | 47/1000 [00:00<00:05, 177.08it/s]
Converged to LP: -451.3:   8%|▊         | 85/1000 [00:00<00:06, 144.01it/s]
Converged to LP: -454.9:  11%|█         | 112/1000 [00:00<00:05, 172.10it/s]
Converged to LP: -443.4:  13%|█▎        | 134/1000 [00:00<00:04, 187.44it/s]
Converged to LP: -447.7:  25%|██▌       | 251/1000 [00:01<00:03, 228.92it/s]
LP: -502.9:   0%|          | 0/1000 [00:00<?, ?it/s]

Fitting 5 states...


Converged to LP: -480.3:  10%|▉         | 95/1000 [00:00<00:06, 144.96it/s]
Converged to LP: -481.8:  11%|█         | 112/1000 [00:00<00:05, 156.45it/s]
Converged to LP: -494.1:  12%|█▏        | 122/1000 [00:00<00:06, 144.88it/s]
Converged to LP: -499.1:  13%|█▎        | 129/1000 [00:00<00:05, 149.97it/s]
Converged to LP: -478.8:  21%|██        | 210/1000 [00:01<00:05, 154.56it/s]
Converged to LP: -346.5:   1%|          | 6/1000 [00:00<00:03, 264.47it/s]
Converged to LP: -348.6:   1%|          | 7/1000 [00:00<00:03, 286.68it/s]
Converged to LP: -346.4:   1%|          | 8/1000 [00:00<00:03, 269.52it/s]
Converged to LP: -338.2:   1%|          | 6/1000 [00:00<00:05, 173.64it/s]
Converged to LP: -339.2:   1%|          | 11/1000 [00:00<00:04, 211.68it/s]
LP: -392.5:   0%|          | 0/1000 [00:00<?, ?it/s]

Fitting session 9...
Fitting 2 states...
Fitting 3 states...


Converged to LP: -391.4:   2%|▏         | 21/1000 [00:00<00:03, 249.77it/s]
Converged to LP: -382.0:   2%|▏         | 18/1000 [00:00<00:04, 226.30it/s]
Converged to LP: -392.5:   3%|▎         | 30/1000 [00:00<00:03, 253.96it/s]
Converged to LP: -381.7:   4%|▍         | 39/1000 [00:00<00:04, 229.47it/s]
Converged to LP: -392.0:   6%|▋         | 65/1000 [00:00<00:04, 233.58it/s]


Fitting 4 states...


Converged to LP: -436.2:   3%|▎         | 33/1000 [00:00<00:05, 189.28it/s]
Converged to LP: -423.9:   4%|▍         | 42/1000 [00:00<00:06, 149.06it/s]
Converged to LP: -419.5:   8%|▊         | 81/1000 [00:00<00:04, 205.78it/s]
Converged to LP: -402.2:  15%|█▌        | 151/1000 [00:00<00:04, 197.51it/s]
Converged to LP: -431.2:  15%|█▍        | 146/1000 [00:00<00:04, 189.90it/s]
LP: -486.8:   0%|          | 0/1000 [00:00<?, ?it/s]

Fitting 5 states...


Converged to LP: -454.2:   4%|▎         | 37/1000 [00:00<00:09, 96.32it/s]
Converged to LP: -475.4:   8%|▊         | 83/1000 [00:00<00:05, 175.39it/s]
Converged to LP: -460.6:  10%|█         | 105/1000 [00:00<00:05, 165.25it/s]
Converged to LP: -459.7:   9%|▉         | 90/1000 [00:00<00:06, 135.34it/s]
Converged to LP: -461.5:  12%|█▏        | 117/1000 [00:00<00:05, 163.43it/s]
Converged to LP: -337.3:   1%|▏         | 14/1000 [00:00<00:03, 247.28it/s]
LP: -330.4:   3%|▎         | 27/1000 [00:00<00:03, 263.91it/s]

Fitting session 10...
Fitting 2 states...


Converged to LP: -330.4:   4%|▍         | 40/1000 [00:00<00:03, 290.45it/s]
Converged to LP: -335.8:  10%|▉         | 95/1000 [00:00<00:02, 307.29it/s]
Converged to LP: -336.7:  11%|█         | 108/1000 [00:00<00:03, 255.25it/s]
Converged to LP: -346.6:  30%|███       | 300/1000 [00:01<00:02, 285.56it/s]
LP: -380.0:   0%|          | 0/1000 [00:00<?, ?it/s]

Fitting 3 states...


Converged to LP: -385.8:   8%|▊         | 77/1000 [00:00<00:05, 184.35it/s]
Converged to LP: -373.5:  12%|█▎        | 125/1000 [00:00<00:04, 218.31it/s]
Converged to LP: -377.8:  12%|█▏        | 121/1000 [00:00<00:04, 191.05it/s]
Converged to LP: -376.2:  15%|█▍        | 149/1000 [00:00<00:03, 228.31it/s]
Converged to LP: -375.7:  16%|█▋        | 164/1000 [00:00<00:03, 247.07it/s]


Fitting 4 states...


Converged to LP: -413.9:   7%|▋         | 69/1000 [00:00<00:05, 178.64it/s]
Converged to LP: -439.2:   9%|▉         | 91/1000 [00:00<00:04, 199.06it/s]
Converged to LP: -418.3:   9%|▉         | 93/1000 [00:00<00:05, 172.73it/s]
Converged to LP: -419.3:  19%|█▉        | 192/1000 [00:00<00:03, 207.94it/s]
Converged to LP: -413.3:  32%|███▏      | 319/1000 [00:01<00:03, 199.15it/s]
LP: -486.0:   0%|          | 0/1000 [00:00<?, ?it/s]

Fitting 5 states...


Converged to LP: -451.5:  11%|█▏        | 113/1000 [00:00<00:05, 165.88it/s]
Converged to LP: -449.9:  14%|█▍        | 140/1000 [00:00<00:05, 153.69it/s]
Converged to LP: -471.6:  19%|█▉        | 193/1000 [00:01<00:04, 175.57it/s]
Converged to LP: -471.9:  25%|██▍       | 247/1000 [00:01<00:04, 177.51it/s]
Converged to LP: -453.4:  30%|██▉       | 299/1000 [00:01<00:03, 175.49it/s]
Converged to LP: -387.5:   1%|          | 12/1000 [00:00<00:05, 182.25it/s]
LP: -374.7:   2%|▏         | 24/1000 [00:00<00:04, 235.42it/s]

Fitting session 11...
Fitting 2 states...


Converged to LP: -374.7:   3%|▎         | 26/1000 [00:00<00:04, 237.90it/s]
Converged to LP: -377.5:   3%|▎         | 33/1000 [00:00<00:03, 275.19it/s]
Converged to LP: -372.8:   4%|▎         | 37/1000 [00:00<00:03, 273.81it/s]
Converged to LP: -376.7:  10%|▉         | 97/1000 [00:00<00:03, 286.06it/s]
LP: -419.9:   2%|▏         | 20/1000 [00:00<00:05, 192.09it/s], 207.87it/s]
Converged to LP: -419.8:   3%|▎         | 31/1000 [00:00<00:04, 208.15it/s]
Converged to LP: -411.8:   2%|▎         | 25/1000 [00:00<00:05, 163.33it/s]
LP: -407.0:   2%|▏         | 19/1000 [00:00<00:05, 183.10it/s]

Fitting 3 states...


Converged to LP: -404.2:   6%|▋         | 64/1000 [00:00<00:04, 207.26it/s]
Converged to LP: -431.1:  10%|▉         | 97/1000 [00:00<00:03, 235.36it/s]
LP: -464.5:   1%|▏         | 13/1000 [00:00<00:07, 127.15it/s]

Fitting 4 states...


Converged to LP: -464.9:   7%|▋         | 72/1000 [00:00<00:05, 177.51it/s]
Converged to LP: -437.7:  10%|▉         | 97/1000 [00:00<00:05, 176.66it/s]
Converged to LP: -449.5:   9%|▉         | 94/1000 [00:00<00:05, 151.06it/s]
Converged to LP: -451.5:  13%|█▎        | 127/1000 [00:00<00:05, 145.79it/s]
Converged to LP: -459.6:  18%|█▊        | 183/1000 [00:01<00:04, 178.24it/s]
LP: -504.2:   0%|          | 0/1000 [00:00<?, ?it/s]5.36it/s]]

Fitting 5 states...


Converged to LP: -503.3:  10%|█         | 102/1000 [00:00<00:05, 151.53it/s]
Converged to LP: -490.0:  12%|█▏        | 122/1000 [00:00<00:05, 156.86it/s]
Converged to LP: -492.2:  13%|█▎        | 131/1000 [00:00<00:05, 153.17it/s]
Converged to LP: -502.5:  17%|█▋        | 166/1000 [00:01<00:05, 163.76it/s]
Converged to LP: -497.3:  27%|██▋       | 272/1000 [00:01<00:04, 173.78it/s]
Converged to LP: -354.0:   1%|▏         | 14/1000 [00:00<00:04, 213.30it/s]
Converged to LP: -358.6:   3%|▎         | 26/1000 [00:00<00:03, 261.83it/s]
Converged to LP: -359.1:   2%|▏         | 18/1000 [00:00<00:05, 169.97it/s]
Converged to LP: -357.8:   3%|▎         | 31/1000 [00:00<00:04, 210.42it/s]
Converged to LP: -368.4:   4%|▍         | 39/1000 [00:00<00:03, 257.76it/s]


Fitting session 12...
Fitting 2 states...
Fitting 3 states...


Converged to LP: -404.8:   2%|▏         | 20/1000 [00:00<00:04, 220.59it/s]
Converged to LP: -406.1:   2%|▎         | 25/1000 [00:00<00:04, 242.38it/s]
Converged to LP: -400.1:   2%|▏         | 24/1000 [00:00<00:04, 196.82it/s]
Converged to LP: -409.1:   4%|▍         | 38/1000 [00:00<00:05, 178.15it/s]
Converged to LP: -388.8:   8%|▊         | 79/1000 [00:00<00:04, 185.33it/s]
LP: -442.8:   2%|▏         | 15/1000 [00:00<00:06, 141.47it/s]

Fitting 4 states...


Converged to LP: -443.4:   8%|▊         | 79/1000 [00:00<00:04, 198.67it/s]
Converged to LP: -422.3:   7%|▋         | 67/1000 [00:00<00:06, 146.20it/s]
Converged to LP: -429.3:   9%|▉         | 93/1000 [00:00<00:05, 172.19it/s]
Converged to LP: -434.1:  11%|█         | 108/1000 [00:00<00:04, 196.02it/s]
Converged to LP: -432.9:  22%|██▏       | 224/1000 [00:01<00:03, 207.52it/s]
LP: -496.8:   0%|          | 0/1000 [00:00<?, ?it/s]

Fitting 5 states...


Converged to LP: -474.7:   6%|▌         | 61/1000 [00:00<00:07, 131.24it/s]
Converged to LP: -482.5:   5%|▌         | 50/1000 [00:00<00:09, 104.99it/s]
Converged to LP: -476.6:   6%|▋         | 64/1000 [00:00<00:08, 108.87it/s]
Converged to LP: -477.8:   9%|▉         | 92/1000 [00:00<00:06, 132.06it/s]
Converged to LP: -477.3:  20%|█▉        | 197/1000 [00:01<00:04, 162.96it/s]
Converged to LP: -381.2:   2%|▏         | 15/1000 [00:00<00:03, 261.61it/s]
Converged to LP: -378.1:   1%|▏         | 14/1000 [00:00<00:04, 200.94it/s]
Converged to LP: -377.5:   2%|▏         | 15/1000 [00:00<00:05, 165.68it/s]
Converged to LP: -373.0:   2%|▏         | 17/1000 [00:00<00:05, 183.14it/s]


Fitting session 13...
Fitting 2 states...


Converged to LP: -368.6:  20%|█▉        | 198/1000 [00:00<00:02, 292.52it/s]
LP: -448.0:   0%|          | 0/1000 [00:00<?, ?it/s]

Fitting 3 states...


Converged to LP: -426.0:   2%|▏         | 18/1000 [00:00<00:06, 153.58it/s]
Converged to LP: -411.8:   2%|▎         | 25/1000 [00:00<00:05, 163.07it/s]
Converged to LP: -414.3:  10%|▉         | 95/1000 [00:00<00:04, 223.43it/s]
Converged to LP: -412.1:  20%|██        | 200/1000 [00:00<00:03, 219.32it/s]
Converged to LP: -412.1:  24%|██▍       | 241/1000 [00:01<00:03, 217.58it/s]
LP: -452.5:   1%|▏         | 14/1000 [00:00<00:07, 137.25it/s]

Fitting 4 states...


Converged to LP: -440.9:   7%|▋         | 67/1000 [00:00<00:05, 158.81it/s]
Converged to LP: -455.8:   8%|▊         | 80/1000 [00:00<00:07, 127.89it/s]
Converged to LP: -446.1:  17%|█▋        | 168/1000 [00:00<00:03, 209.99it/s]
Converged to LP: -443.8:  16%|█▌        | 156/1000 [00:00<00:04, 171.09it/s]
Converged to LP: -451.9:  21%|██        | 211/1000 [00:01<00:04, 179.01it/s]
LP: -515.2:   1%|          | 8/1000 [00:00<00:13, 73.04it/s]s]

Fitting 5 states...


Converged to LP: -500.0:   6%|▌         | 57/1000 [00:00<00:06, 157.11it/s]
Converged to LP: -502.4:   9%|▊         | 87/1000 [00:00<00:08, 113.39it/s]
Converged to LP: -476.3:  13%|█▎        | 130/1000 [00:00<00:06, 135.34it/s]
Converged to LP: -492.8:  24%|██▍       | 241/1000 [00:01<00:04, 155.85it/s]
Converged to LP: -497.6:  26%|██▋       | 265/1000 [00:01<00:04, 167.05it/s]
Converged to LP: -343.4:   2%|▏         | 17/1000 [00:00<00:03, 265.59it/s]
Converged to LP: -341.9:   2%|▏         | 23/1000 [00:00<00:03, 285.98it/s]
LP: -335.3:   2%|▏         | 22/1000 [00:00<00:04, 210.84it/s]

Fitting session 14...
Fitting 2 states...


Converged to LP: -340.1:   3%|▎         | 30/1000 [00:00<00:05, 192.35it/s]
Converged to LP: -343.5:   4%|▍         | 42/1000 [00:00<00:03, 259.15it/s]
Converged to LP: -329.5:  13%|█▎        | 127/1000 [00:00<00:03, 260.90it/s]
LP: -385.5:   0%|          | 0/1000 [00:00<?, ?it/s]

Fitting 3 states...


Converged to LP: -385.6:   3%|▎         | 26/1000 [00:00<00:04, 200.01it/s]
Converged to LP: -386.1:   5%|▌         | 51/1000 [00:00<00:04, 225.47it/s]
Converged to LP: -379.7:   5%|▌         | 52/1000 [00:00<00:04, 201.03it/s]
Converged to LP: -369.9:   8%|▊         | 75/1000 [00:00<00:03, 238.00it/s]
Converged to LP: -385.0:   6%|▌         | 62/1000 [00:00<00:05, 184.34it/s]
LP: -429.9:   1%|▏         | 13/1000 [00:00<00:07, 129.36it/s]

Fitting 4 states...


Converged to LP: -425.0:   8%|▊         | 81/1000 [00:00<00:04, 185.27it/s]
Converged to LP: -426.9:   7%|▋         | 68/1000 [00:00<00:06, 133.91it/s]
Converged to LP: -419.3:  13%|█▎        | 126/1000 [00:00<00:04, 186.72it/s]
Converged to LP: -434.3:  18%|█▊        | 185/1000 [00:00<00:03, 225.17it/s]
Converged to LP: -428.8:  19%|█▊        | 186/1000 [00:00<00:03, 216.80it/s]
LP: -489.0:   0%|          | 0/1000 [00:00<?, ?it/s]

Fitting 5 states...


Converged to LP: -457.0:   6%|▋         | 64/1000 [00:00<00:08, 116.31it/s]
Converged to LP: -453.0:  14%|█▎        | 135/1000 [00:00<00:05, 168.88it/s]
Converged to LP: -465.1:  16%|█▌        | 157/1000 [00:01<00:05, 142.85it/s]
Converged to LP: -465.8:  21%|██        | 208/1000 [00:01<00:04, 184.95it/s]
Converged to LP: -463.8:  28%|██▊       | 275/1000 [00:01<00:04, 177.02it/s]
Converged to LP: -311.2:   2%|▏         | 15/1000 [00:00<00:03, 277.34it/s]
LP: -298.8:   0%|          | 0/1000 [00:00<?, ?it/s]

Fitting session 15...
Fitting 2 states...


Converged to LP: -302.7:   3%|▎         | 26/1000 [00:00<00:03, 257.15it/s]
Converged to LP: -291.6:   6%|▌         | 61/1000 [00:00<00:03, 277.31it/s]
Converged to LP: -307.8:   6%|▌         | 57/1000 [00:00<00:04, 197.55it/s]
Converged to LP: -298.5:   6%|▌         | 60/1000 [00:00<00:04, 206.79it/s]


Fitting 3 states...


Converged to LP: -343.4:   2%|▏         | 18/1000 [00:00<00:08, 120.40it/s]
Converged to LP: -340.8:   3%|▎         | 30/1000 [00:00<00:06, 147.74it/s]
Converged to LP: -341.2:   4%|▍         | 39/1000 [00:00<00:05, 184.19it/s]
Converged to LP: -342.5:   6%|▌         | 57/1000 [00:00<00:05, 182.85it/s]
Converged to LP: -355.6:   8%|▊         | 81/1000 [00:00<00:05, 178.18it/s]
LP: -391.2:   1%|▏         | 14/1000 [00:00<00:07, 136.46it/s]

Fitting 4 states...


Converged to LP: -390.9:   6%|▌         | 61/1000 [00:00<00:04, 188.97it/s]
Converged to LP: -378.5:   6%|▌         | 56/1000 [00:00<00:05, 168.69it/s]
Converged to LP: -374.7:   8%|▊         | 83/1000 [00:00<00:04, 196.96it/s]
Converged to LP: -378.9:  14%|█▎        | 137/1000 [00:00<00:04, 187.71it/s]
Converged to LP: -375.7:  18%|█▊        | 182/1000 [00:00<00:03, 213.91it/s]
LP: -436.3:   0%|          | 0/1000 [00:00<?, ?it/s]

Fitting 5 states...


Converged to LP: -429.9:   9%|▉         | 91/1000 [00:00<00:06, 135.64it/s]
Converged to LP: -420.3:  16%|█▌        | 160/1000 [00:00<00:04, 179.92it/s]
Converged to LP: -416.3:  21%|██▏       | 213/1000 [00:01<00:04, 170.09it/s]
Converged to LP: -423.0:  16%|█▌        | 162/1000 [00:01<00:06, 123.59it/s]
Converged to LP: -425.3:  29%|██▉       | 293/1000 [00:01<00:04, 172.47it/s]


In [None]:
on_medication_results = {
    'global':{
		'inputs': on_med_inputs_aggregate,
		'choices': on_med_choices_aggregate,
		'masks': on_med_masks_aggregate,
		'models': models_glm_hmm_on_med,
		'fit_lls': fit_lls_glm_hmm_on_med,
		'best_params': init_params
	},
	'session':{
		'session_ids': on_med_sessions,
        'unnormalized_inputs': unnormalized_on_med_inputs,
		'inputs': inputs_session_wise,
		'choices': choices_session_wise,
		'masks': masks_session_wise,
		'reaction_time': reaction_time_session_wise,
		'models': models_session_state_fold_on_med,
		'train_lls': train_ll_session_on_med,
		'test_lls': test_ll_session_on_med
	}
}


with open(Path(processed_dir, f'glm_hmm_on_meds_result.pkl'), 'wb') as f:
    pickle.dump(on_medication_results, f)