<a href="https://colab.research.google.com/github/yjiong0228/ssm/blob/master/Q1_Jiong.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

## Install IBL pipeline package

In [1]:
# install IBL pipeline package to access and navigate the pipeline
!pip install --quiet nma-ibl

[K     |████████████████████████████████| 102kB 6.7MB/s 
[K     |████████████████████████████████| 51kB 5.5MB/s 
[K     |████████████████████████████████| 81kB 8.5MB/s 
[K     |████████████████████████████████| 3.2MB 56.8MB/s 
[31mERROR: otumat 0.2.0 has requirement cryptography<=3.3.2, but you'll have cryptography 3.4.7 which is incompatible.[0m
[?25h

In [2]:
import datajoint as dj

dj.config['database.host'] = 'datajoint-public.internationalbrainlab.org'
dj.config['database.user'] = 'nma-ibl-public'
dj.config['database.password'] = 'ibl.pipeline.public.demo'
dj.conn() # explicitly checks if database connection can be established

Connecting nma-ibl-public@datajoint-public.internationalbrainlab.org:3306


DataJoint connection (connected) nma-ibl-public@datajoint-public.internationalbrainlab.org:3306

In [3]:
from nma_ibl import reference, subject, acquisition, behavior, behavior_analyses
from nma_ibl import psychofit as psy

## Define general functions

 Adapted from [paper_behavior_functions.py](https://github.com/int-brain-lab/paper-behavior/blob/master/paper_behavior_functions.py)  by IBL et al., 2021

In [4]:
import warnings
import os
from io import BytesIO
from zipfile import ZipFile
from urllib.request import urlopen

import seaborn as sns
import matplotlib
import numpy as np
import datajoint as dj
import pandas as pd
import matplotlib.pyplot as plt


# Supress seaborn future warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

# Some constants
URL = 'http://ibl.flatironinstitute.org/public/behavior_paper_data.zip'
QUERY = False  # Whether to query data through DataJoint (True) or use downloaded csv files (False)
EXAMPLE_MOUSE = 'KS014'  # Mouse nickname used as an example
CUTOFF_DATE = '2020-03-23'  # Date after which sessions are excluded, previously 30th Nov
STABLE_HW_DATE = '2019-06-10'  # Date after which hardware was deemed stable

# LAYOUT
FIGURE_HEIGHT = 2  # inch
FIGURE_WIDTH = 8  # inch

# EXCLUDED SESSIONS
EXCLUDED_SESSIONS = ['a9fb578a-9d7d-42b4-8dbc-3b419ce9f424']  # Session UUID


def group_colors():
    return sns.color_palette("Dark2", 7)


def institution_map():
    institution_map = {'UCL': 'Lab 1', 'CCU': 'Lab 2', 'CSHL': 'Lab 3', 'NYU': 'Lab 4',
                       'Princeton': 'Lab 5', 'SWC': 'Lab 6', 'Berkeley': 'Lab 7'}
    col_names = ['Lab 1', 'Lab 2', 'Lab 3', 'Lab 4', 'Lab 5', 'Lab 6', 'Lab 7', 'All labs']

    return institution_map, col_names


def seaborn_style():
    """
    Set seaborn style for plotting figures
    """
    sns.set(style="ticks", context="paper",
            font="Arial",
            rc={"font.size": 9,
                "axes.titlesize": 9,
                "axes.labelsize": 9,
                "lines.linewidth": 1,
                "xtick.labelsize": 7,
                "ytick.labelsize": 7,
                "savefig.transparent": True,
                "xtick.major.size": 2.5,
                "ytick.major.size": 2.5,
                "xtick.minor.size": 2,
                "ytick.minor.size": 2,
                })
    matplotlib.rcParams['pdf.fonttype'] = 42
    matplotlib.rcParams['ps.fonttype'] = 42


def figpath():
    # Retrieve absolute path of paper-behavior dir
    repo_dir = os.path.dirname(os.path.realpath(__file__))
    # Make figure directory
    fig_dir = os.path.join(repo_dir, 'exported_figs')
    # If doesn't already exist, create
    if not os.path.exists(fig_dir):
        os.mkdir(fig_dir)
    return fig_dir


def datapath():
    """
    Return the location of data directory
    """
   # Retrieve absolute path of paper-behavior dir
    repo_dir = os.path.dirname(os.path.realpath(__file__))
    # Make figure directory
    data_dir = os.path.join(repo_dir, 'data')
    # If doesn't already exist, create
    if not os.path.exists(data_dir):
        os.mkdir(data_dir)
    return data_dir


def query_subjects(as_dataframe=False, from_list=False, criterion='trained'):
    """
    Query all mice for analysis of behavioral data
    Parameters
    ----------
    as_dataframe:    boolean if true returns a pandas dataframe (default is False)
    from_list:       loads files from list uuids (array of uuids objects)
    criterion:       what criterion by the 30th of November - trained (a and b), biased, ephys
                     (includes ready4ephysrig, ready4delay and ready4recording).  If None,
                     all mice that completed a training session are returned, with date_trained
                     being the date of their first training session.
    """
    from nma_ibl import subject, acquisition, reference, behavior_analyses

    # Query all subjects with project ibl_neuropixel_brainwide_01 and get the date at which
    # they reached a given training status
    all_subjects = (subject.Subject * subject.SubjectLab * reference.Lab * subject.SubjectProject
                    & 'subject_project = "ibl_neuropixel_brainwide_01"')
    sessions = acquisition.Session * behavior_analyses.SessionTrainingStatus()
    fields = ('subject_nickname', 'sex', 'subject_birth_date', 'institution_short')

    if criterion is None:
        # Find first session of all mice; date_trained = date of first training session
        subj_query = all_subjects.aggr(
            sessions, *fields, date_trained='min(date(session_start_time))')
    else:  # date_trained = date of first session when criterion was reached
        if criterion == 'trained':
            restriction = 'training_status="trained_1a" OR training_status="trained_1b"'
        elif criterion == 'biased':
            restriction = 'task_protocol LIKE "%biased%"'
        elif criterion == 'ephys':
            restriction = 'training_status LIKE "ready%"'
        else:
            raise ValueError('criterion must be "trained", "biased" or "ephys"')
        subj_query = all_subjects.aggr(
            sessions & restriction, *fields, date_trained='min(date(session_start_time))')

    if from_list is True:
        data_path = os.path.join(datapath(), 'uuids_trained.npy')
        ids = np.load(data_path, allow_pickle=True)
        subj_query = subj_query & [{'subject_uuid': u_id} for u_id in ids]

    # Select subjects that reached criterion before cutoff date
    subjects = (subj_query & 'date_trained <= "%s"' % CUTOFF_DATE)
    if as_dataframe is True:
        subjects = subjects.fetch(format='frame')
        subjects = subjects.sort_values(by=['lab_name']).reset_index()

    return subjects


def query_sessions(task='all', stable=False, as_dataframe=False,
                   force_cutoff=False, criterion='biased'):
    """
    Query all sessions for analysis of behavioral data
    Parameters
    ----------
    task:            string indicating sessions of which task to return, can be trianing or biased
                     default is all
    stable:          boolean if True only return sessions with stable hardware, which means
                     sessions after particular date (default is False)
    as_dataframe:    boolean if True returns a pandas dataframe (default is False)
    force_cutoff:    whether the animal had to reach the criterion by the 30th of Nov. Only
                     applies to biased and ready for ephys criterion
    criterion:       what criterion by the 30th of November - trained (includes
                     a and b), biased, ready (includes ready4ephysrig, ready4delay and
                     ready4recording)
    """

    from nma_ibl import acquisition

    # Query sessions
    if force_cutoff is True:
        use_subjects = query_subjects(criterion=criterion).proj('subject_uuid')
    else:
        use_subjects = query_subjects().proj('subject_uuid')

    # Query all sessions or only training or biased if required
    if task == 'all':
        sessions = acquisition.Session * use_subjects & 'task_protocol NOT LIKE "%habituation%"'
    elif task == 'training':
        sessions = acquisition.Session * use_subjects & 'task_protocol LIKE "%training%"'
    elif task == 'biased':
        sessions = acquisition.Session * use_subjects & 'task_protocol LIKE "%biased%"'
    elif task == 'ephys':
        sessions = acquisition.Session * use_subjects & 'task_protocol LIKE "%ephys%"'
    else:
        raise ValueError('task must be "all", "training", "biased" or "ephys"')

    # Only use sessions up until the end of December
    sessions = sessions & 'date(session_start_time) <= "%s"' % CUTOFF_DATE

    # Exclude weird sessions
    sessions = sessions & dj.Not([{'session_uuid': u_id} for u_id in EXCLUDED_SESSIONS])

    # If required only output sessions with stable hardware
    if stable is True:
        sessions = sessions & 'date(session_start_time) > "%s"' % STABLE_HW_DATE

    # Transform into pandas Dataframe if requested
    if as_dataframe is True:
        sessions = sessions.fetch(
            order_by='institution_short, subject_nickname, session_start_time', format='frame')
        sessions = sessions.reset_index()

    return sessions


def query_sessions_around_criterion(criterion='trained', days_from_criterion=(2, 0),
                                    as_dataframe=False, force_cutoff=False):
    """
    Query all sessions for analysis of behavioral data
    Parameters
    ----------
    criterion:              string indicating which criterion to use: trained, biased or ephys
    days_from_criterion:    two-element array which indicates which training days around the day
                            the mouse reached criterium to return, e.g. [3, 2] returns three days
                            before criterium reached up until 2 days after (default: [2, 0])
    as_dataframe:           return sessions as a pandas dataframe
    force_cutoff:           whether the animal had to reach the criterion by the 30th of Nov. Only
                            applies to biased and ready for ephys criterion
    Returns
    ---------
    sessions:               The sessions around the criterion day, works in conjunction with
                            any table that has session_start_time as primary key (such as
                            behavior.TrialSet.Trial)
    days:                   The training days around the criterion day. Can be used in conjunction
                            with tables that have session_date as primary key (such as
                            behavior_analyses.BehavioralSummaryByDate)
    """

    from nma_ibl import subject, acquisition, behavior_analyses

    # Query all included subjects
    if force_cutoff is True:
        use_subjects = query_subjects(criterion=criterion).proj('subject_uuid')
    else:
        use_subjects = query_subjects().proj('subject_uuid')

    # Query per subject the date at which the criterion is reached
    sessions = acquisition.Session * behavior_analyses.SessionTrainingStatus
    if criterion == 'trained':
        restriction = 'training_status="trained_1a" OR training_status="trained_1b"'
    elif criterion == 'biased':
        restriction = 'task_protocol LIKE "%biased%" AND training_status="trained_1b"'
    elif criterion == 'ephys':
        restriction = 'training_status LIKE "ready%"'
    else:
        raise ValueError('criterion must be "trained", "biased" or "ephys"')

    subj_crit = (subject.Subject * use_subjects).aggr(
        sessions & restriction, 'subject_nickname', date_criterion='min(date(session_start_time))')

    # Query the training day at which criterion is reached
    subj_crit_day = (dj.U('subject_uuid', 'day_of_crit')
                     & (behavior_analyses.BehavioralSummaryByDate * subj_crit
                        & 'session_date=date_criterion').proj(day_of_crit='training_day'))

    # Query days around the day at which criterion is reached
    days = (behavior_analyses.BehavioralSummaryByDate * subject.Subject * subj_crit_day
            & ('training_day - day_of_crit between %d and %d'
               % (-days_from_criterion[0], days_from_criterion[1]))).proj(
                   'subject_uuid', 'subject_nickname', 'session_date')

    # Use dates to query sessions
    ses_query = acquisition.Session.aggr(
                            days, from_date='min(session_date)', to_date='max(session_date)')

    sessions = (acquisition.Session * ses_query & 'date(session_start_time) >= from_date'
                & 'date(session_start_time) <= to_date')

    # Exclude weird sessions
    sessions = sessions & dj.Not([{'session_uuid': u_id} for u_id in EXCLUDED_SESSIONS])

    # Transform to pandas dataframe if necessary
    if as_dataframe is True:
        sessions = sessions.fetch(format='frame').reset_index()
        days = days.fetch(format='frame').reset_index()

    return sessions, days


def query_session_around_performance(perform_thres=0.8, stage='training'):
    '''
    Parameters
    ----------
    perform_thres : float, optional
        DESCRIPTION. Performance threshold that need to be met in all 3
        session. The default is 0.8.
    stage:  string, optional.
        DESCRIPTION. Stage of trial too pull from datajoint to calculate
        performance. The default is training. Other options e.g 'biased'
    Returns
    -------
    selection : dataframe
        DESCRIPTION. Dataframe with all trials from mice reaching
        performance criterion
    '''
    from nma_ibl import behavior, subject, reference
    use_sessions = query_sessions(task='all', stable=False, as_dataframe=False,
                   force_cutoff=True, criterion=None)
    behav = dj2pandas(
        ((use_sessions & 'task_protocol LIKE "%' + stage + '%"')  # only get training sessions
         * subject.Subject * subject.SubjectLab * reference.Lab * behavior.TrialSet.Trial)

        # Query only the fields we require, reducing the size of the fetch
        .proj('institution_short', 'subject_nickname', 'task_protocol', 'session_uuid',
              'trial_stim_contrast_left', 'trial_stim_contrast_right', 'trial_response_choice',
              'task_protocol', 'trial_stim_prob_left', 'trial_feedback_type',
              'trial_response_time', 'trial_stim_on_time', 'session_end_time', 'time_zone')

        # Fetch as a pandas DataFrame, ordered by institute
        .fetch(order_by='institution_short, subject_nickname, session_start_time, trial_id',
               format='frame')
        .reset_index()
    )
    behav_ses = behav.groupby(['subject_nickname',
                'session_start_time']).mean()['correct_easy'].reset_index()
    behav_ses['above_criterion'] = behav_ses['correct_easy']>perform_thres
    # Check rolling sum of sessions above 0.8, must be 3
    behav_ses['met_session_criterion'] = \
        behav_ses.groupby(['subject_nickname']
                          )['above_criterion'].rolling(3).sum().to_numpy()
    # Select trials from sessions where criterion was first met
    selection = pd.DataFrame()
    for mouse in behav_ses['subject_nickname'].unique():
        mouse_ses  = behav_ses[behav_ses['subject_nickname']==mouse]
        if any(mouse_ses['met_session_criterion']==3):
            mouse_ses_select = mouse_ses.iloc[np.where(
                    mouse_ses['met_session_criterion']==3)[0][0]-2:\
                    np.where(mouse_ses['met_session_criterion']==3)[0][0]+1,:]
            trial_select = behav.loc[(behav['subject_nickname']==mouse) &
                      (behav['session_start_time'].isin(
                          mouse_ses_select['session_start_time']))]
            selection = pd.concat([selection,trial_select])
    return selection


def dj2pandas(behav):

    # make sure all contrasts are positive
    behav['trial_stim_contrast_right'] = np.abs(
        behav['trial_stim_contrast_right'])
    behav['trial_stim_contrast_left'] = np.abs(
        behav['trial_stim_contrast_left'])

    behav['signed_contrast'] = (
        behav['trial_stim_contrast_right'] - behav['trial_stim_contrast_left']) * 100
    # behav['signed_contrast'] = behav.signed_contrast.astype(int)

    behav['trial'] = behav.trial_id  # for psychfuncfit
    val_map = {'CCW': 1, 'No Go': 0, 'CW': -1}
    behav['choice'] = behav['trial_response_choice'].map(val_map)
    behav['correct'] = np.where(
        np.sign(behav['signed_contrast']) == behav['choice'], 1, 0)
    behav.loc[behav['signed_contrast'] == 0, 'correct'] = np.NaN

    behav['choice_right'] = behav.choice.replace(
        [-1, 0, 1], [0, np.nan, 1])  # code as 0, 100 for percentages
    behav['choice2'] = behav.choice_right  # for psychfuncfit
    behav['correct_easy'] = behav.correct
    behav.loc[np.abs(behav['signed_contrast']) < 50, 'correct_easy'] = np.NaN
    behav.rename(
        columns={'trial_stim_prob_left': 'probabilityLeft'}, inplace=True)
    behav['probabilityLeft'] = behav['probabilityLeft'] * 100
    behav['probabilityLeft'] = behav.probabilityLeft.astype(int)

    # compute rt
    if 'trial_response_time' in behav.columns:
        behav['rt'] = behav['trial_response_time'] - \
            behav['trial_stim_on_time']
        # ignore a bunch of things for missed trials
        # don't count RT if there was no response
        behav.loc[behav.choice == 0, 'rt'] = np.nan
        # don't count RT if there was no response
        behav.loc[behav.choice == 0, 'trial_feedback_type'] = np.nan

    # CODE FOR HISTORY
    behav['previous_choice'] = behav.choice.shift(1)
    behav.loc[behav.previous_choice == 0, 'previous_choice'] = np.nan
    behav['previous_outcome'] = behav.trial_feedback_type.shift(1)
    behav.loc[behav.previous_outcome == 0, 'previous_outcome'] = np.nan
    behav['previous_contrast'] = np.abs(behav.signed_contrast.shift(1))
    behav['previous_choice_name'] = behav['previous_choice'].map(
        {-1: 'left', 1: 'right'})
    behav['previous_outcome_name'] = behav['previous_outcome'].map(
        {-1: 'post_error', 1: 'post_correct'})
    behav['repeat'] = (behav.choice == behav.previous_choice)

    # # to more easily retrieve specific training days
    # behav['days'] = (behav['session_start_time'] -
    #                  behav['session_start_time'].min()).dt.days

    return behav


def num_star(pvalue):
    if pvalue < 0.0001:
        stars = '**** p < 0.0001'
    elif pvalue < 0.001:
        stars = '*** p < 0.001'
    elif pvalue < 0.01:
        stars = '** p < 0.01'
    elif pvalue < 0.05:
        stars = '* p < 0.05'
    else:
        stars = ''
    return stars

## Import

In [16]:
import numpy as np
import datetime
import seaborn as sns 
import matplotlib.pyplot as plt
import pandas as pd
from scipy import stats

# for basic logistic regression
import patsy # to build design matrix
import statsmodels.api as sm
from sklearn.model_selection import KFold

# for progress bar
from tqdm.auto import tqdm
tqdm.pandas(desc="model fitting")

# for GLM-HMM

!git clone git@github.com:slinderman/ssm.git
cd ssm
!pip install numpy cython
!pip install -e.
import ssm
from ssm.util import find_permutation
from ssm.plots import gradient_cmap, white_to_color_cmap

SyntaxError: ignored

In [None]:
from google.colab import drive
drive.mount('/content/drive')
import os
%cd '/content/drive/My Drive' 
!mkdir GLM-HMM
os.chdir('./researchHub')

# Load data

 Adapted from [figure5_GLM_modelfit.py](https://github.com/int-brain-lab/paper-behavior/blob/master/figure5_GLM_modelfit.py)  by IBL et al., 2021

In [6]:
institution_map, col_names = institution_map()

# select sessions
use_sessions, _ = query_sessions_around_criterion(criterion='biased',
                                                      days_from_criterion=[2, 3],
                                                      as_dataframe=False,
                                                      force_cutoff=True)

trial_fields = ('trial_stim_contrast_left', 'trial_stim_contrast_right',
                    'trial_response_time', 'trial_stim_prob_left',
                    'trial_feedback_type', 'trial_stim_on_time', 'trial_response_choice')

# query trial data for sessions and subject name and lab info
trials = use_sessions.proj('task_protocol') * behavior.TrialSet.Trial.proj(*trial_fields)
subject_info = subject.Subject.proj('subject_nickname') * \
                (subject.SubjectLab * reference.Lab).proj('institution_short')

# fetch, join and sort data as a pandas DataFrame
behav = dj2pandas(trials.fetch(format='frame')
                          .join(subject_info.fetch(format='frame'))
                          .sort_values(by=['institution_short', 'subject_nickname',
                                          'session_start_time', 'trial_id'])
                          .reset_index())
behav['institution_code'] = behav.institution_short.map(institution_map)
# split the two types of task protocols (remove the pybpod version number)
behav['task'] = behav['task_protocol'].str[14:20].copy()

# RECODE SOME THINGS JUST FOR PATSY
behav['contrast'] = np.abs(behav.signed_contrast)
behav['stimulus_side'] = np.sign(behav.signed_contrast)
behav['block_id'] = behav['probabilityLeft'].map({80:-1, 50:0, 20:1})

In [None]:
behav

Unnamed: 0,subject_uuid,trial_id,session_start_time,lab_name,task_protocol,trial_response_time,trial_response_choice,trial_stim_on_time,trial_stim_contrast_left,trial_stim_contrast_right,trial_feedback_type,probabilityLeft,subject_nickname,institution_short,signed_contrast,trial,choice,correct,choice_right,choice2,correct_easy,rt,previous_choice,previous_outcome,previous_contrast,previous_choice_name,previous_outcome_name,repeat,institution_code,task,contrast,stimulus_side,block_id
0,bc381af7-631d-4ed0-95f9-8231c830158a,1,2019-05-03 10:00:26,danlab,_iblrig_tasks_biasedChoiceWorld4.1.0,1.731200,CCW,0.634500,0.0000,1.0000,1,50,DY_001,Berkeley,100.00,1,1,1.0,1.0,1.0,1.0,1.0967,,,,,,False,Lab 7,biased,100.00,1.0,0.0
1,bc381af7-631d-4ed0-95f9-8231c830158a,2,2019-05-03 10:00:26,danlab,_iblrig_tasks_biasedChoiceWorld4.1.0,5.972800,CW,5.767800,0.2500,0.0000,1,50,DY_001,Berkeley,-25.00,2,-1,1.0,0.0,0.0,,0.2050,1.0,1,100.00,right,post_correct,False,Lab 7,biased,25.00,-1.0,0.0
2,bc381af7-631d-4ed0-95f9-8231c830158a,3,2019-05-03 10:00:26,danlab,_iblrig_tasks_biasedChoiceWorld4.1.0,24.150200,CW,8.019400,0.0000,0.0625,-1,50,DY_001,Berkeley,6.25,3,-1,0.0,0.0,0.0,,16.1308,-1.0,1,25.00,left,post_correct,True,Lab 7,biased,6.25,1.0,0.0
3,bc381af7-631d-4ed0-95f9-8231c830158a,4,2019-05-03 10:00:26,danlab,_iblrig_tasks_biasedChoiceWorld4.1.0,27.829800,CCW,27.468500,0.0000,1.0000,1,50,DY_001,Berkeley,100.00,4,1,1.0,1.0,1.0,1.0,0.3613,-1.0,-1,6.25,left,post_error,False,Lab 7,biased,100.00,1.0,0.0
4,bc381af7-631d-4ed0-95f9-8231c830158a,5,2019-05-03 10:00:26,danlab,_iblrig_tasks_biasedChoiceWorld4.1.0,33.860200,CCW,29.951600,0.0000,1.0000,1,50,DY_001,Berkeley,100.00,5,1,1.0,1.0,1.0,1.0,3.9086,1.0,1,100.00,right,post_correct,True,Lab 7,biased,100.00,1.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
493340,5ba2da46-8213-4dd1-ac4e-3dc1eafd7141,959,2019-10-31 10:48:08,cortexlab,_iblrig_tasks_biasedChoiceWorld6.0.5,4061.304398,CW,4060.899298,1.0000,0.0000,1,80,KS025,UCL,-100.00,959,-1,1.0,0.0,0.0,1.0,0.4051,-1.0,1,12.50,left,post_correct,True,Lab 1,biased,100.00,-1.0,-1.0
493341,5ba2da46-8213-4dd1-ac4e-3dc1eafd7141,960,2019-10-31 10:48:08,cortexlab,_iblrig_tasks_biasedChoiceWorld6.0.5,4088.302899,CCW,4065.285699,0.0000,0.1250,1,80,KS025,UCL,12.50,960,1,1.0,1.0,1.0,,23.0172,-1.0,1,100.00,left,post_correct,False,Lab 1,biased,12.50,1.0,-1.0
493342,5ba2da46-8213-4dd1-ac4e-3dc1eafd7141,961,2019-10-31 10:48:08,cortexlab,_iblrig_tasks_biasedChoiceWorld6.0.5,4096.576298,CW,4093.115298,0.0625,0.0000,1,80,KS025,UCL,-6.25,961,-1,1.0,0.0,0.0,,3.4610,1.0,1,12.50,right,post_correct,False,Lab 1,biased,6.25,-1.0,-1.0
493343,5ba2da46-8213-4dd1-ac4e-3dc1eafd7141,962,2019-10-31 10:48:08,cortexlab,_iblrig_tasks_biasedChoiceWorld6.0.5,4105.841298,CW,4099.868098,0.1250,0.0000,1,80,KS025,UCL,-12.50,962,-1,1.0,0.0,0.0,,5.9732,-1.0,1,6.25,left,post_correct,True,Lab 1,biased,12.50,-1.0,-1.0


# Basic logistic regression (Q1)

 Adapted from [figure5_GLM_modelfit.py](https://github.com/int-brain-lab/paper-behavior/blob/master/figure5_GLM_modelfit.py) and [figure5_GLM_plot.py](https://github.com/int-brain-lab/paper-behavior/blob/master/figure5_GLM_plot.py) by IBL et al., 2021

## Define the model

In [8]:
def fit_q1(behav, prior_blocks=False, folds=5):

    # drop trials with contrast-level 50, only rarely present (should not be its own regressor)
    behav = behav[np.abs(behav.signed_contrast) != 50]
    # drop trials with No-go choice
    behav = behav[behav.choice != 0]

    # use patsy to easily build design matrix
    if not prior_blocks:
        endog, exog = patsy.dmatrices('choice ~ 1 + signed_contrast',
                               data=behav.dropna(subset=['choice']).reset_index(),
                                      return_type='dataframe')
    else:
        endog, exog = patsy.dmatrices('choice ~ 1 + signed_contrast + block_id',
                               data=behav.dropna(subset=['choice', 'block_id']).reset_index(),
                                      return_type='dataframe')

    # recode choices for logistic regression
    endog['choice'] = endog['choice'].map({-1:0, 1:1})

    # rename columns
    exog.rename(columns={'Intercept': 'bias', 'signed_contrast': 'contrast'}, inplace=True)

    # NOW FIT THIS WITH STATSMODELS - ignore NaN choices
    logit_model = sm.Logit(endog, exog)
    res = logit_model.fit_regularized(disp=False) # run silently

    # what do we want to keep?
    params = pd.DataFrame(res.params).T
    params['pseudo_rsq'] = res.prsquared # https://www.statsmodels.org/stable/generated/statsmodels.discrete.discrete_model.LogitResults.prsquared.html?highlight=pseudo
    params['condition_number'] = np.linalg.cond(exog)

    # ===================================== #
    # ADD MODEL ACCURACY - cross-validate

    kf = KFold(n_splits=folds, shuffle=True)
    acc = np.array([])
    for train, test in kf.split(endog):
        X_train, X_test, y_train, y_test = exog.loc[train], exog.loc[test], \
                                           endog.loc[train], endog.loc[test]
        # fit again
        logit_model = sm.Logit(y_train, X_train)
        res = logit_model.fit_regularized(disp=False)  # run silently

        # compute the accuracy on held-out data [from Luigi]:
        # suppose you are predicting Pr(Left), let's call it p,
        # the % match is p if the actual choice is left, or 1-p if the actual choice is right
        # if you were to simulate it, in the end you would get these numbers
        y_test['pred'] = res.predict(X_test)
        y_test.loc[y_test['choice'] == 0, 'pred'] = 1 - y_test.loc[y_test['choice'] == 0, 'pred']
        acc = np.append(acc, y_test['pred'].mean())

    # average prediction accuracy over the K folds
    params['accuracy'] = np.mean(acc)

    return params  # wide df

## Fit the model

In [9]:
print('fitting GLM to BASIC task...')
params_basic = behav.loc[behav.task == 'traini', :].groupby(
    ['institution_code', 'subject_nickname']).progress_apply(fit_q1,
                                                     prior_blocks=False).reset_index()
print('The mean condition number for the basic model is', params_basic['condition_number'].mean())
 
print('fitting GLM to FULL task...')
params_full = behav.loc[behav.task == 'biased', :].groupby(
    ['institution_code', 'subject_nickname']).progress_apply(fit_q1,
                                                     prior_blocks=True).reset_index()
print('The mean condition number for the full model is', params_full['condition_number'].mean())

fitting GLM to BASIC task...


HBox(children=(FloatProgress(value=0.0, description='model fitting', max=83.0, style=ProgressStyle(description…


The mean condition number for the basic model is 48.69541179343498
fitting GLM to FULL task...


HBox(children=(FloatProgress(value=0.0, description='model fitting', max=102.0, style=ProgressStyle(descriptio…


The mean condition number for the full model is 54.35945226287389


In [10]:
params_basic

Unnamed: 0,institution_code,subject_nickname,level_2,bias,contrast,pseudo_rsq,condition_number,accuracy
0,Lab 1,KS003,0,-0.768511,0.107334,0.443390,48.272847,0.765406
1,Lab 1,KS005,0,-0.427894,0.041561,0.235938,47.439400,0.649423
2,Lab 1,KS014,0,-0.009918,0.051446,0.280886,48.987346,0.670277
3,Lab 1,KS015,0,0.054683,0.048861,0.279347,49.698866,0.669409
4,Lab 1,KS016,0,0.379327,0.062838,0.327613,48.968198,0.697861
...,...,...,...,...,...,...,...,...
78,Lab 7,DY_009,0,-0.005298,0.123430,0.453440,47.429029,0.752729
79,Lab 7,DY_010,0,-0.638529,0.146620,0.517061,49.275500,0.808086
80,Lab 7,DY_011,0,-0.115479,0.130279,0.487495,47.948316,0.775089
81,Lab 7,DY_013,0,-0.247382,0.121104,0.474473,50.371781,0.775050


In [11]:
params_full

Unnamed: 0,institution_code,subject_nickname,level_2,bias,contrast,block_id,pseudo_rsq,condition_number,accuracy
0,Lab 1,KS002,0,0.549494,0.042048,0.768044,0.354734,56.948718,0.713532
1,Lab 1,KS003,0,-0.704172,0.031142,0.912381,0.336067,54.347196,0.710864
2,Lab 1,KS004,0,-0.062796,0.043136,0.757547,0.324462,54.620181,0.697368
3,Lab 1,KS005,0,-0.395967,0.072999,0.873631,0.473665,57.741461,0.774321
4,Lab 1,KS014,0,0.128963,0.058329,0.649794,0.369183,56.234239,0.718102
...,...,...,...,...,...,...,...,...,...
97,Lab 7,DY_009,0,0.097167,0.106754,0.555988,0.460132,53.987795,0.763789
98,Lab 7,DY_010,0,-0.763517,0.254063,0.421754,0.641836,53.724658,0.857144
99,Lab 7,DY_011,0,-0.024697,0.125729,0.519073,0.485328,54.968647,0.779207
100,Lab 7,DY_013,0,-0.808344,0.148816,0.539884,0.540344,54.783802,0.812022


In [25]:
params_basic.describe()

Unnamed: 0,level_2,bias,contrast,pseudo_rsq,condition_number,accuracy
count,83.0,83.0,83.0,83.0,83.0,83.0
mean,0.0,-0.126812,0.080356,0.358837,48.695412,0.719499
std,0.0,0.516077,0.037193,0.10744,2.389612,0.05371
min,0.0,-1.643345,0.004958,0.011409,44.70112,0.552813
25%,0.0,-0.531784,0.057492,0.303822,47.672252,0.693367
50%,0.0,-0.10603,0.075439,0.351486,48.540863,0.71704
75%,0.0,0.244536,0.096427,0.426429,49.187288,0.754355
max,0.0,1.449161,0.229594,0.648881,65.286924,0.848665


In [26]:
params_full.describe()

Unnamed: 0,level_2,bias,contrast,block_id,pseudo_rsq,condition_number,accuracy
count,102.0,102.0,102.0,102.0,102.0,102.0,102.0
mean,0.0,-0.066749,0.082014,0.644681,0.416885,54.359452,0.746466
std,0.0,0.444307,0.043789,0.177244,0.091936,1.959225,0.046309
min,0.0,-0.85865,0.020652,0.227658,0.135334,50.309617,0.595803
25%,0.0,-0.389735,0.051032,0.540521,0.350179,52.946381,0.712581
50%,0.0,-0.091944,0.073482,0.653551,0.412069,54.0774,0.743569
75%,0.0,0.148992,0.097708,0.760657,0.472337,55.73286,0.778473
max,0.0,1.146744,0.254063,0.995147,0.641836,60.538145,0.859184


## Plot

# GLM-HMM