# Reproducing Ashwood NatNeuro

## Download the data

In [None]:
import numpy as np
import numpy.random as npr
from scipy.stats import bernoulli
import json
import os
from oneibl.onelight import ONE

one = ONE()

def get_animal_name(eid):
    # get session id:
    raw_session_id = eid.split('Subjects/')[1]
    # Get animal:
    animal = raw_session_id.split('/')[0]
    return animal


def get_raw_data(eid):
    print(eid)
    # get session id:
    raw_session_id = eid.split('Subjects/')[1]
    # Get animal:
    animal = raw_session_id.split('/')[0]
    # replace '/' with dash in session ID
    session_id = raw_session_id.replace('/', '-')
    # hack to work with ONE:
    current_dir = os.getcwd()
    os.chdir("../../data/ibl/")
    # Get choice data, stim data and rewarded/not rewarded:
    choice = one.load_dataset(eid, '_ibl_trials.choice')
    stim_left = one.load_dataset(eid, '_ibl_trials.contrastLeft')
    stim_right = one.load_dataset(eid, '_ibl_trials.contrastRight')
    rewarded = one.load_dataset(eid, '_ibl_trials.feedbackType')
    bias_probs = one.load_dataset(eid, '_ibl_trials.probabilityLeft')
    os.chdir(current_dir)
    return animal, session_id, stim_left, stim_right, rewarded, choice, \
           bias_probs


def create_stim_vector(stim_left, stim_right):
    # want stim_right - stim_left
    # Replace NaNs with 0:
    stim_left = np.nan_to_num(stim_left, nan=0)
    stim_right = np.nan_to_num(stim_right, nan=0)
    # now get 1D stim
    signed_contrast = stim_right - stim_left
    return signed_contrast


def create_previous_choice_vector(choice):
    ''' choice: choice vector of size T
        previous_choice : vector of size T with previous choice made by
        animal - output is in {0, 1}, where 0 corresponds to a previous left
        choice; 1 corresponds to right.
        If the previous choice was a violation, replace this with the choice
        on the previous trial that was not a violation.
        locs_mapping: array of size (~num_viols)x2, where the entry in
        column 1 is the location in the previous choice vector that was a
        remapping due to a violation and the
        entry in column 2 is the location in the previous choice vector that
        this location was remapped to
    '''
    previous_choice = np.hstack([np.array(choice[0]), choice])[:-1]
    locs_to_update = np.where(previous_choice == -1)[0]
    locs_with_choice = np.where(previous_choice != -1)[0]
    loc_first_choice = locs_with_choice[0]
    locs_mapping = np.zeros((len(locs_to_update) - loc_first_choice, 2),
                            dtype='int')

    for i, loc in enumerate(locs_to_update):
        if loc < loc_first_choice:
            # since no previous choice, bernoulli sample: (not output of
            # bernoulli rvs is in {1, 2})
            previous_choice[loc] = bernoulli.rvs(0.5, 1) - 1
        else:
            # find nearest loc that has a previous choice value that is not
            # -1, and that is earlier than current trial
            potential_matches = locs_with_choice[
                np.where(locs_with_choice < loc)]
            absolute_val_diffs = np.abs(loc - potential_matches)
            absolute_val_diffs_ind = absolute_val_diffs.argmin()
            nearest_loc = potential_matches[absolute_val_diffs_ind]
            locs_mapping[i - loc_first_choice, 0] = int(loc)
            locs_mapping[i - loc_first_choice, 1] = int(nearest_loc)
            previous_choice[loc] = previous_choice[nearest_loc]
    assert len(np.unique(
        previous_choice)) <= 2, "previous choice should be in {0, 1}; " + str(
        np.unique(previous_choice))
    return previous_choice, locs_mapping


def create_wsls_covariate(previous_choice, success, locs_mapping):
    '''
    inputs:
    success: vector of size T, entries are in {-1, 1} and 0 corresponds to
    failure, 1 corresponds to success
    previous_choice: vector of size T, entries are in {0, 1} and 0
    corresponds to left choice, 1 corresponds to right choice
    locs_mapping: location remapping dictionary due to violations
    output:
    wsls: vector of size T, entries are in {-1, 1}.  1 corresponds to
    previous choice = right and success OR previous choice = left and
    failure; -1 corresponds to
    previous choice = left and success OR previous choice = right and failure
    '''
    # remap previous choice vals to {-1, 1}
    remapped_previous_choice = 2 * previous_choice - 1
    previous_reward = np.hstack([np.array(success[0]), success])[:-1]
    # Now need to go through and update previous reward to correspond to
    # same trial as previous choice:
    for i, loc in enumerate(locs_mapping[:, 0]):
        nearest_loc = locs_mapping[i, 1]
        previous_reward[loc] = previous_reward[nearest_loc]
    wsls = previous_reward * remapped_previous_choice
    assert len(np.unique(wsls)) == 2, "wsls should be in {-1, 1}"
    return wsls


def remap_choice_vals(choice):
    # raw choice vector has CW = 1 (correct response for stim on left),
    # CCW = -1 (correct response for stim on right) and viol = 0.  Let's
    # remap so that CW = 0, CCw = 1, and viol = -1
    choice_mapping = {1: 0, -1: 1, 0: -1}
    new_choice_vector = [choice_mapping[old_choice] for old_choice in choice]
    return new_choice_vector


def create_design_mat(choice, stim_left, stim_right, rewarded):
    # Create unnormalized_inpt: with first column = stim_right - stim_left,
    # second column as past choice, third column as WSLS
    stim = create_stim_vector(stim_left, stim_right)
    T = len(stim)
    design_mat = np.zeros((T, 3))
    design_mat[:, 0] = stim
    # make choice vector so that correct response for stim>0 is choice =1
    # and is 0 for stim <0 (viol is mapped to -1)
    choice = remap_choice_vals(choice)
    # create past choice vector:
    previous_choice, locs_mapping = create_previous_choice_vector(choice)
    # create wsls vector:
    wsls = create_wsls_covariate(previous_choice, rewarded, locs_mapping)
    # map previous choice to {-1,1}
    design_mat[:, 1] = 2 * previous_choice - 1
    design_mat[:, 2] = wsls
    return design_mat


def get_all_unnormalized_data_this_session(eid):
    # Load raw data
    animal, session_id, stim_left, stim_right, rewarded, choice, bias_probs \
        = get_raw_data(eid)
    # Subset choice and design_mat to 50-50 entries:
    trials_to_study = np.where(bias_probs == 0.5)[0]
    num_viols_50 = len(np.where(choice[trials_to_study] == 0)[0])
    if num_viols_50 < 10:
        # Create design mat = matrix of size T x 3, with entries for
        # stim/past choice/wsls
        unnormalized_inpt = create_design_mat(choice[trials_to_study],
                                              stim_left[trials_to_study],
                                              stim_right[trials_to_study],
                                              rewarded[trials_to_study])
        y = np.expand_dims(remap_choice_vals(choice[trials_to_study]), axis=1)
        session = [session_id for i in range(y.shape[0])]
        rewarded = np.expand_dims(rewarded[trials_to_study], axis=1)
    else:
        unnormalized_inpt = np.zeros((90, 3))
        y = np.zeros((90, 1))
        session = []
        rewarded = np.zeros((90, 1))
    return animal, unnormalized_inpt, y, session, num_viols_50, rewarded


def load_animal_list(file):
    container = np.load(file, allow_pickle=True)
    data = [container[key] for key in container]
    animal_list = data[0]
    return animal_list


def load_animal_eid_dict(file):
    with open(file, 'r') as f:
        animal_eid_dict = json.load(f)
    return animal_eid_dict


def load_data(animal_file):
    container = np.load(animal_file, allow_pickle=True)
    data = [container[key] for key in container]
    inpt = data[0]
    y = data[1]
    y = y.astype('int')
    session = data[2]
    return inpt, y, session


def create_train_test_sessions(session, num_folds=5):
    # create a session-fold lookup table
    num_sessions = len(np.unique(session))
    # Map sessions to folds:
    unshuffled_folds = np.repeat(np.arange(num_folds),
                                 np.ceil(num_sessions / num_folds))
    shuffled_folds = npr.permutation(unshuffled_folds)[:num_sessions]
    assert len(np.unique(
        shuffled_folds)) == 5, "require at least one session per fold for " \
                               "each animal!"
    # Look up table of shuffle-folds:
    sess_id = np.array(np.unique(session), dtype='str')
    shuffled_folds = np.array(shuffled_folds, dtype='O')
    session_fold_lookup_table = np.transpose(
        np.vstack([sess_id, shuffled_folds]))
    return session_fold_lookup_table

## Preprocess

In [None]:
# Download IBL dataset and begin processing it: identify unique animals in
# IBL dataset that enter biased blocks.  Save a dictionary with each animal
# and a list of their eids in the biased blocks

import numpy as np
from oneibl.onelight import ONE
import numpy.random as npr
import json
from collections import defaultdict
import wget
from zipfile import ZipFile
import os
#from preprocessing_utils import get_animal_name
npr.seed(65)

DOWNLOAD_DATA = True # change to True to download raw data (WARNING: this
# can take a while)

if __name__ == '__main__':
    ibl_data_path = "../../int-brain-lab/glm-hmm/data/ibl/"
    if DOWNLOAD_DATA: # Warning: this step takes a while
        if not os.path.exists(ibl_data_path):
            os.makedirs(ibl_data_path)
        # download IBL data
        url = 'https://ndownloader.figshare.com/files/21623715'
        wget.download(url, ibl_data_path)
        # now unzip downloaded data:
        with ZipFile(ibl_data_path + "ibl-behavior-data-Dec2019.zip",
                     'r') as zipObj:
            # extract all the contents of zip file in ibl_data_path
            zipObj.extractall(ibl_data_path)

    # create directory for saving data:
    if not os.path.exists(ibl_data_path + "partially_processed/"):
        os.makedirs(ibl_data_path + "partially_processed/")

    # change directory so that ONE searches in correct directory:
    os.chdir(ibl_data_path)
    one = ONE()
    eids = one.search(['_ibl_trials.*'])
    assert len(eids) > 0, "ONE search is in incorrect directory"
    animal_list = []
    animal_eid_dict = defaultdict(list)

    for eid in eids:
        bias_probs = one.load_dataset(eid, '_ibl_trials.probabilityLeft')
        comparison = np.unique(bias_probs) == np.array([0.2, 0.5, 0.8])
        # sessions with bias blocks
        if isinstance(comparison, np.ndarray):
            # update def of comparison to single True/False
            comparison = comparison.all()
        if comparison == True:
            animal = get_animal_name(eid)
            if animal not in animal_list:
                animal_list.append(animal)
            animal_eid_dict[animal].append(eid)

    json = json.dumps(animal_eid_dict)
    f = open("partially_processed/animal_eid_dict.json",  "w")
    f.write(json)
    f.close()

    np.savez('partially_processed/animal_list.npz', animal_list)

In [None]:
# Continue preprocessing of IBL dataset and create design matrix for GLM-HMM
import numpy as np
from sklearn import preprocessing
import numpy.random as npr
import os
import json
from collections import defaultdict
#from preprocessing_utils import load_animal_list, load_animal_eid_dict, \
#    get_all_unnormalized_data_this_session, create_train_test_sessions

npr.seed(65)

if __name__ == '__main__':
    data_dir = '../../data/ibl/'
    # Create directories for saving data:
    processed_ibl_data_path = data_dir + "data_for_cluster/"
    if not os.path.exists(processed_ibl_data_path):
        os.makedirs(processed_ibl_data_path)
    # Also create a subdirectory for storing each individual animal's data:
    if not os.path.exists(processed_ibl_data_path + "data_by_animal/"):
        os.makedirs(processed_ibl_data_path + "data_by_animal/")

    # Load animal list/results of partial processing:
    animal_list = load_animal_list(
        data_dir + 'partially_processed/animal_list.npz')
    animal_eid_dict = load_animal_eid_dict(
        data_dir + 'partially_processed/animal_eid_dict.json')

    # Require that each animal has at least 30 sessions (=2700 trials) of data:
    req_num_sessions = 30  # 30*90 = 2700
    for animal in animal_list:
        num_sessions = len(animal_eid_dict[animal])
        if num_sessions < req_num_sessions:
            animal_list = np.delete(animal_list,
                                    np.where(animal_list == animal))
    # Identify idx in master array where each animal's data starts and ends:
    animal_start_idx = {}
    animal_end_idx = {}

    final_animal_eid_dict = defaultdict(list)
    # WORKHORSE: iterate through each animal and each animal's set of eids;
    # obtain unnormalized data.  Write out each animal's data and then also
    # write to master array
    for z, animal in enumerate(animal_list):
        sess_counter = 0
        for eid in animal_eid_dict[animal]:
            animal, unnormalized_inpt, y, session, num_viols_50, rewarded = \
                get_all_unnormalized_data_this_session(
                    eid)
            if num_viols_50 < 10:  # only include session if number of viols
                # in 50-50 block is less than 10
                if sess_counter == 0:
                    animal_unnormalized_inpt = np.copy(unnormalized_inpt)
                    animal_y = np.copy(y)
                    animal_session = session
                    animal_rewarded = np.copy(rewarded)
                else:
                    animal_unnormalized_inpt = np.vstack(
                        (animal_unnormalized_inpt, unnormalized_inpt))
                    animal_y = np.vstack((animal_y, y))
                    animal_session = np.concatenate((animal_session, session))
                    animal_rewarded = np.vstack((animal_rewarded, rewarded))
                sess_counter += 1
                final_animal_eid_dict[animal].append(eid)
        # Write out animal's unnormalized data matrix:
        np.savez(
            processed_ibl_data_path + 'data_by_animal/' + animal +
            '_unnormalized.npz',
            animal_unnormalized_inpt, animal_y,
            animal_session)
        animal_session_fold_lookup = create_train_test_sessions(animal_session,
                                                                5)
        np.savez(
            processed_ibl_data_path + 'data_by_animal/' + animal +
            "_session_fold_lookup" +
            ".npz",
            animal_session_fold_lookup)
        np.savez(
            processed_ibl_data_path + 'data_by_animal/' + animal +
            '_rewarded.npz',
            animal_rewarded)
        assert animal_rewarded.shape[0] == animal_y.shape[0]
        # Now create or append data to master array across all animals:
        if z == 0:
            master_inpt = np.copy(animal_unnormalized_inpt)
            animal_start_idx[animal] = 0
            animal_end_idx[animal] = master_inpt.shape[0] - 1
            master_y = np.copy(animal_y)
            master_session = animal_session
            master_session_fold_lookup_table = animal_session_fold_lookup
            master_rewarded = np.copy(animal_rewarded)
        else:
            animal_start_idx[animal] = master_inpt.shape[0]
            master_inpt = np.vstack((master_inpt, animal_unnormalized_inpt))
            animal_end_idx[animal] = master_inpt.shape[0] - 1
            master_y = np.vstack((master_y, animal_y))
            master_session = np.concatenate((master_session, animal_session))
            master_session_fold_lookup_table = np.vstack(
                (master_session_fold_lookup_table, animal_session_fold_lookup))
            master_rewarded = np.vstack((master_rewarded, animal_rewarded))
    # Write out data from across animals
    assert np.shape(master_inpt)[0] == np.shape(master_y)[
        0], "inpt and y not same length"
    assert np.shape(master_rewarded)[0] == np.shape(master_y)[
        0], "rewarded and y not same length"
    assert len(np.unique(master_session)) == \
           np.shape(master_session_fold_lookup_table)[
               0], "number of unique sessions and session fold lookup don't " \
                   "match"
    assert len(master_inpt) == 181530, "design matrix for all IBL animals " \
                                       "should have shape (181530, 3)"
    assert len(animal_list) == 37, "37 animals were studied in Ashwood et " \
                                   "al. (2020)"
    normalized_inpt = np.copy(master_inpt)
    normalized_inpt[:, 0] = preprocessing.scale(normalized_inpt[:, 0])
    np.savez(processed_ibl_data_path + 'all_animals_concat' + '.npz',
             normalized_inpt,
             master_y, master_session)
    np.savez(
        processed_ibl_data_path + 'all_animals_concat_unnormalized' + '.npz',
        master_inpt, master_y, master_session)
    np.savez(
        processed_ibl_data_path + 'all_animals_concat_session_fold_lookup' +
        '.npz',
        master_session_fold_lookup_table)
    np.savez(processed_ibl_data_path + 'all_animals_concat_rewarded' + '.npz',
             master_rewarded)
    np.savez(processed_ibl_data_path + 'data_by_animal/' + 'animal_list.npz',
             animal_list)

    json = json.dumps(final_animal_eid_dict)
    f = open(processed_ibl_data_path + "final_animal_eid_dict.json", "w")
    f.write(json)
    f.close()

    # Now write out normalized data (when normalized across all animals) for
    # each animal:
    counter = 0
    for animal in animal_start_idx.keys():
        start_idx = animal_start_idx[animal]
        end_idx = animal_end_idx[animal]
        inpt = normalized_inpt[range(start_idx, end_idx + 1)]
        y = master_y[range(start_idx, end_idx + 1)]
        session = master_session[range(start_idx, end_idx + 1)]
        counter += inpt.shape[0]
        np.savez(processed_ibl_data_path + 'data_by_animal/' + animal + '_processed.npz',
                 inpt, y,
                 session)

    assert counter == master_inpt.shape[0]

In [None]:
# Obtain IBL response time data for producing Figure 6
# Write out the response times and the corresponding sessions
import os

import numpy as np
import numpy.random as npr
from oneibl.onelight import ONE

#from preprocessing_utils import load_animal_eid_dict, load_data

npr.seed(65)

if __name__ == '__main__':
    ibl_data_path = "../../data/ibl/"
    animal_eid_dict = load_animal_eid_dict(
        ibl_data_path + 'data_for_cluster/final_animal_eid_dict.json')
    # must change directory for working with ONE
    os.chdir(ibl_data_path)
    one = ONE()

    data_dir = 'response_times/data_by_animal/'
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)

    for animal in animal_eid_dict.keys():
        print(animal)
        animal_inpt, animal_y, animal_session = load_data(
            'data_for_cluster/data_by_animal/' + animal + '_processed.npz')
        for z, eid in enumerate(animal_eid_dict[animal]):
            raw_session_id = eid.split('Subjects/')[1]
            session_id = raw_session_id.replace('/', '-')
            full_sess_len = len(one.load_dataset(eid, '_ibl_trials.choice'))

            file_names = [
                '_ibl_trials.feedback_times', '_ibl_trials.response_times',
                '_ibl_trials.goCue_times', '_ibl_trials.stimOn_times'
            ]

            save_vars = [
                'feedback_times', 'response_times', 'go_cues', 'stim_on_times'
            ]

            for i, file in enumerate(file_names):
                full_path = 'ibl-behavioral-data-Dec2019/' + eid + \
                                 '/alf/' + file + '.npy'
                if os.path.exists(full_path):
                    globals()[save_vars[i]] = one.load_dataset(eid, file)
                else:
                    globals()[save_vars[i]] = np.empty((full_sess_len, ))
                    globals()[save_vars[i]][:] = np.nan

            start = np.nanmin(np.c_[stim_on_times, go_cues], axis=1)

            if (len(feedback_times) == len(response_times)): # some response
                # times/feedback times are missing, so fill these as best as
                # possible
                end = np.nanmin(np.c_[feedback_times, response_times], axis=1)
            elif len(feedback_times) == full_sess_len:
                end = feedback_times
            elif len(response_times) == full_sess_len:
                end = response_times

            # check timestamps increasing:
            idx_to_change = np.where(start > end)[0]

            if len(idx_to_change) > 0:
                start[idx_to_change[0]] = np.nan
                end[idx_to_change[0]] = np.nan

            # Check we have times for at least some trials
            nan_trial = np.isnan(np.c_[start, end]).any(axis=1)

            is_increasing = (((start < end) | nan_trial).all() and
                    ((np.diff(start) > 0) | np.isnan(
                        np.diff(start))).all())

            if is_increasing and ~nan_trial.all() and len(start) == \
                    full_sess_len and len(end) == full_sess_len: #
                # check that times are increasing and that len(start) ==
                # full_sess_len etc
                prob_left_dta = one.load_dataset(
                    eid, '_ibl_trials.probabilityLeft')
                assert start.shape[0] == prob_left_dta.shape[0],\
                    "different lengths for prob left and raw response dta: " + \
                    str(start.shape[0]) + " vs " + str(
                        prob_left_dta.shape[0])

                # subset to trials corresponding to prob_left == 0.5:
                unbiased_idx = np.where(prob_left_dta == 0.5)
                response_dta = end[unbiased_idx] - start[unbiased_idx]

                if ((np.nanmedian(response_dta) >= 10) | (np.nanmedian(
                        response_dta) == np.nan)): # check that median
                    # response time for session is less than 10 seconds
                    response_dta = np.array([np.nan for i in range(len(
                        unbiased_idx[0]))])

                rt_sess = [session_id for i in range(response_dta.shape[0])]
                # before saving, confirm that there are as many trials as in
                # some of the other data:
                assert len(rt_sess) == animal_inpt[np.where(animal_session ==
                                                            session_id),
                                       :].shape[1], "response dta is different " \
                                                    "shape compared to inpt"
            else: # if any of the conditions above fail, fill the session's
                # data with nans
                len_prob_50 = animal_inpt[np.where(animal_session ==
                                                            session_id),
                              :].shape[1]
                response_dta = np.array([np.nan for i in range(len_prob_50)])
                rt_sess = [session_id for i in range(response_dta.shape[0])]

            if z == 0:
                rt_session_dta_this_animal = rt_sess
                response_dta_this_animal = response_dta
            else:
                rt_session_dta_this_animal = np.concatenate(
                    (rt_session_dta_this_animal, rt_sess))
                response_dta_this_animal = np.concatenate(
                    (response_dta_this_animal, response_dta))

        assert len(response_dta_this_animal) == len(animal_inpt), "different size for response times and inpt"
        np.savez(data_dir + animal + '.npz', response_dta_this_animal,
                 rt_session_dta_this_animal)

## GLM

In [None]:
# GLM class
import autograd.numpy as np
import autograd.numpy.random as npr
from autograd.scipy.special import logsumexp
# Import useful functions from ssm package
from ssm.util import ensure_args_are_lists
from ssm.optimizers import adam, bfgs, rmsprop, sgd
import ssm.stats as stats


class glm(object):
    def __init__(self, M, C):
        """
        @param C:  number of classes in the categorical observations
        """
        self.M = M
        self.C = C
        # Parameters linking input to state distribution
        self.Wk = npr.randn(1, C - 1, M + 1)

    @property
    def params(self):
        return self.Wk

    @params.setter
    def params(self, value):
        self.Wk = value

    def log_prior(self):
        return 0

    # Calculate time dependent logits - output is matrix of size Tx1xC
    # Input is size TxM
    def calculate_logits(self, input):
        # Update input to include offset term:
        input = np.append(input, np.ones((input.shape[0], 1)), axis=1)
        # Add additional row (of zeros) to second dimension of self.Wk
        Wk_tranpose = np.transpose(self.Wk, (1, 0, 2))
        Wk = np.transpose(
            np.vstack([
                Wk_tranpose,
                np.zeros((1, Wk_tranpose.shape[1], Wk_tranpose.shape[2]))
            ]), (1, 0, 2))
        # Input effect; transpose so that output has dims TxKxC
        time_dependent_logits = np.transpose(np.dot(Wk, input.T), (2, 0, 1))
        time_dependent_logits = time_dependent_logits - logsumexp(
            time_dependent_logits, axis=2, keepdims=True)
        return time_dependent_logits

    # Calculate log-likelihood of observed data
    def log_likelihoods(self, data, input, mask, tag):
        time_dependent_logits = self.calculate_logits(input)
        mask = np.ones_like(data, dtype=bool) if mask is None else mask
        return stats.categorical_logpdf(data[:, None, :],
                                        time_dependent_logits[:, :, None, :],
                                        mask=mask[:, None, :])

    # log marginal likelihood of data
    @ensure_args_are_lists
    def log_marginal(self, datas, inputs, masks, tags):
        elbo = self.log_prior()
        for data, input, mask, tag in zip(datas, inputs, masks, tags):
            lls = self.log_likelihoods(data, input, mask, tag)
            elbo += np.sum(lls)
        return elbo

    @ensure_args_are_lists
    def fit_glm(self,
                datas,
                inputs,
                masks,
                tags,
                num_iters=1000,
                optimizer="bfgs",
                **kwargs):
        optimizer = dict(adam=adam, bfgs=bfgs, rmsprop=rmsprop,
                         sgd=sgd)[optimizer]

        def _objective(params, itr):
            self.params = params
            obj = self.log_marginal(datas, inputs, masks, tags)
            return -obj

        self.params = optimizer(_objective,
                                self.params,
                                num_iters=num_iters,
                                **kwargs)


In [None]:
import autograd.numpy as np
import autograd.numpy.random as npr
import matplotlib.pyplot as plt
#from GLM import glm

npr.seed(65)


def load_data(animal_file):
    container = np.load(animal_file, allow_pickle=True)
    data = [container[key] for key in container]
    inpt = data[0]
    y = data[1]
    session = data[2]
    return inpt, y, session


def fit_glm(inputs, datas, M, C):
    new_glm = glm(M, C)
    new_glm.fit_glm(datas, inputs, masks=None, tags=None)
    # Get loglikelihood of training data:
    loglikelihood_train = new_glm.log_marginal(datas, inputs, None, None)
    recovered_weights = new_glm.Wk
    return loglikelihood_train, recovered_weights


# Append column of zeros to weights matrix in appropriate location
def append_zeros(weights):
    weights_tranpose = np.transpose(weights, (1, 0, 2))
    weights = np.transpose(
        np.vstack([
            weights_tranpose,
            np.zeros((1, weights_tranpose.shape[1], weights_tranpose.shape[2]))
        ]), (1, 0, 2))
    return weights


def load_session_fold_lookup(file_path):
    container = np.load(file_path, allow_pickle=True)
    data = [container[key] for key in container]
    session_fold_lookup_table = data[0]
    return session_fold_lookup_table


def load_animal_list(list_file):
    container = np.load(list_file, allow_pickle=True)
    data = [container[key] for key in container]
    animal_list = data[0]
    return animal_list


def plot_input_vectors(Ws,
                       figure_directory,
                       title='true',
                       save_title="true",
                       labels_for_plot=[]):
    K = Ws.shape[0]
    K_prime = Ws.shape[1]
    M = Ws.shape[2] - 1
    fig = plt.figure(figsize=(7, 9), dpi=80, facecolor='w', edgecolor='k')
    plt.subplots_adjust(left=0.15,
                        bottom=0.27,
                        right=0.95,
                        top=0.95,
                        wspace=0.3,
                        hspace=0.3)

    for j in range(K):
        for k in range(K_prime - 1):
            # plt.subplot(K, K_prime, 1+j*K_prime+k)
            plt.plot(range(M + 1), -Ws[j][k], marker='o')
            plt.plot(range(-1, M + 2), np.repeat(0, M + 3), 'k', alpha=0.2)
            plt.axhline(y=0, color="k", alpha=0.5, ls="--")
            if len(labels_for_plot) > 0:
                plt.xticks(list(range(0, len(labels_for_plot))),
                           labels_for_plot,
                           rotation='90',
                           fontsize=12)
            else:
                plt.xticks(list(range(0, 3)),
                           ['Stimulus', 'Past Choice', 'Bias'],
                           rotation='90',
                           fontsize=12)
            plt.ylim((-3, 6))

    fig.text(0.04,
             0.5,
             "Weight",
             ha="center",
             va="center",
             rotation=90,
             fontsize=15)
    fig.suptitle("GLM Weights: " + title, y=0.99, fontsize=14)
    fig.savefig(figure_directory + 'glm_weights_' + save_title + '.png')


### GLM: all animals together

In [None]:
#  Fit GLM to all IBL data together

import autograd.numpy as np
import autograd.numpy.random as npr
import os
#from glm_utils import load_session_fold_lookup, load_data, fit_glm, \
#    plot_input_vectors, append_zeros

C = 2  # number of output types/categories
N_initializations = 10
npr.seed(65)  # set seed in case of randomization

if __name__ == '__main__':
    data_dir = '../../int-brain-lab/glm-hmm/data/ibl/data_for_cluster/'
    num_folds = 5

    # Create directory for results:
    results_dir = '../../int-brain-lab/glm-hmm/results/ibl_global_fit/'
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)

    # Fit GLM to all data
    animal_file = data_dir + 'all_animals_concat.npz'
    inpt, y, session = load_data(animal_file)
    session_fold_lookup_table = load_session_fold_lookup(
        data_dir + 'all_animals_concat_session_fold_lookup.npz')

    for fold in range(num_folds):
        # Subset to relevant covariates for covar set of interest:
        labels_for_plot = ['stim', 'P_C', 'WSLS', 'bias']
        y = y.astype('int')
        figure_directory = results_dir + "GLM/fold_" + str(fold) + '/'
        if not os.path.exists(figure_directory):
            os.makedirs(figure_directory)

        # Subset to sessions of interest for fold
        sessions_to_keep = session_fold_lookup_table[np.where(
            session_fold_lookup_table[:, 1] != fold), 0]
        idx_this_fold = [
            str(sess) in sessions_to_keep and y[id, 0] != -1
            for id, sess in enumerate(session)
        ]
        this_inpt, this_y, this_session = inpt[idx_this_fold, :], y[
            idx_this_fold, :], session[idx_this_fold]
        assert len(
            np.unique(this_y)
        ) == 2, "choice vector should only include 2 possible values"
        train_size = this_inpt.shape[0]

        M = this_inpt.shape[1]
        loglikelihood_train_vector = []

        for iter in range(N_initializations):  # GLM fitting should be
            # independent of initialization, so fitting multiple
            # initializations is a good way to check that everything is
            # working correctly
            loglikelihood_train, recovered_weights = fit_glm([this_inpt],
                                                             [this_y], M, C)
            weights_for_plotting = append_zeros(recovered_weights)
            plot_input_vectors(weights_for_plotting,
                               figure_directory,
                               title="GLM fit; Final LL = " +
                               str(loglikelihood_train),
                               save_title='init' + str(iter),
                               labels_for_plot=labels_for_plot)
            loglikelihood_train_vector.append(loglikelihood_train)
            np.savez(
                figure_directory + 'variables_of_interest_iter_' + str(iter) +
                '.npz', loglikelihood_train, recovered_weights)

### GLM: single animals

In [None]:
# Fit GLM to each IBL animal separately
import autograd.numpy as np
import autograd.numpy.random as npr
import os
#from glm_utils import load_session_fold_lookup, load_data, load_animal_list, \
#    fit_glm, plot_input_vectors, append_zeros

npr.seed(65)

C = 2  # number of output types/categories
N_initializations = 10

if __name__ == '__main__':
    data_dir = '../../int-brain-lab/glm-hmm/data/ibl/data_for_cluster/data_by_animal/'
    num_folds = 5
    animal_list = load_animal_list(data_dir + 'animal_list.npz')

    results_dir = '../../int-brain-lab/glm-hmm/results/ibl_individual_fit/'
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)

    for animal in animal_list:
        # Fit GLM to data from single animal:
        animal_file = data_dir + animal + '_processed.npz'
        session_fold_lookup_table = load_session_fold_lookup(
            data_dir + animal + '_session_fold_lookup.npz')

        for fold in range(num_folds):
            this_results_dir = results_dir + animal + '/'

            # Load data
            inpt, y, session = load_data(animal_file)
            labels_for_plot = ['stim', 'pc', 'wsls', 'bias']
            y = y.astype('int')

            figure_directory = this_results_dir + "GLM/fold_" + str(fold) + '/'
            if not os.path.exists(figure_directory):
                os.makedirs(figure_directory)

            # Subset to sessions of interest for fold
            sessions_to_keep = session_fold_lookup_table[np.where(
                session_fold_lookup_table[:, 1] != fold), 0]
            idx_this_fold = [
                str(sess) in sessions_to_keep and y[id, 0] != -1
                for id, sess in enumerate(session)
            ]
            this_inpt, this_y, this_session = inpt[idx_this_fold, :], \
                                              y[idx_this_fold, :], \
                                              session[idx_this_fold]
            assert len(
                np.unique(this_y)
            ) == 2, "choice vector should only include 2 possible values"
            train_size = this_inpt.shape[0]

            M = this_inpt.shape[1]
            loglikelihood_train_vector = []

            for iter in range(N_initializations):
                loglikelihood_train, recovered_weights = fit_glm([this_inpt],
                                                                 [this_y], M,
                                                                 C)
                weights_for_plotting = append_zeros(recovered_weights)
                plot_input_vectors(weights_for_plotting,
                                   figure_directory,
                                   title="GLM fit; Final LL = " +
                                   str(loglikelihood_train),
                                   save_title='init' + str(iter),
                                   labels_for_plot=labels_for_plot)
                loglikelihood_train_vector.append(loglikelihood_train)
                np.savez(
                    figure_directory + 'variables_of_interest_iter_' +
                    str(iter) + '.npz', loglikelihood_train, recovered_weights)


## GLM-HMM

### Global GLM-HMM

In [2]:
# Functions to assist with GLM-HMM model fitting
import sys
import ssm
import autograd.numpy as np
import autograd.numpy.random as npr


def load_data(animal_file):
    container = np.load(animal_file, allow_pickle=True)
    data = [container[key] for key in container]
    inpt = data[0]
    y = data[1]
    session = data[2]
    return inpt, y, session


def load_cluster_arr(cluster_arr_file):
    container = np.load(cluster_arr_file, allow_pickle=True)
    data = [container[key] for key in container]
    cluster_arr = data[0]
    return cluster_arr


def load_glm_vectors(glm_vectors_file):
    container = np.load(glm_vectors_file)
    data = [container[key] for key in container]
    loglikelihood_train = data[0]
    recovered_weights = data[1]
    return loglikelihood_train, recovered_weights


def load_global_params(global_params_file):
    container = np.load(global_params_file, allow_pickle=True)
    data = [container[key] for key in container]
    global_params = data[0]
    return global_params


def partition_data_by_session(inpt, y, mask, session):
    '''
    Partition inpt, y, mask by session
    :param inpt: arr of size TxM
    :param y:  arr of size T x D
    :param mask: Boolean arr of size T indicating if element is violation or
    not
    :param session: list of size T containing session ids
    :return: list of inpt arrays, data arrays and mask arrays, where the
    number of elements in list = number of sessions and each array size is
    number of trials in session
    '''
    inputs = []
    datas = []
    indexes = np.unique(session, return_index=True)[1]
    unique_sessions = [session[index] for index in sorted(indexes)]
    counter = 0
    masks = []
    for sess in unique_sessions:
        idx = np.where(session == sess)[0]
        counter += len(idx)
        inputs.append(inpt[idx, :])
        datas.append(y[idx, :])
        masks.append(mask[idx, :])
    assert counter == inpt.shape[0], "not all trials assigned to session!"
    return inputs, datas, masks


def load_session_fold_lookup(file_path):
    container = np.load(file_path, allow_pickle=True)
    data = [container[key] for key in container]
    session_fold_lookup_table = data[0]
    return session_fold_lookup_table


def load_animal_list(file):
    container = np.load(file, allow_pickle=True)
    data = [container[key] for key in container]
    animal_list = data[0]
    return animal_list


def launch_glm_hmm_job(inpt, y, session, mask, session_fold_lookup_table, K, D,
                       C, N_em_iters, transition_alpha, prior_sigma, fold,
                       iter, global_fit, init_param_file, save_directory):
    print("Starting inference with K = " + str(K) + "; Fold = " + str(fold) +
          "; Iter = " + str(iter))
    sys.stdout.flush()
    sessions_to_keep = session_fold_lookup_table[np.where(
        session_fold_lookup_table[:, 1] != fold), 0]
    idx_this_fold = [str(sess) in sessions_to_keep for sess in session]
    this_inpt, this_y, this_session, this_mask = inpt[idx_this_fold, :], \
                                                 y[idx_this_fold, :], \
                                                 session[idx_this_fold], \
                                                 mask[idx_this_fold]
    # Only do this so that errors are avoided - these y values will not
    # actually be used for anything (due to violation mask)
    this_y[np.where(this_y == -1), :] = 1
    inputs, datas, masks = partition_data_by_session(
        this_inpt, this_y, this_mask, this_session)
    # Read in GLM fit if global_fit = True:
    if global_fit == True:
        _, params_for_initialization = load_glm_vectors(init_param_file)
    else:
        params_for_initialization = load_global_params(init_param_file)
    M = this_inpt.shape[1]
    npr.seed(iter)
    fit_glm_hmm(datas,
                inputs,
                masks,
                K,
                D,
                M,
                C,
                N_em_iters,
                transition_alpha,
                prior_sigma,
                global_fit,
                params_for_initialization,
                save_title=save_directory + 'glm_hmm_raw_parameters_itr_' +
                           str(iter) + '.npz')


def fit_glm_hmm(datas, inputs, masks, K, D, M, C, N_em_iters,
                transition_alpha, prior_sigma, global_fit,
                params_for_initialization, save_title):
    '''
    Instantiate and fit GLM-HMM model
    :param datas:
    :param inputs:
    :param masks:
    :param K:
    :param D:
    :param M:
    :param C:
    :param N_em_iters:
    :param global_fit:
    :param glm_vectors:
    :param save_title:
    :return:
    '''
    if global_fit == True:
        # Prior variables
        # Choice of prior
        this_hmm = ssm.HMM(K,
                           D,
                           M,
                           observations="input_driven_obs",
                           observation_kwargs=dict(C=C,
                                                   prior_sigma=prior_sigma),
                           transitions="sticky",
                           transition_kwargs=dict(alpha=transition_alpha,
                                                  kappa=0))
        # Initialize observation weights as GLM weights with some noise:
        glm_vectors_repeated = np.tile(params_for_initialization, (K, 1, 1))
        glm_vectors_with_noise = glm_vectors_repeated + np.random.normal(
            0, 0.2, glm_vectors_repeated.shape)
        this_hmm.observations.params = glm_vectors_with_noise
    else:
        # Choice of prior
        this_hmm = ssm.HMM(K,
                           D,
                           M,
                           observations="input_driven_obs",
                           observation_kwargs=dict(C=C,
                                                   prior_sigma=prior_sigma),
                           transitions="sticky",
                           transition_kwargs=dict(alpha=transition_alpha,
                                                  kappa=0))
        # Initialize HMM-GLM with global parameters:
        this_hmm.params = params_for_initialization
        # Get log_prior of transitions:
    print("=== fitting GLM-HMM ========")
    sys.stdout.flush()
    # Fit this HMM and calculate marginal likelihood
    lls = this_hmm.fit(datas,
                       inputs=inputs,
                       masks=masks,
                       method="em",
                       num_iters=N_em_iters,
                       initialize=False,
                       tolerance=10 ** -4)
    # Save raw parameters of HMM, as well as loglikelihood during training
    np.savez(save_title, this_hmm.params, lls)
    return None


def create_violation_mask(violation_idx, T):
    """
    Return indices of nonviolations and also a Boolean mask for inclusion (1
    = nonviolation; 0 = violation)
    :param test_idx:
    :param T:
    :return:
    """
    mask = np.array([i not in violation_idx for i in range(T)])
    nonviolation_idx = np.arange(T)[mask]
    mask = mask + 0
    assert len(nonviolation_idx) + len(
        violation_idx
    ) == T, "violation and non-violation idx do not include all dta!"
    return nonviolation_idx, np.expand_dims(mask, axis=1)

In [3]:
#  Functions to assist with post-processing of GLM-HMM fits
import glob
import re
import sys

import numpy as np
import pandas as pd
import ssm

sys.path.insert(0, '../fit_glm/')
sys.path.insert(0, '../fit_lapse_model/')
#from GLM import glm
#from LapseModel import lapse_model


def load_data(animal_file):
    container = np.load(animal_file, allow_pickle=True)
    data = [container[key] for key in container]
    inpt = data[0]
    y = data[1]
    y = y.astype('int')
    session = data[2]
    return inpt, y, session


def load_session_fold_lookup(file_path):
    container = np.load(file_path, allow_pickle=True)
    data = [container[key] for key in container]
    session_fold_lookup_table = data[0]
    return session_fold_lookup_table


def load_glm_vectors(glm_vectors_file):
    container = np.load(glm_vectors_file)
    data = [container[key] for key in container]
    loglikelihood_train = data[0]
    recovered_weights = data[1]
    return loglikelihood_train, recovered_weights


def load_lapse_params(lapse_file):
    container = np.load(lapse_file, allow_pickle=True)
    data = [container[key] for key in container]
    lapse_loglikelihood = data[0]
    lapse_glm_weights = data[1]
    lapse_glm_weights_std = data[2],
    lapse_p = data[3]
    lapse_p_std = data[4]
    return lapse_loglikelihood, lapse_glm_weights, lapse_glm_weights_std, \
           lapse_p, lapse_p_std


def load_glmhmm_data(data_file):
    container = np.load(data_file, allow_pickle=True)
    data = [container[key] for key in container]
    this_hmm_params = data[0]
    lls = data[1]
    return [this_hmm_params, lls]


def load_cv_arr(file):
    container = np.load(file, allow_pickle=True)
    data = [container[key] for key in container]
    cvbt_folds_model = data[0]
    return cvbt_folds_model


def partition_data_by_session(inpt, y, mask, session):
    '''
    Partition inpt, y, mask by session
    :param inpt: arr of size TxM
    :param y:  arr of size T x D
    :param mask: Boolean arr of size T indicating if element is violation or
    not
    :param session: list of size T containing session ids
    :return: list of inpt arrays, data arrays and mask arrays, where the
    number of elements in list = number of sessions and each array size is
    number of trials in session
    '''
    inputs = []
    datas = []
    indexes = np.unique(session, return_index=True)[1]
    unique_sessions = [
        session[index] for index in sorted(indexes)
    ]  # ensure that unique sessions are ordered as they are in
    # session (so we can map inputs back to inpt)
    counter = 0
    masks = []
    for sess in unique_sessions:
        idx = np.where(session == sess)[0]
        counter += len(idx)
        inputs.append(inpt[idx, :])
        datas.append(y[idx, :])
        masks.append(mask[idx])
    assert counter == inpt.shape[0], "not all trials assigned to session!"
    return inputs, datas, masks


def get_train_test_dta(inpt, y, mask, session, session_fold_lookup_table,
                       fold):
    '''
    Split inpt, y, mask, session arrays into train and test arrays
    :param inpt:
    :param y:
    :param mask:
    :param session:
    :param session_fold_lookup_table:
    :param fold:
    :return:
    '''
    test_sessions = session_fold_lookup_table[np.where(
        session_fold_lookup_table[:, 1] == fold), 0]
    train_sessions = session_fold_lookup_table[np.where(
        session_fold_lookup_table[:, 1] != fold), 0]
    idx_test = [str(sess) in test_sessions for sess in session]
    idx_train = [str(sess) in train_sessions for sess in session]
    test_inpt, test_y, test_mask, this_test_session = inpt[idx_test, :], y[
                                                                         idx_test,
                                                                         :], \
                                                      mask[idx_test], session[
                                                          idx_test]
    train_inpt, train_y, train_mask, this_train_session = inpt[idx_train,
                                                          :], y[idx_train,
                                                              :], \
                                                          mask[idx_train], \
                                                          session[idx_train]
    return test_inpt, test_y, test_mask, this_test_session, train_inpt, \
           train_y, train_mask, this_train_session


def create_violation_mask(violation_idx, T):
    """
    Return indices of nonviolations and also a Boolean mask for inclusion (1
    = nonviolation; 0 = violation)
    :param test_idx:
    :param T:
    :return:
    """
    mask = np.array([i not in violation_idx for i in range(T)])
    nonviolation_idx = np.arange(T)[mask]
    mask = mask + 0
    assert len(nonviolation_idx) + len(
        violation_idx) == T, "violation and non-violation idx do not include " \
                             "" \
                             "" \
                             "" \
                             "" \
                             "all dta!"
    return nonviolation_idx, mask


def prepare_data_for_cv(inpt, y, session, session_fold_lookup_table, fold):
    '''
    :return:
    '''

    violation_idx = np.where(y == -1)[0]
    nonviolation_idx, nonviolation_mask = create_violation_mask(
        violation_idx, inpt.shape[0])
    # Load train and test data for session
    test_inpt, test_y, test_nonviolation_mask, this_test_session, \
    train_inpt, train_y, train_nonviolation_mask, this_train_session = \
        get_train_test_dta(
            inpt, y, nonviolation_mask, session, session_fold_lookup_table,
            fold)
    M = train_inpt.shape[1]
    n_test = np.sum(test_nonviolation_mask == 1)
    n_train = np.sum(train_nonviolation_mask == 1)
    return test_inpt, test_y, test_nonviolation_mask, this_test_session, \
           train_inpt, train_y, train_nonviolation_mask, this_train_session, \
           M, n_test, n_train


def calculate_baseline_test_ll(train_y, test_y, C):
    """
    Calculate baseline loglikelihood for CV bit/trial calculation.  This is
    log(p(y|p0)) = n_right(log(p0)) + (n_total-n_right)log(1-p0), where p0
    is the proportion of trials
    in which the animal went right in the training set and n_right is the
    number of trials in which the animal went right in the test set
    :param train_y
    :param test_y
    :return: baseline loglikelihood for CV bit/trial calculation
    """
    _, train_class_totals = np.unique(train_y, return_counts=True)
    train_class_probs = train_class_totals / train_y.shape[0]
    _, test_class_totals = np.unique(test_y, return_counts=True)
    ll0 = 0
    for c in range(C):
        ll0 += test_class_totals[c] * np.log(train_class_probs[c])
    return ll0


def calculate_glm_test_loglikelihood(glm_weights_file, test_y, test_inpt, M,
                                     C):
    loglikelihood_train, glm_vectors = load_glm_vectors(glm_weights_file)
    # Calculate test loglikelihood
    new_glm = glm(M, C)
    # Set parameters to fit parameters:
    new_glm.params = glm_vectors
    # Get loglikelihood of training data:
    loglikelihood_test = new_glm.log_marginal([test_y], [test_inpt], None,
                                              None)
    return loglikelihood_test


def calculate_lapse_test_loglikelihood(lapse_file, test_y, test_inpt, M,
                                       num_lapse_params):
    lapse_loglikelihood, lapse_glm_weights, _, lapse_p, _ = load_lapse_params(
        lapse_file)
    # Instantiate a model with these parameters
    new_lapse_model = lapse_model(M, num_lapse_params)
    if num_lapse_params == 1:
        new_lapse_model.params = [lapse_glm_weights, np.array([lapse_p])]
    else:
        new_lapse_model.params = [lapse_glm_weights, lapse_p]
    # Now calculate test loglikelihood
    loglikelihood_test = new_lapse_model.log_marginal(datas=[test_y],
                                                      inputs=[test_inpt],
                                                      masks=None,
                                                      tags=None)
    return loglikelihood_test


def return_lapse_nll(inpt, y, session, session_fold_lookup_table, fold,
                     num_lapse_params, results_dir_glm_lapse, C):
    test_inpt, test_y, test_nonviolation_mask, this_test_session, \
    train_inpt, train_y, train_nonviolation_mask, this_train_session, M, \
    n_test, n_train = prepare_data_for_cv(
        inpt, y, session, session_fold_lookup_table, fold)
    ll0 = calculate_baseline_test_ll(train_y[train_nonviolation_mask == 1, :],
                                     test_y[test_nonviolation_mask == 1, :], C)
    ll0_train = calculate_baseline_test_ll(
        train_y[train_nonviolation_mask == 1, :],
        train_y[train_nonviolation_mask == 1, :], C)
    if num_lapse_params == 1:
        lapse_file = results_dir_glm_lapse + '/Lapse_Model/fold_' + str(
            fold) + '/lapse_model_params_one_param.npz'
    elif num_lapse_params == 2:
        lapse_file = results_dir_glm_lapse + '/Lapse_Model/fold_' + str(
            fold) + '/lapse_model_params_two_param.npz'
    ll_lapse = calculate_lapse_test_loglikelihood(
        lapse_file,
        test_y[test_nonviolation_mask == 1, :],
        test_inpt[test_nonviolation_mask == 1, :],
        M,
        num_lapse_params=num_lapse_params)
    ll_train_lapse = calculate_lapse_test_loglikelihood(
        lapse_file,
        train_y[train_nonviolation_mask == 1, :],
        train_inpt[train_nonviolation_mask == 1, :],
        M,
        num_lapse_params=num_lapse_params)
    nll_lapse = calculate_cv_bit_trial(ll_lapse, ll0, n_test)
    nll_lapse_train = calculate_cv_bit_trial(ll_train_lapse, ll0_train,
                                             n_train)
    return nll_lapse, nll_lapse_train, ll_lapse, ll_train_lapse


def calculate_glm_hmm_test_loglikelihood(glm_hmm_dir, test_datas, test_inputs,
                                         test_nonviolation_masks, K, D, M, C):
    """
    calculate test loglikelihood for GLM-HMM model.  Loop through all
    initializations for fold of interest, and check that final train LL is
    same for top initializations
    :return:
    """
    this_file_name = glm_hmm_dir + '/iter_*/glm_hmm_raw_parameters_*.npz'
    raw_files = glob.glob(this_file_name, recursive=True)
    train_ll_vals_across_iters = []
    test_ll_vals_across_iters = []
    for file in raw_files:
        # Loop through initializations and calculate BIC:
        this_hmm_params, lls = load_glmhmm_data(file)
        train_ll_vals_across_iters.append(lls[-1])
        # Instantiate a new HMM and calculate test loglikelihood:
        this_hmm = ssm.HMM(K,
                           D,
                           M,
                           observations="input_driven_obs",
                           observation_kwargs=dict(C=C),
                           transitions="standard")
        this_hmm.params = this_hmm_params
        test_ll = this_hmm.log_likelihood(test_datas,
                                          inputs=test_inputs,
                                          masks=test_nonviolation_masks)
        test_ll_vals_across_iters.append(test_ll)
    # Order initializations by train LL (don't train on test data!):
    train_ll_vals_across_iters = np.array(train_ll_vals_across_iters)
    test_ll_vals_across_iters = np.array(test_ll_vals_across_iters)
    # Order raw files by train LL
    file_ordering_by_train = np.argsort(-train_ll_vals_across_iters)
    raw_file_ordering_by_train = np.array(raw_files)[file_ordering_by_train]
    # Get initialization number from raw_file ordering
    init_ordering_by_train = [
        int(re.findall(r'\d+', file)[-1])
        for file in raw_file_ordering_by_train
    ]
    return test_ll_vals_across_iters, init_ordering_by_train, \
           file_ordering_by_train


def return_glmhmm_nll(inpt, y, session, session_fold_lookup_table, fold, K, D,
                      C, results_dir_glm_hmm):
    '''
    For a given fold, return NLL for both train and test datasets for
    GLM-HMM model with K, D, C.  Requires reading in best
    parameters over all initializations for GLM-HMM (hence why
    results_dir_glm_hmm is required as an input)
    :param inpt:
    :param y:
    :param session:
    :param session_fold_lookup_table:
    :param fold:
    :param K:
    :param D:
    :param C:
    :param results_dir_glm_hmm:
    :return:
    '''
    test_inpt, test_y, test_nonviolation_mask, this_test_session, \
    train_inpt, train_y, train_nonviolation_mask, this_train_session, M, \
    n_test, n_train = prepare_data_for_cv(
        inpt, y, session, session_fold_lookup_table, fold)
    ll0 = calculate_baseline_test_ll(train_y[train_nonviolation_mask == 1, :],
                                     test_y[test_nonviolation_mask == 1, :], C)
    ll0_train = calculate_baseline_test_ll(
        train_y[train_nonviolation_mask == 1, :],
        train_y[train_nonviolation_mask == 1, :], C)
    # For GLM-HMM set values of y for violations to 1.  This value doesn't
    # matter (as mask will ensure that these y values do not contribute to
    # loglikelihood calculation
    test_y[test_nonviolation_mask == 0, :] = 1
    train_y[train_nonviolation_mask == 0, :] = 1
    # For GLM-HMM, need to partition data by session
    test_inputs, test_datas, test_nonviolation_masks = \
        partition_data_by_session(
            test_inpt, test_y,
            np.expand_dims(test_nonviolation_mask, axis=1),
            this_test_session)
    train_inputs, train_datas, train_nonviolation_masks = \
        partition_data_by_session(
            train_inpt, train_y,
            np.expand_dims(train_nonviolation_mask, axis=1),
            this_train_session)
    dir_to_check = results_dir_glm_hmm + '/GLM_HMM_K_' + str(
        K) + '/fold_' + str(fold) + '/'
    test_ll_vals_across_iters, init_ordering_by_train, \
    file_ordering_by_train = calculate_glm_hmm_test_loglikelihood(
        dir_to_check, test_datas, test_inputs, test_nonviolation_masks, K, D,
        M, C)
    train_ll_vals_across_iters, _, _ = calculate_glm_hmm_test_loglikelihood(
        dir_to_check, train_datas, train_inputs, train_nonviolation_masks, K,
        D, M, C)
    test_ll_vals_across_iters = test_ll_vals_across_iters[
        file_ordering_by_train]
    train_ll_vals_across_iters = train_ll_vals_across_iters[
        file_ordering_by_train]
    ll_glm_hmm_this_K = test_ll_vals_across_iters[0]
    cvbt_thismodel_thisfold = calculate_cv_bit_trial(ll_glm_hmm_this_K, ll0,
                                                     n_test)
    train_cvbt_thismodel_thisfold = calculate_cv_bit_trial(
        train_ll_vals_across_iters[0], ll0_train, n_train)
    return cvbt_thismodel_thisfold, train_cvbt_thismodel_thisfold, \
           ll_glm_hmm_this_K, \
           train_ll_vals_across_iters[0], init_ordering_by_train


def calculate_cv_bit_trial(ll_model, ll_0, n_trials):
    cv_bit_trial = ((ll_model - ll_0) / n_trials) / np.log(2)
    return cv_bit_trial


def create_cv_frame_for_plotting(cv_file):
    cvbt_folds_model = load_cv_arr(cv_file)
    glm_lapse_model = cvbt_folds_model[:3, ]
    idx = np.array([0, 3, 4, 5, 6])
    cvbt_folds_model = cvbt_folds_model[idx, :]
    # Identify best cvbt:
    mean_cvbt = np.mean(cvbt_folds_model, axis=1)
    loc_best = np.where(mean_cvbt == max(mean_cvbt))[0]
    best_val = max(mean_cvbt)
    # Create dataframe for plotting
    num_models = cvbt_folds_model.shape[0]
    num_folds = cvbt_folds_model.shape[1]
    # Create pandas dataframe:
    data_for_plotting_df = pd.DataFrame({
        'model':
            np.repeat(np.arange(num_models), num_folds),
        'cv_bit_trial':
            cvbt_folds_model.flatten()
    })
    return data_for_plotting_df, loc_best, best_val, glm_lapse_model


def get_file_name_for_best_model_fold(cvbt_folds_model, K, overall_dir,
                                      best_init_cvbt_dict):
    '''
    Get the file name for the best initialization for the K value specified
    :param cvbt_folds_model:
    :param K:
    :param models:
    :param overall_dir:
    :param best_init_cvbt_dict:
    :return:
    '''
    # Identify best fold for best model:
    # loc_best = K - 1
    loc_best = 0
    best_fold = np.where(cvbt_folds_model[loc_best, :] == max(cvbt_folds_model[
                                                              loc_best, :]))[
        0][0]
    base_path = overall_dir + '/GLM_HMM_K_' + str(K) + '/fold_' + str(
        best_fold)
    key_for_dict = '/GLM_HMM_K_' + str(K) + '/fold_' + str(best_fold)
    best_iter = best_init_cvbt_dict[key_for_dict]
    raw_file = base_path + '/iter_' + str(
        best_iter) + '/glm_hmm_raw_parameters_itr_' + str(best_iter) + '.npz'
    return raw_file


def permute_transition_matrix(transition_matrix, permutation):
    transition_matrix = transition_matrix[np.ix_(permutation, permutation)]
    return transition_matrix


def calculate_state_permutation(hmm_params):
    '''
    If K = 3, calculate the permutation that results in states being ordered
    as engaged/bias left/bias right
    Else: order states so that they are ordered by engagement
    :param hmm_params:
    :return: permutation
    '''
    # GLM weights (note: we have to take negative, because we are interested
    # in weights corresponding to p(y = 1) = 1/(1+e^(-w.x)), but returned
    # weights from
    # code are w such that p(y = 1) = e(w.x)/1+e(w.x))
    glm_weights = -hmm_params[2]
    K = glm_weights.shape[0]
    if K == 3:
        # want states ordered as engaged/bias left/bias right
        M = glm_weights.shape[2] - 1
        # bias coefficient is last entry in dimension 2
        engaged_loc = \
            np.where((glm_weights[:, 0, 0] == max(glm_weights[:, 0, 0])))[0][0]
        reduced_weights = np.copy(glm_weights)
        # set row in reduced weights corresponding to engaged to have a bias
        # that will not cause it to have largest bias
        reduced_weights[engaged_loc, 0, M] = max(glm_weights[:, 0, M]) - 0.001
        bias_left_loc = \
            np.where(
                (reduced_weights[:, 0, M] == min(reduced_weights[:, 0, M])))[
                0][0]
        state_order = [engaged_loc, bias_left_loc]
        bias_right_loc = np.arange(3)[np.where(
            [range(3)[i] not in state_order for i in range(3)])][0]
        permutation = np.array([engaged_loc, bias_left_loc, bias_right_loc])
    elif K == 4:
        # want states ordered as engaged/bias left/bias right
        M = glm_weights.shape[2] - 1
        # bias coefficient is last entry in dimension 2
        engaged_loc = \
            np.where((glm_weights[:, 0, 0] == max(glm_weights[:, 0, 0])))[0][0]
        reduced_weights = np.copy(glm_weights)
        # set row in reduced weights corresponding to engaged to have a bias
        # that will not
        reduced_weights[engaged_loc, 0, M] = max(glm_weights[:, 0, M]) - 0.001
        bias_right_loc = \
            np.where(
                (reduced_weights[:, 0, M] == max(reduced_weights[:, 0, M])))[
                0][0]
        bias_left_loc = \
            np.where(
                (reduced_weights[:, 0, M] == min(reduced_weights[:, 0, M])))[
                0][0]
        state_order = [engaged_loc, bias_left_loc, bias_right_loc]
        other_loc = np.arange(4)[np.where(
            [range(4)[i] not in state_order for i in range(4)])][0]
        permutation = np.array(
            [engaged_loc, bias_left_loc, bias_right_loc, other_loc])
    else:
        # order states by engagement: with the most engaged being first.
        # Note: argsort sorts inputs from smallest to largest (hence why we
        # convert to -ve glm_weights)
        permutation = np.argsort(-glm_weights[:, 0, 0])
    # assert that all indices are present in permutation exactly once:
    assert len(permutation) == K, "permutation is incorrect size"
    assert check_all_indices_present(permutation, K), "not all indices " \
                                                      "present in " \
                                                      "permutation: " \
                                                      "permutation = " + \
                                                      str(permutation)
    return permutation


def check_all_indices_present(permutation, K):
    for i in range(K):
        if i not in permutation:
            return False
    return True


def get_marginal_posterior(inputs, datas, masks, hmm_params, K, permutation):
    # Run forward algorithm on hmm with these parameters and collect gammas:
    M = inputs[0].shape[1]
    D = datas[0].shape[1]
    this_hmm = ssm.HMM(K, D, M,
                       observations="input_driven_obs",
                       observation_kwargs=dict(C=2),
                       transitions="standard")
    this_hmm.params = hmm_params
    # Get expected states:
    expectations = [this_hmm.expected_states(data=data, input=input,
                                             mask=np.expand_dims(mask,
                                                                 axis=1))[0]
                    for data, input, mask
                    in zip(datas, inputs, masks)]
    # Convert this now to one array:
    posterior_probs = np.concatenate(expectations, axis=0)
    posterior_probs = posterior_probs[:, permutation]
    return posterior_probs

In [9]:
import numpy as np

K_vals = [2, 3, 4, 5]
num_folds = 5
N_initializations = 20

if __name__ == '__main__':
    cluster_job_arr = []
    for K in K_vals:
        for i in range(num_folds):
            for j in range(N_initializations):
                cluster_job_arr.append([K, i, j])
    np.savez('../../int-brain-lab/glm-hmm/data/ibl/data_for_cluster/cluster_job_arr.npz',
             cluster_job_arr)

In [19]:
import sys
import os
import autograd.numpy as np
#from glm_hmm_utils import load_cluster_arr, load_session_fold_lookup, \
#    load_data, create_violation_mask, launch_glm_hmm_job

D = 1  # data (observations) dimension
C = 2  # number of output types/categories
N_em_iters = 1 # 300  # number of EM iterations

USE_CLUSTER = False

if __name__ == '__main__':
    data_dir = '../../int-brain-lab/glm-hmm/data/ibl/data_for_cluster/'
    results_dir = '../../int-brain-lab/glm-hmm/results/ibl_global_fit/'

    if USE_CLUSTER:
        z = int(sys.argv[1])
    else:
        z = 0

    num_folds = 5
    global_fit = True
    # perform mle => set transition_alpha to 1
    transition_alpha = 1
    prior_sigma = 100

    # Load external files:
    cluster_arr_file = data_dir + 'cluster_job_arr.npz'
    # Load cluster array job parameters:
    cluster_arr = load_cluster_arr(cluster_arr_file)
    [K, fold, iter] = cluster_arr[z]

    #  read in data and train/test split
    animal_file = data_dir + 'all_animals_concat.npz'
    session_fold_lookup_table = load_session_fold_lookup(
        data_dir + 'all_animals_concat_session_fold_lookup.npz')

    inpt, y, session = load_data(animal_file)
    #  append a column of ones to inpt to represent the bias covariate:
    inpt = np.hstack((inpt, np.ones((len(inpt),1))))
    y = y.astype('int')
    # Identify violations for exclusion:
    violation_idx = np.where(y == -1)[0]
    nonviolation_idx, mask = create_violation_mask(violation_idx,
                                                   inpt.shape[0])

    #  GLM weights to use to initialize GLM-HMM
    init_param_file = results_dir + '/GLM/fold_' + str(
        fold) + '/variables_of_interest_iter_0.npz'

    # create save directory for this initialization/fold combination:
    save_directory = results_dir + '/GLM_HMM_K_' + str(
        K) + '/' + 'fold_' + str(fold) + '/' + '/iter_' + str(iter) + '/'
    if not os.path.exists(save_directory):
        os.makedirs(save_directory)

    launch_glm_hmm_job(inpt,
                       y,
                       session,
                       maks,
                       session_fold_lookup_table,
                       K,
                       D,
                       C,
                       N_em_iters,
                       transition_alpha,
                       prior_sigma,
                       fold,
                       iter,
                       global_fit,
                       init_param_file,
                       save_directory)

Starting inference with K = 2; Fold = 0; Iter = 0


TypeError: 'NoneType' object is not subscriptable

In [18]:
def fit_glm_hmm(datas, inputs, masks, K, D, M, C, N_em_iters,
                transition_alpha, prior_sigma, global_fit,
                params_for_initialization, save_title):
    '''
    Instantiate and fit GLM-HMM model
    :param datas:
    :param inputs:
    :param masks:
    :param K:
    :param D:
    :param M:
    :param C:
    :param N_em_iters:
    :param global_fit:
    :param glm_vectors:
    :param save_title:
    :return:
    '''
    if global_fit == True:
        # Prior variables
        # Choice of prior
        this_hmm = ssm.HMM(K,
                           D,
                           M,
                           observations="input_driven_obs",
                           observation_kwargs=dict(C=C,
                                                   prior_sigma=prior_sigma),
                           transitions="sticky",
                           transition_kwargs=dict(alpha=transition_alpha,
                                                  kappa=0))
        # Initialize observation weights as GLM weights with some noise:
        glm_vectors_repeated = np.tile(params_for_initialization, (K, 1, 1))
        glm_vectors_with_noise = glm_vectors_repeated + np.random.normal(
            0, 0.2, glm_vectors_repeated.shape)
        this_hmm.observations.params = glm_vectors_with_noise
    else:
        # Choice of prior
        this_hmm = ssm.HMM(K,
                           D,
                           M,
                           observations="input_driven_obs",
                           observation_kwargs=dict(C=C,
                                                   prior_sigma=prior_sigma),
                           transitions="sticky",
                           transition_kwargs=dict(alpha=transition_alpha,
                                                  kappa=0))
        # Initialize HMM-GLM with global parameters:
        this_hmm.params = params_for_initialization
        # Get log_prior of transitions:
    print("=== fitting GLM-HMM ========")
    sys.stdout.flush()
    # Fit this HMM and calculate marginal likelihood
    lls = this_hmm.fit(datas,
                       inputs=inputs,
                       masks=None,
                       method="em",
                       num_iters=N_em_iters,
                       initialize=False,
                       tolerance=10 ** -4)
    # Save raw parameters of HMM, as well as loglikelihood during training
    np.savez(save_title, this_hmm.params, lls)
    return None

In [None]:
# Create a matrix of size num_models x num_folds containing
# normalized loglikelihood for both train and test splits
import json

import numpy as np
#from post_processing_utils import load_data, load_session_fold_lookup, \
#    prepare_data_for_cv, calculate_baseline_test_ll, \
#    calculate_glm_test_loglikelihood, calculate_cv_bit_trial, \
#    return_glmhmm_nll, return_lapse_nll

if __name__ == '__main__':
    data_dir = '../../int-brain-lab/glm-hmm/data/ibl/data_for_cluster/'
    results_dir = '../../int-brain-lab/glm-hmm/results/ibl_global_fit/'

    # Load data
    inpt, y, session = load_data(data_dir + 'all_animals_concat.npz')
    session_fold_lookup_table = load_session_fold_lookup(
        data_dir + 'all_animals_concat_session_fold_lookup.npz')

    # Parameters
    C = 2  # number of output classes
    num_folds = 5  # number of folds
    D = 1  # number of output dimensions
    K_max = 5  # maximum number of latent states
    num_models = K_max + 2  # model for each latent + 2 lapse models

    animal_preferred_model_dict = {}
    models = ["GLM", "Lapse_Model", "GLM_HMM"]

    cvbt_folds_model = np.zeros((num_models, num_folds))
    cvbt_train_folds_model = np.zeros((num_models, num_folds))

    # Save best initialization for each model-fold combination
    best_init_cvbt_dict = {}
    for fold in range(num_folds):
        test_inpt, test_y, test_nonviolation_mask, this_test_session, \
        train_inpt, train_y, train_nonviolation_mask, this_train_session, M,\
        n_test, n_train = prepare_data_for_cv(
            inpt, y, session, session_fold_lookup_table, fold)
        ll0 = calculate_baseline_test_ll(
            train_y[train_nonviolation_mask == 1, :],
            test_y[test_nonviolation_mask == 1, :], C)
        ll0_train = calculate_baseline_test_ll(
            train_y[train_nonviolation_mask == 1, :],
            train_y[train_nonviolation_mask == 1, :], C)
        for model in models:
            print("model = " + str(model))
            if model == "GLM":
                # Load parameters and instantiate a new GLM object with
                # these parameters
                glm_weights_file = results_dir + '/GLM/fold_' + str(
                    fold) + '/variables_of_interest_iter_0.npz'
                ll_glm = calculate_glm_test_loglikelihood(
                    glm_weights_file, test_y[test_nonviolation_mask == 1, :],
                    test_inpt[test_nonviolation_mask == 1, :], M, C)
                ll_glm_train = calculate_glm_test_loglikelihood(
                    glm_weights_file, train_y[train_nonviolation_mask == 1, :],
                    train_inpt[train_nonviolation_mask == 1, :], M, C)
                cvbt_folds_model[0, fold] = calculate_cv_bit_trial(
                    ll_glm, ll0, n_test)
                cvbt_train_folds_model[0, fold] = calculate_cv_bit_trial(
                    ll_glm_train, ll0_train, n_train)
            elif model == "Lapse_Model":
                # One lapse parameter model:
                cvbt_folds_model[1, fold], cvbt_train_folds_model[
                    1,
                    fold], _, _ = return_lapse_nll(inpt, y, session,
                                                   session_fold_lookup_table,
                                                   fold, 1, results_dir, C)
                # Two lapse parameter model:
                cvbt_folds_model[2, fold], cvbt_train_folds_model[
                    2,
                    fold], _, _ = return_lapse_nll(inpt, y, session,
                                                   session_fold_lookup_table,
                                                   fold, 2, results_dir, C)
            elif model == "GLM_HMM":
                for K in range(2, K_max + 1):
                    print("K = " + str(K))
                    model_idx = 3 + (K - 2)
                    cvbt_folds_model[model_idx, fold], \
                    cvbt_train_folds_model[
                        model_idx, fold], _, _, init_ordering_by_train = \
                        return_glmhmm_nll(
                            np.hstack((inpt, np.ones((len(inpt), 1)))), y,
                            session, session_fold_lookup_table, fold,
                            K, D, C, results_dir)
                    # Save best initialization to dictionary for later:
                    key_for_dict = '/GLM_HMM_K_' + str(K) + '/fold_' + str(
                        fold)
                    best_init_cvbt_dict[key_for_dict] = int(
                        init_ordering_by_train[0])
    # Save best initialization directories across animals, folds and models
    # (only GLM-HMM):
    print(cvbt_folds_model)
    print(cvbt_train_folds_model)
    json_dump = json.dumps(best_init_cvbt_dict)
    f = open(results_dir + "/best_init_cvbt_dict.json", "w")
    f.write(json_dump)
    f.close()
    # Save cvbt_folds_model as numpy array for easy parsing across all
    # models and folds
    np.savez(results_dir + "/cvbt_folds_model.npz", cvbt_folds_model)
    np.savez(results_dir + "/cvbt_train_folds_model.npz",
             cvbt_train_folds_model)

In [None]:
# Save best parameters from IBL global fits (for K = 2 to 5) to initialize
# each animal's model
import json
import os
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from post_processing_utils import load_glmhmm_data, load_cv_arr, \
    create_cv_frame_for_plotting, get_file_name_for_best_model_fold, \
    permute_transition_matrix, calculate_state_permutation


if __name__ == '__main__':

    data_dir = '../../data/ibl/data_for_cluster/'
    results_dir = '../../results/ibl_global_fit/'
    save_directory = data_dir + "best_global_params/"

    if not os.path.exists(save_directory):
        os.makedirs(save_directory)

    labels_for_plot = ['stim', 'pc', 'wsls', 'bias']

    cv_file = results_dir + "/cvbt_folds_model.npz"
    cvbt_folds_model = load_cv_arr(cv_file)

    for K in range(2, 6):
        print("K = " + str(K))
        with open(results_dir + "/best_init_cvbt_dict.json", 'r') as f:
            best_init_cvbt_dict = json.load(f)

        # Get the file name corresponding to the best initialization for
        # given K value
        raw_file = get_file_name_for_best_model_fold(
            cvbt_folds_model, K, results_dir, best_init_cvbt_dict)
        hmm_params, lls = load_glmhmm_data(raw_file)

        # Calculate permutation
        permutation = calculate_state_permutation(hmm_params)
        print(permutation)

        # Save parameters for initializing individual fits
        weight_vectors = hmm_params[2][permutation]
        log_transition_matrix = permute_transition_matrix(
            hmm_params[1][0], permutation)
        init_state_dist = hmm_params[0][0][permutation]
        params_for_individual_initialization = [[init_state_dist],
                                                [log_transition_matrix],
                                                weight_vectors]

        np.savez(
            save_directory + 'best_params_K_' + str(K) + '.npz',
            params_for_individual_initialization)

        # Plot these too:
        cols = ["#e74c3c", "#15b01a", "#7e1e9c", "#3498db", "#f97306"]
        fig = plt.figure(figsize=(4 * 8, 10),
                         dpi=80,
                         facecolor='w',
                         edgecolor='k')
        plt.subplots_adjust(left=0.1,
                            bottom=0.24,
                            right=0.95,
                            top=0.7,
                            wspace=0.8,
                            hspace=0.5)
        plt.subplot(1, 3, 1)
        M = weight_vectors.shape[2] - 1
        for k in range(K):
            plt.plot(range(M + 1),
                     -weight_vectors[k][0],
                     marker='o',
                     label='State ' + str(k + 1),
                     color=cols[k],
                     lw=4)
        plt.xticks(list(range(0, len(labels_for_plot))),
                   labels_for_plot,
                   rotation='20',
                   fontsize=24)
        plt.yticks(fontsize=30)
        plt.legend(fontsize=30)
        plt.axhline(y=0, color="k", alpha=0.5, ls="--")
        # plt.ylim((-3, 14))
        plt.ylabel("Weight", fontsize=30)
        plt.xlabel("Covariate", fontsize=30, labelpad=20)
        plt.title("GLM Weights: Choice = R", fontsize=40)

        plt.subplot(1, 3, 2)
        transition_matrix = np.exp(log_transition_matrix)
        plt.imshow(transition_matrix, vmin=0, vmax=1)
        for i in range(transition_matrix.shape[0]):
            for j in range(transition_matrix.shape[1]):
                text = plt.text(j,
                                i,
                                np.around(transition_matrix[i, j],
                                          decimals=3),
                                ha="center",
                                va="center",
                                color="k",
                                fontsize=30)
        plt.ylabel("Previous State", fontsize=30)
        plt.xlabel("Next State", fontsize=30)
        plt.xlim(-0.5, K - 0.5)
        plt.ylim(-0.5, K - 0.5)
        plt.xticks(range(0, K), ('1', '2', '3', '4', '4', '5', '6', '7',
                                 '8', '9', '10')[:K],
                   fontsize=30)
        plt.yticks(range(0, K), ('1', '2', '3', '4', '4', '5', '6', '7',
                                 '8', '9', '10')[:K],
                   fontsize=30)
        plt.title("Retrieved", fontsize=40)

        plt.subplot(1, 3, 3)
        cols = [
            "#7e1e9c", "#0343df", "#15b01a", "#bf77f6", "#95d0fc",
            "#96f97b"
        ]
        cv_file = results_dir + "/cvbt_folds_model.npz"
        data_for_plotting_df, loc_best, best_val, glm_lapse_model = \
            create_cv_frame_for_plotting(
            cv_file)
        cv_file_train = results_dir + "/cvbt_train_folds_model.npz"
        train_data_for_plotting_df, train_loc_best, train_best_val, \
        train_glm_lapse_model = create_cv_frame_for_plotting(
            cv_file_train)

        glm_lapse_model_cvbt_means = np.mean(glm_lapse_model, axis=1)
        train_glm_lapse_model_cvbt_means = np.mean(train_glm_lapse_model,
                                                   axis=1)
        g = sns.lineplot(
            data_for_plotting_df['model'],
            data_for_plotting_df['cv_bit_trial'],
            err_style="bars",
            mew=0,
            color=cols[0],
            marker='o',
            ci=68,
            label="test",
            alpha=1,
            lw=4)
        sns.lineplot(
            train_data_for_plotting_df['model'],
            train_data_for_plotting_df['cv_bit_trial'],
            err_style="bars",
            mew=0,
            color=cols[1],
            marker='o',
            ci=68,
            label="train",
            alpha=1,
            lw=4)
        plt.xlabel("Model", fontsize=30)
        plt.ylabel("Normalized LL", fontsize=30)
        plt.xticks([0, 1, 2, 3, 4],
                   ['1 State', '2 State', '3 State', '4 State', '5 State'],
                   rotation=45,
                   fontsize=24)
        plt.yticks(fontsize=15)
        plt.axhline(y=glm_lapse_model_cvbt_means[2],
                    color=cols[2],
                    label="Lapse (test)",
                    alpha=0.9,
                    lw=4)
        plt.legend(loc='upper right', fontsize=30)
        plt.tick_params(axis='y')
        plt.yticks([0.2, 0.3, 0.4, 0.5], fontsize=30)
        plt.ylim((0.2, 0.55))
        plt.title("Model Comparison", fontsize=40)
        fig.tight_layout()

        fig.savefig(results_dir + 'best_params_cross_validation_K_' +
                    str(K) + '.png')