# Choice of Priors

In [None]:
from scipy.stats import norm

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Define parameters.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## Prior parameters
mu = 0
sd = 1
scale = 50

## Sampling parameters.
x = np.linspace(-5,5,100)

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Define useful functions.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

def inv_logit(arr):
    return 1 / (1 + np.exp(-arr))

def phi_approx(arr):
    return inv_logit(0.07056 * arr ** 3 + 1.5976 * arr)

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Plot.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## Initialize canvas.
fig, axes = plt.subplots(1,3,figsize=(16,4))

## Compute PDF.
y = norm(mu,sd).pdf(x)

## Plot functions.
axes[0].plot(x, phi_approx(x), lw=2.5)
axes[0].set(xlabel='x', yticks=np.linspace(0,1,3), ylabel='y', title='phi_approx')

## Plot learning rate.
axes[1].plot(phi_approx(x) * scale, y, lw=2.5)
axes[1].set(xlabel=r'%s $\cdot$ phi_approx( $\mathcal{N}(%s, %s)$ )' %(scale,mu,sd), 
            xticks=np.linspace(0,scale,5), ylabel='PDF', 
            title=r'Inverse Temperature ($\beta$)')

## Plot learning rate.
axes[2].plot(phi_approx(x), y, lw=2.5)
axes[2].set(xlabel=r'phi_approx( $\mathcal{N}(%s, %s)$ )' %(mu,sd), 
            ylabel='PDF', title=r'Learning Rate ($\eta$)')

sns.despine()
plt.tight_layout()

### Baseline model: Standard RL (single subject)

In [None]:
import pystan
from scripts.plotting_utility import plot_toy_model

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Define parameters.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## Subject parameters.
subjects = np.arange(data.Datetime.unique().size)

## RL parameters.
q = 0

## Sampling parameters.
model_name = 'moodRL_toy.stan'
samples = 1000
warmup = 750
chains = 4
thin = 1
n_jobs = 2
   
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Main loop.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
    
## Locate stan model.
file = 'stan_models/%s' %model_name
    
for i, subject in enumerate(subjects):
    
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
    ### Prepare data.
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

    ## Define subject.
    datetime = data.Datetime.unique()[subject]

    ## Reduce DataFrame.
    df = data[np.logical_and(data.Datetime==datetime, data.Block < 4)].copy()

    ## Drop trials with missing data.
    df = df[df.Choice.notnull()]

    ## Extract and prepare data.
    X = df[['M1','M2']].values
    Y = np.array([ np.argmax(x == y)+1 for x, y in zip(X, df.Choice.values) ])
    R = df.Outcome.values
    B = df.Block.values

    ## Define metadata.
    T = R.size

    ## Organize data dictionary.
    dd = dict(T=T, X=X, Y=Y, R=R, B=B, q=q)

    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
    ### Model fitting.
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

    ## Fit model.
    fit = pystan.stan(file=file, data=dd, iter=samples, warmup=warmup, thin=thin,
                      chains=chains, seed=47404, n_jobs=n_jobs)

    ## Plot.
    plot_toy_model('plots/%s/subj%s.png' %(model_name.replace('.stan',''), i), fit)
    
print('Done.')

### Baseline model: Standard RL (hierarchical)

In [None]:
import pystan, time

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Define parameters.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## Model parameters.
q = 0

## Sampling parameters.
model_name = 'moodRL_base.stan'
samples = 100
warmup = 50
chains = 2
thin = 1
n_jobs = 2
    
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Prepare data.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## Reduce DataFrame.
df = data[data.Block < 4].copy()

## Drop trials with missing data.
df = df[df.Choice.notnull()]

## Extract and prepare data.
X = df[['M1','M2']].values
Y = np.array([ np.argmax(x == y)+1 for x, y in zip(X, df.Choice.values) ])
R = df.Outcome.values

## Define metadata.
_, ix = np.unique(df.Datetime, return_inverse=True)
ix += 1
T = ix.size
N = ix.max()

## Organize data dictionary.
dd = dict(T=T, N=N, ix=ix, X=X, Y=Y, R=R, q=q)

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Model fitting.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
print('n_samples = %0.0f.' %((samples - warmup) * chains / thin), end=' ')

file = 'stan_models/%s' %model_name
st = time.time()
fit = pystan.stan(file=file, data=dd, iter=samples, warmup=warmup, thin=thin,
                  chains=chains, seed=47404, n_jobs=n_jobs)
print('Elapsed time: %0.2f s.' %(time.time()-st), end='\n\n')

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Model results.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## Extract results.
parameters = fit.extract()