# Wald-RL Model Fitting

## Section 1: Motivating the model
### Reaction time effects

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pandas import Categorical, DataFrame, read_csv
sns.set_style('white')
sns.set_context('notebook', font_scale=1.5)
%matplotlib inline

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Load and prepare data.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## Load and prepare behavior data.
data = read_csv('data/moodRL_data.csv')
data = data[data.Block < 4]

## Define trial types (e.g. 20% vs. 40%).
cond = np.where(data[['M1','M2']] % 3, data[['M1','M2']] % 3, 3).sum(axis=1)
cond = Categorical(cond, categories=[3,5,4], ordered=True).rename_categories(['20% vs. 40%', 
                                                                              '40% vs. 60%', 
                                                                              '20% vs. 60%'])
data['Cond'] = cond

## Bin trials into thirds.
data['Bins'] = (data.Trial - 1) // 14

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Plotting.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## Initialize canvas.
fig, axes = plt.subplots(2, 3, figsize=(12,6))

## Plot.
for i in np.arange(3):
    
    ## Plot RT timeseries.
    sns.lineplot('Trial', 'RT', data=data[data.Block==i+1], color='k', zorder=0, ax=axes[0,i])
    ax=axes[0,i].set(xlim=(1,42), xticks=[1,10,20,30,40], ylim=(0.8,1.6), 
                     yticks=[], ylabel='', title='Block %s' %(i+1))
    if not i: axes[0,i].set(yticks=np.arange(0.8,1.7,0.2), ylabel='Reaction Time')

    ## Plot RT bins.
    sns.pointplot('Bins', 'RT', 'Cond', data=data[data.Block==i+1], ax=axes[1,i])
    axes[1,i].set(xticklabels=['1-14','15-28','29-42'], xlabel='Trials', ylim=(0.8,1.4),
                  yticks=[], ylabel='')
    axes[1,i].legend_.set_visible(False)
    if not i: axes[1,i].set(yticks=np.arange(0.8,1.5,0.2), ylabel='Reaction Time')

axes[1,2].legend(loc=7,bbox_to_anchor=(1.5,0.5),handletextpad=0)
        
sns.despine()
plt.tight_layout()

### Sources of speed-up

In [None]:
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Define parameters.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## Set True you want to include post-query trials.
## Does not make a major effect.
mask = False
query_trials = [8,15,22,29,36]

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Prepare percentile data (10%, 90%).
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

if mask: perc = data[[True if t not in query_trials else False for t in data.Trial]].copy()
else: perc = data.copy()
    
## Compute percentile data.
perc = perc.groupby(['Block','Datetime']).RT.apply(lambda arr: np.nanpercentile(arr, [10,90])).reset_index()
perc['10%'] = perc.RT.apply(lambda arr: arr[0])
perc['90%'] = perc.RT.apply(lambda arr: arr[1])

## Melt and reshape DataFrame for plotting.
perc = perc.drop('RT', 1).melt(id_vars=('Block','Datetime'), var_name='Percentile', value_name='RT')

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Plot.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## Initialize canvas.
fig, axes = plt.subplots(1,2,figsize=(12,4))

## Plot KDE.
for i in np.arange(3): sns.kdeplot(data.loc[data.Block==i+1,'RT'].dropna(), lw=3, cut=1,
                                   label='Block %s' %(i+1), ax=axes[0])
axes[0].set(xlabel='Reaction Time (s)', ylabel='Density')
    
## Plot stripplot.
sns.stripplot('Percentile', 'RT', 'Block', data=perc, dodge=True, jitter=True, alpha=0.7, ax=axes[1])
axes[1].scatter([-0.27,0,0.25,0.75,1,1.27], perc.groupby(['Percentile','Block']).RT.mean().values,
                s=150, marker='d', color='k', zorder=100)
axes[1].set(xlabel='RT Percentile', ylabel='Reaction Time (s)')
axes[1].legend_.set_visible(False)

sns.despine()
plt.tight_layout()

In [None]:
from spm1d.stats import anova1rm

for p in perc.Percentile.unique():

    ## Prepare data.
    ix, = np.where(perc.Percentile==p)
    Y = perc.loc[ix, 'RT'].values
    _, A = np.unique(perc.loc[ix, 'Block'], return_inverse=True)
    _, SUBJ = np.unique(perc.loc[ix, 'Datetime'], return_inverse=True)
    
    ## Fit 1-way repeated measures ANOVA.
    fit = anova1rm(Y, A, SUBJ).inference(alpha=0.05)
    print('%s RT: F = %0.3f, p = %0.3f' %(p, fit.z, fit.p))

### Task-switching

In [None]:
query_trials = np.array([8,15,22,29,36])

## Reshape RT data.
N = data.Datetime.unique().size
B = data.Block.max()
T = data.Trial.max()
RT = data.RT.values.reshape(N,B,T)

## Initialize canvas.
fig, axes = plt.subplots(1,2,figsize=(12,4),sharex=True,sharey=True)

for i, ax, color, title in zip(range(2), axes, ['#774576','k'], ['post-query trials', 'regular trials']):

    query_trials += i * 3
    
    ## Preallocate space.
    penalty = np.zeros((N,B,len(query_trials)))

    ## Iteratively compute RT penalty.
    for j, qt in enumerate(query_trials):
        penalty[...,j] = RT[...,qt-1] - RT[...,[qt-2,qt]].mean(axis=-1)

    ## Convert to DataFrame.
    penalty = DataFrame(penalty.reshape(N,B*len(query_trials)))
    penalty = penalty.melt(var_name='Query', value_name='RT')
    
    ## Plot.
    sns.pointplot('Query', 'RT', data=penalty, join=False, color=color, ax=ax)
    ax.hlines(0,0,15,zorder=0,lw=0.75)
    ax.fill_between([4.5,9.5],-0.25,0.55,color='k',alpha=0.05)
    ax.set(xticks=np.arange(2,15,5), xticklabels=['Block 1', 'Block 2', 'Block 3'],
           xlabel='', ylim=(-0.25,0.55), ylabel='', title=title)
    if not i: ax.set_ylabel(r'$\Delta$ RT')
        
sns.despine()
plt.tight_layout()

### Likelihood of drift rates

In [None]:
from scripts.simulations import slot_machine_game, softmax

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Define parameters.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## Define simulation parameters.
n_sim = 500
n_trials = 42

## Define agent parameters.
beta = 9
eta_v = 0.09

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Main loop.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
np.random.seed(47404)

## Preallocate space.
drift = np.zeros((n_sim, n_trials))

test = np.zeros(3)
for i in np.arange(n_sim):
    
    ## Simulate block.
    x, r = slot_machine_game(n_trials=n_trials, reward=1)

    ## Initialize Q-values.
    Q = np.zeros(3)
    
    ## Simulate behavior.
    for j in np.arange(n_trials):
        
        ## Store drift.
        drift[i,j] = np.abs(np.diff(Q[x[j]]))
        
        ## Simulate choice.
        p = softmax(beta * Q[x[j]])
        y = np.argmax(np.random.multinomial(1, p))
                                
        ## Compute reward prediction error.
        delta = r[j,y] - Q[x[j,y]]

        ## Update expectations.
        Q[x[j,y]] += eta_v * delta

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Plotting.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## Initialize canvas.
fig, ax = plt.subplots(1,1,figsize=(8,4))

## Plot.
# sns.distplot(drift.flatten(), kde=False, cumu hist_kws=dict(alpha=1, edgecolor='w'), ax=ax)
ax.hist(drift.flatten(), cumulative=True)
ax.set(xlim=(-0.01,0.8), xlabel='drift = abs($Q_1 - Q_2$)', ylabel='Count')

sns.despine()
plt.tight_layout()

### Priors on shifted Wald distribution

In [None]:
from scripts.wald import shifted_wald_pdf

## Define parameters.
gamma = np.arange(0,0.8,0.1)
alpha = [0.4, 0.6, 0.8]
theta = 0.5

## Initialize canvas.
fig, axes = plt.subplots(1,3,figsize=(12,4))
palette = sns.color_palette('GnBu_d', n_colors=gamma.size)

x = np.linspace(0,3,1000)
for ax, a in zip(axes, alpha):
    
    for g, color in zip(gamma, palette[::-1]):
        
        ## Plot PDF.
        ax.plot(x, shifted_wald_pdf(x, g, a, theta), label=r'$\gamma = %0.1f$' %g,
                lw=2, color=color, alpha=0.8)

    ## Add info.
    ax.set(xlabel='RT', ylabel='PDF', title=r'$\alpha = %s$' %a)
        
## Add legend.
ax.legend(loc=1, borderpad=0, labelspacing=0, fontsize=12)
        
sns.despine()
plt.tight_layout()

In [None]:
x = np.linspace(0,3,1000)

beta = 8

for d in np.arange(0,0.6,0.1):
    plt.plot(x, shifted_wald_pdf(x, beta*d, 1, 0.5), label='%0.1f' %d)
    
plt.legend()

## Section 2: Model Fitting

In [None]:
import os, pystan
import _pickle as cPickle
from pandas import DataFrame, read_csv
from scripts.diagnostics import *
from scripts.utilities import normalize
from scripts.plotting import plot_subject , plot_subject_rt
from scripts.wald import init_shifted_wald, wald_generate_quantities
%load_ext jupyternotify

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Define parameters.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## Select model.
model_name = 'moodRL_wald_full_mood_glob.stan'

## Sampling parameters.
samples = 1250
warmup = 1000
chains = 4
thin = 1
n_jobs = 4

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Load and prepare data.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## Load and prepare behavior data.
data = read_csv('data/moodRL_data.csv')
data = data[data.Block < 4]
data = data.fillna(-1)

## Load and prepare ratings data.
ratings = read_csv('data/moodRL_ratings.csv')
ratings = ratings[ratings.Variable=='Mood']

## Load and prepare metadata.
metadata = read_csv('data/moodRL_metadata.csv')

## Extract and prepare stimulus presentation data.
## Stimulus presentation is sorted such that the
## more valuable machine occupies the right column.
X = data[['M1','M2']].values
X = np.sort(X, axis=-1) 

## Extract and prepare choice data. Choice data 
## recoded to range [1, 2], where 1 = less valuable, 
# 2 = more valuable. 
Y = data.Choice.values
Y = np.equal(X[:,-1], Y).astype(int)

## Extract and prepare reward data. All rewards 
## scaled between [0,1].
R = data.Outcome.values
R = np.where(R > 0, 1, 0)

## Extract and prepare RT data.
Z = data.RT.values

## Extract and prepare mood data. All data
## scaled between [1,9].
M = ratings.loc[ratings.Trial>0, 'Rating'].values / 4

m2 = ratings.loc[np.logical_and(ratings.Block==2, ratings.Trial==0),'Rating'].values / 4
m2 = np.where(m2==-1, -0.99, np.where(m2==1, 0.99, m2))

## Define subject index.
_, subj_ix = np.unique(data.Subject, return_inverse=True)
subj_ix += 1

## Define mood index.
mood_ix = np.in1d(data.Trial, [7,21,35]).astype(int)
shift, = np.where(np.logical_and(mood_ix, data.RT < 0))
mood_ix[shift] = 0
mood_ix[shift-1] = 1
mood_ix[np.where(mood_ix)] = np.arange(mood_ix.sum()) + 1

## Define block index.
block_ix = np.zeros_like(mood_ix)
block_ix[np.logical_and(data.Block==1, data.Trial==1)] = 1
block_ix[np.logical_and(data.Block==2, data.Trial==1)] = 2

## Remove trials with missing data.
X = X[data.RT > 0]
Y = Y[data.RT > 0]
R = R[data.RT > 0]
Z = Z[data.RT > 0]
subj_ix = subj_ix[data.RT > 0]
mood_ix = mood_ix[data.RT > 0]
block_ix = block_ix[data.RT > 0]

## Define metadata.
N = data.Subject.max()
T = Y.size

## Organize data dictionary.
dd = dict(N=N, T=T, subj_ix=subj_ix, mood_ix=mood_ix, block_ix=block_ix,
          X=X, Y=Y, R=R, Z=Z, M=M, m2=m2)

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Fit model with Stan.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
np.random.seed(47404)

if model_name:

    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
    ### Model fitting and diagnostics.
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
        
    ## Definite initial parameters.
    if model_name == 'moodRL_wald_no_mood.stan':
        init = init_shifted_wald(N, 4, chains)
    elif model_name == 'moodRL_wald_no_mood_bias.stan':
        init = init_shifted_wald(N, 5, chains)
    elif 'full_mood' in model_name:
        init = init_shifted_wald(N, 6, chains)
        
    ## Fit model.
    file = 'stan_models/%s' %model_name
    fit = pystan.stan(file=file, data=dd, iter=samples, warmup=warmup, thin=thin, chains=chains, 
                      init=init, control=dict(adapt_delta = 0.9), n_jobs=n_jobs, seed=47404)
    check_div(fit); check_treedepth(fit); check_energy(fit); check_n_eff(fit), check_rhat(fit)

    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
    ### Generated quantities.
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

    ## Extract parameters.
    extract = fit.extract()

    ## Store metadata.
    extract['N'] = data.Datetime.unique().size
    extract['B'] = data.Block.max()
    extract['T'] = data.Trial.max()

    ## Store trial information.
    extract['X'] = np.sort(data[['M1','M2']].values.reshape(extract['N'],extract['B'],extract['T'],2),axis=-1)
    extract['Y'] = data.Choice.values.reshape(extract['N'],extract['B'],extract['T'],1)
    extract['Y'] = np.where(np.any(np.equal(extract['X'], extract['Y']), axis=-1), 
                            np.argmax(np.equal(extract['X'], extract['Y']), axis=-1) + 1, -1 )
    extract['R'] = np.where(data.Outcome.values.reshape(extract['N'],extract['B'],extract['T']) > 0, 1, 0)
    extract['Z'] = data.RT.values.reshape(extract['N'],extract['B'],extract['T'])
    extract['M'] = ratings.loc[ratings.Trial>0, 'Rating'].values.reshape(extract['N'],extract['B'],3) / 4
    extract['m2'] = ratings.loc[np.logical_and(ratings.Block==2, ratings.Trial==0),'Rating'].values / 4
    extract['m2'] = np.where(extract['m2']==-1, -0.99, np.where(extract['m2']==1, 0.99, extract['m2']))

    ## Compute generated quantities.
    extract = wald_generate_quantities(extract)
    
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
    ### Save data.
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
    
    ## Create out-directory.
    out_dir = 'stan_fits/%s' %model_name.replace('.stan','')
    if not os.path.isdir(out_dir): os.makedirs(out_dir)

    ## Save summary file.
    summary = fit.summary()
    summary = DataFrame(summary['summary'], columns=summary['summary_colnames'], index=summary['summary_rownames'])
    summary.to_csv(os.path.join(out_dir, 'summary.csv'))

    ## Save contents of StanFit.
    with open(os.path.join(out_dir, 'StanFit.pickle'), 'wb') as fn: cPickle.dump(extract, fn)

    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
    ### Plot subjects.
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

    ## Make plots dir.
    plots_dir = os.path.join(out_dir, 'plots')
    if not os.path.isdir(plots_dir): os.makedirs(plots_dir)

    ## Iteratively plot.
    for i in np.arange(extract['N']):

        ## Plot behavior and save.
        fig, axes = plot_subject(extract, i)
        fig.savefig(os.path.join(plots_dir, 'subj_%s.png' %(i+1)), dpi=180)
        plt.close('all')
        
        ## Plot RTs and save.
        fig, axes = plot_subject_rt(extract, i, ds=2)
        fig.savefig(os.path.join(plots_dir, 'subj_%s_rt.png' %(i+1)), dpi=180)
        plt.close('all')
        
    print('Done.')

In [None]:
fit

# Model comparison

In [None]:
from scripts.plotting import plot_group_behavior

## Define parameters.
models = ['moodRL_centered_no_mood', 'moodRL_wald_no_mood']
labels = ['RL', 'Wald']
colors = sns.color_palette('GnBu_d', n_colors=4)

## Plot models.
fig, ax = plt.subplots(1,1,figsize=(12,4))
for model, observed, color, label in zip(models, [1,0,0,0], colors, labels):
    plot_group_behavior(model, observed=observed, color=color, label=label, ax=ax)
ax.legend(loc=4, borderpad=0, fontsize=14, labelspacing=0)

sns.despine()
plt.tight_layout()

In [None]:
from scripts.utilities import model_comparison

fig, axes = plt.subplots(1,3,figsize=(12,4))
xlabels = ['Choice', 'Mood', 'Both']
ylimits = np.array([(6220,6230), (0,20), (6200,6250)])

for ax, var, xlabel, ylim in zip(axes, ['y','m','both'], xlabels, ylimits):
    
    for i in np.arange(1,len(models)):
        
        ## Perform model comparison.
        w1, w2, se = model_comparison(models[i-1], models[i], on=var, verbose=False)
        
        ## Plot.
        ax.bar([i-1,i],[w1,w2],color=colors[i-1:i+1])
        
        ## Add difference.
        if np.abs(w1 - w2) > se: ax.hlines(ylim[1], i-0.85, i-0.15 )
        
    if np.logical_xor(*ylim>0): ax.hlines(0,*ax.get_xlim(),lw=0.5)
    ax.set(xticks=[], xlabel=xlabel, ylim=ylim, ylabel='WAIC')
    
## Add legend.
for color, label in zip(colors, labels): ax.bar(0,0,color=color,label=label)
ax.legend(loc=7, bbox_to_anchor=(1.75,0.5), fontsize=14, borderpad=0)
    
sns.despine()
plt.tight_layout()