In [1]:
import random
import copy
import numpy as np
import theano as T
import theano.tensor as tt
#import pymc3 as pm
import scipy.stats as st
import pandas as pd
import matplotlib.pyplot as plt


import rpy2
from rpy2.robjects import pandas2ri

pandas2ri.activate()
%load_ext rpy2.ipython

In [2]:
%%capture  --no-stderr --no-stdout --no-display
%%R

library(ggplot2)
library(reshape)
library(grid)
library(dplyr)
library(gridExtra)
library(lme4)

paper_theme <- theme_light() + theme( axis.title.x = element_text(size=18),
  axis.text.x=element_text(colour="black", 
                           size = 14), 
  axis.title.y = element_text(size = 18, vjust = 1),
  axis.text.y  = element_text(size = 14),
  strip.text=element_text(size=16),
  axis.line.x = element_line(colour = "black"), 
  axis.line.y = element_line(colour = "black"),
  legend.title=element_text(size=18),
  legend.text=element_text(size=16))  

  

paper_theme_2 <- theme_light() + theme(
                                       legend.text=element_text(size=14), 
                            legend.title=element_blank(),
                     legend.key = element_rect(color = "transparent", fill="transparent"),
                            legend.position=c(0.82,0.82),
                            strip.text=element_blank(), 

                            axis.text.x=element_text(size=16,color="black"),
                            axis.text.y=element_text(size=16,color="black"),
                            axis.title.x=element_text(face="plain", size=20),
                            axis.title.y=element_text(face="plain", size=20)) 


Attaching package: ‘dplyr’



    rename



    filter, lag



    intersect, setdiff, setequal, union


Attaching package: ‘gridExtra’



    combine



Attaching package: ‘Matrix’



    expand


Attaching package: ‘lme4’



    sigma




In [3]:
data = pd.read_csv("mc_data_big.csv")
subj_group = data.groupby(["subj","incr","n_query", "stim"])


subj_data = {}
for i, trial in subj_group:
    subj = list(trial["subj"])[0]
    if subj not in subj_data:
        subj_data[subj] = []
        
    cnt = [x.count(".") for x in list(trial["stim"])]
    
    tup = (list(trial["guess"])[0],
               list(trial["mod_pred"]), list(trial["cplx"]), list(trial["dist"]), cnt)
    subj_data[subj].append(tup)

    




In [65]:
keys = list(subj_data.keys())
subjs = []
hum_preds = []
mod_preds = []
cplxs = []
dists = []
seen = []
for key in keys:
    tups = subj_data[key]
    m1,m2,m3,m4,m5,m6 = [],[],[],[],[],[]
    if len(tups) == 160:
        for tup in tups:
            m1.append(key)
            m2.append(tup[0])
            m3.append(tup[1])
            m4.append(tup[2])
            m5.append(tup[3])
            m6.append(tup[4])
        subjs.append(m1)
        hum_preds.append(m2)
        mod_preds.append(m3)
        cplxs.append(m4)
        dists.append(m5)
        seen.append(m6)

subjs = np.array(subjs)
hum_preds = np.array(hum_preds)
mod_preds = np.array(mod_preds)
cplxs = np.array(cplxs)
dists = np.array(dists)
seen = np.array(seen)
seen = 16-seen

mod_shape = cplxs.shape
hum_shape = hum_preds.shape

In [5]:


x = np.ones((2,3))
y = np.ones(2)*2
y=y.reshape((2,1,1))

x * y
y

array([[[2.]],

       [[2.]]])

In [66]:
def compute_posterior(complexities,predictions,hum_data,dists,seen, alpha,beta,noise):
    theta = 1./beta
    


    probs = st.gamma.pdf(complexities,a=alpha,scale=theta)
    #print(probs)


    probs = probs * noise**dists
    probs = probs * ((1-noise)**(seen-dists))

    norm = np.sum(probs, axis=2,keepdims=True)
    
    probs = probs/norm
    probs = probs * predictions
    
    probs = np.sum(probs,axis=2)
    probs = probs + (0.5-probs)*noise
    posterior = st.bernoulli.logpmf(hum_data, probs)

    #print(prior)
    
    
    return(np.sum(posterior))

def compute_posterior_multiple(complexities,predictions,hum_data,dists,seen, alpha,beta,noise):
    theta = 1./beta
    
    noise_orig = copy.deepcopy(noise).reshape((len(predictions),1))
    

    alpha = alpha.reshape((len(predictions),1,1))
    theta = theta.reshape((len(predictions),1,1))
    noise = noise.reshape((len(predictions),1,1))

    probs = st.gamma.pdf(complexities,a=alpha,scale=theta)
   # print(probs)
    #print("*"*30)

    probs = probs * noise**dists
    probs = probs * ((1-noise)**(seen-dists))

    norm = np.sum(probs, axis=2,keepdims=True)
    
   # print (probs/norm)
    probs = probs/norm
    probs = probs * predictions
    
    probs = np.sum(probs,axis=2)
    probs = probs + (0.5-probs)*noise_orig
    posterior = st.bernoulli.logpmf(hum_data, probs)

    #print(prior)
    
    
    return(np.sum(posterior))
    

pr1 = compute_posterior_multiple(cplxs[:3], mod_preds[:3],hum_preds[:3],dists[:3],seen[:3],
                       np.array([1,1,1])*0.5,np.array([1,1,1]),np.array([0.1,0.1,0.1])*2)

pr2 = compute_posterior(cplxs[:3], mod_preds[:3],hum_preds[:3],dists[:3],seen[:3],0.5,1,0.2)



In [67]:
n_steps = 10000
curr_alpha = random.random() * 10.
curr_beta = random.random() * 10.
curr_noise = random.random()*0.5
curr_posterior = compute_posterior(cplxs,mod_preds,hum_preds,dists,seen, curr_alpha,curr_beta,curr_noise)


for step in range(n_steps):
    prop_alpha = np.exp(np.log(curr_alpha)+np.random.normal(0,0.1))
    prop_beta = np.exp(np.log(curr_beta)+np.random.normal(0,0.1))
    
    prop_noise = curr_noise + (random.random()-curr_noise)*0.1
    
    prop_posterior = compute_posterior(cplxs,mod_preds,hum_preds,dists,seen, prop_alpha,prop_beta,prop_noise)
    
    if (prop_posterior > curr_posterior) or (np.exp(prop_posterior - curr_posterior) > random.random()):
        curr_posterior = prop_posterior
        curr_alpha = prop_alpha
        curr_beta = prop_beta
        curr_noise = prop_noise
        
        print(step, curr_posterior, curr_alpha,curr_beta,curr_noise)
    
    
print("DONE!")
    


0 -10021.608276395391 9.148216601451722 0.24249652868976815 0.39079565767600616
3 -10014.388266788155 8.412602617100886 0.23890455903609098 0.44214239594049004
5 -9990.319819469296 8.090389249185316 0.22377601246881737 0.41926010751627024
8 -9941.010424032105 7.300831678108995 0.2308621122665153 0.39773777348579187
17 -9939.410432658453 7.762680215112081 0.2494093819555642 0.3650029225777257
21 -9922.315536326936 7.394568254443081 0.236391099827676 0.34228616168534637
24 -9917.2939201419 6.3411045203321015 0.1973367246093374 0.39913236984499517
25 -9882.04782275483 5.065844509353319 0.23965170407231395 0.4066975059730444
26 -9879.85613961969 5.6587024005000295 0.24356950458644047 0.39437242942706635
27 -9845.182630306786 4.994052789387855 0.2823597865417953 0.3828603313211807
33 -9819.342339455226 4.590162041801107 0.278752381403464 0.3654889305876062
35 -9772.564730333583 4.0978187252214395 0.30225055825455766 0.3313357595345456
41 -9756.01279208952 3.7936921134270647 0.25792496796882

KeyboardInterrupt: 

In [64]:

        

n_steps = 12500
n_chains = 2
burn_in = 2500
anti_spread = 10.
acceptance_temperature = 5.

prop_stds, curr_stds = 1,1

all_alphas,all_betas,all_noises,all_posts = [],[],[],[]


def prior_penalty(*args):
    stds = 0.
    for arg in args:
        stds +=  np.std(arg)**2.
    
    return - anti_spread * stds


for chain in range(n_chains):

    alphas,betas,noises,posts = [],[],[],[]


    curr_alpha = np.ones(len(mod_preds)) * random.random() * 4
    curr_beta = np.ones(len(mod_preds)) * random.random() * 4

    curr_noise_logistic = np.random.normal(-1.25,0.1,len(mod_preds))

    curr_noise = 1/(1+np.exp(-curr_noise_logistic))
    curr_stds = (np.std(curr_alpha) * np.std(curr_beta) * np.std(curr_noise)) 

    curr_posterior = compute_posterior_multiple(cplxs,mod_preds,hum_preds,dists,seen, curr_alpha,curr_beta,curr_noise)
    curr_posterior = curr_posterior + prior_penalty([curr_alpha,curr_beta,curr_noise])

    
    add_alpha,add_beta,add_noise = 0,0,0
    accepted= False


    for step in range(n_steps):
        if accepted == False:
            add_alpha = np.random.normal(np.random.normal(0,0.3),0.1,len(curr_alpha))
            add_beta = np.random.normal(np.random.normal(0,0.3),0.1,len(curr_beta))
            add_noise =np.random.normal(np.random.normal(0,0.3),0.1,len(mod_preds))
        
            
        else:
            add_alpha = np.random.normal(add_alpha,0.05,len(add_alpha))
            add_beta =  np.random.normal(add_beta,0.05,len(add_beta))
            add_noise =  np.random.normal(add_noise,0.05,len(add_noise))


        prop_alpha = np.abs(curr_alpha+add_alpha)
        prop_beta = np.abs(curr_beta + add_beta)

        prop_noise_logistic = curr_noise_logistic + add_noise
        prop_noise = 1/(1+np.exp(-prop_noise_logistic))
        
        

        prop_stds = prior_penalty([prop_alpha,prop_beta,prop_noise])

        prop_posterior = compute_posterior_multiple(cplxs,mod_preds,hum_preds,dists,seen, prop_alpha,prop_beta,prop_noise)
        prop_posterior = prop_posterior + prop_stds
        
        
      #  print(np.exp(prop_posterior - curr_posterior))

        if ((prop_posterior > curr_posterior) or 
            (np.exp(prop_posterior - curr_posterior) > random.random()/acceptance_temperature)):
            curr_posterior = prop_posterior
            curr_alpha = prop_alpha
            curr_beta = prop_beta
            curr_noise = prop_noise
            curr_noise_logistic=prop_noise_logistic
            curr_stds = prop_stds


            print(chain, step, np.round(curr_posterior), 
                  np.round(curr_alpha.mean(),1),np.round(curr_beta.mean(),1),np.round(curr_noise.mean(),1),
                  np.round(curr_alpha.std(),1),np.round(curr_beta.std(),1),
                  np.round(prop_stds,1))
            
            
 
        if prop_posterior > curr_posterior:
            accepted=True
        else:
            accepted = False

        if (step > burn_in) and (step % 10 == 0):
            alphas.append(list(curr_alpha))
            betas.append(list(curr_beta))
            noises.append(list(curr_noise))
            posts.append(curr_posterior)
            
    all_alphas.append(copy.deepcopy(alphas))
    all_betas.append(copy.deepcopy(betas))
    all_noises.append(copy.deepcopy(noises))
    all_posts.append(copy.deepcopy(posts))

    
all_alphas = np.array(all_alphas)
all_betas = np.array(all_betas)
all_noises = np.array(all_noises)
all_posts = np.array(all_posts)
    
print("DONE!")
    

0 1 -9950.0 2.9 3.1 0.4 0.1 0.1 -15.3
0 2 -9804.0 2.8 2.6 0.3 0.1 0.1 -12.5
0 6 -9689.0 2.8 2.3 0.3 0.2 0.2 -12.0
0 8 -9539.0 3.0 2.1 0.2 0.2 0.2 -13.6
0 9 -9531.0 3.2 1.8 0.3 0.2 0.2 -15.3
0 12 -9511.0 3.7 2.0 0.2 0.3 0.3 -20.4
0 32 -9513.0 3.7 2.0 0.2 0.3 0.2 -21.1
0 43 -9490.0 3.7 2.0 0.2 0.3 0.2 -21.4
0 74 -9471.0 3.7 1.9 0.2 0.3 0.3 -20.7
0 108 -9459.0 3.5 1.9 0.2 0.4 0.3 -19.0
0 116 -9456.0 3.3 1.6 0.2 0.4 0.3 -16.9
0 125 -9438.0 3.0 1.7 0.2 0.4 0.3 -14.1
0 151 -9434.0 2.9 1.7 0.2 0.4 0.3 -12.8
0 156 -9411.0 2.5 1.5 0.2 0.4 0.3 -9.5
0 179 -9411.0 2.9 1.4 0.2 0.4 0.4 -13.6
0 188 -9407.0 2.9 1.7 0.2 0.5 0.4 -13.7
0 262 -9407.0 2.7 1.7 0.2 0.5 0.4 -11.8
0 265 -9399.0 2.8 1.5 0.2 0.5 0.4 -12.1
0 270 -9400.0 2.7 1.6 0.2 0.5 0.4 -11.5
0 290 -9392.0 2.5 1.4 0.2 0.5 0.4 -9.9
0 409 -9392.0 1.8 1.2 0.2 0.5 0.4 -5.9
0 422 -9378.0 2.3 1.4 0.2 0.5 0.5 -8.7
0 470 -9372.0 2.1 1.2 0.2 0.5 0.5 -7.9
0 504 -9371.0 2.7 1.5 0.2 0.5 0.5 -11.8
0 600 -9370.0 3.0 1.6 0.2 0.5 0.5 -15.1
0 635 -9372.0 3.0 1

KeyboardInterrupt: 

In [95]:



best_post = float("-inf")
best_tup =None

alphas,betas,noises = np.arange(0.5,5,0.25),np.arange(0.5,5,0.25),np.arange(0.2,0.4,0.05)
posts = []
alpha_ps,beta_ps,noise_ps = np.zeros(len(alphas)), np.zeros(len(betas)),np.zeros(len(noises))
for n in range(len(noises)):
    noise = noises[n]
    for b in range(len(betas)):
        beta = betas[b]
        for a in range(len(alphas)):
            alpha = alphas[a]

            post = compute_posterior(cplxs,mod_preds,hum_preds,dists,seen, alpha,beta,noise)
            alpha_ps[a] += np.exp(post+9500)
            beta_ps[b] += np.exp(post+9500)
            noise_ps[n] += np.exp(post+9500)
            
            if str(np.exp(post+9500)) == "nan":
                assert(False)

            print(round(alpha,2),round(beta,2),round(noise,2),round(post))
            
     
            if post > best_post:
                best_tup = (round(alpha,2),round(beta,2),round(noise,2))
                best_post = post
                
                
    print("")
    print(alpha_ps/np.sum(alpha_ps))
    print(beta_ps/np.sum(beta_ps))
    print(noise_ps/np.sum(noise_ps))

    print("")



print("DONE")

0.5 0.5 0.2 -9427.0
0.75 0.5 0.2 -9440.0
1.0 0.5 0.2 -9463.0
1.25 0.5 0.2 -9494.0
1.5 0.5 0.2 -9526.0
1.75 0.5 0.2 -9557.0
2.0 0.5 0.2 -9585.0
2.25 0.5 0.2 -9609.0
2.5 0.5 0.2 -9628.0
2.75 0.5 0.2 -9643.0
3.0 0.5 0.2 -9655.0
3.25 0.5 0.2 -9666.0
3.5 0.5 0.2 -9675.0
3.75 0.5 0.2 -9683.0
4.0 0.5 0.2 -9692.0
4.25 0.5 0.2 -9700.0
4.5 0.5 0.2 -9709.0
4.75 0.5 0.2 -9719.0
0.5 0.75 0.2 -9453.0
0.75 0.75 0.2 -9428.0
1.0 0.75 0.2 -9421.0
1.25 0.75 0.2 -9427.0
1.5 0.75 0.2 -9446.0
1.75 0.75 0.2 -9472.0
2.0 0.75 0.2 -9502.0
2.25 0.75 0.2 -9533.0
2.5 0.75 0.2 -9561.0
2.75 0.75 0.2 -9586.0
3.0 0.75 0.2 -9605.0
3.25 0.75 0.2 -9621.0
3.5 0.75 0.2 -9632.0
3.75 0.75 0.2 -9641.0
4.0 0.75 0.2 -9648.0
4.25 0.75 0.2 -9653.0
4.5 0.75 0.2 -9659.0
4.75 0.75 0.2 -9664.0
0.5 1.0 0.2 -9553.0
0.75 1.0 0.2 -9498.0
1.0 1.0 0.2 -9457.0
1.25 1.0 0.2 -9431.0
1.5 1.0 0.2 -9421.0
1.75 1.0 0.2 -9425.0
2.0 1.0 0.2 -9440.0
2.25 1.0 0.2 -9464.0
2.5 1.0 0.2 -9493.0
2.75 1.0 0.2 -9523.0
3.0 1.0 0.2 -9552.0
3.25 1.0 0.2 -9577.

4.0 0.75 0.25 -9661.0
4.25 0.75 0.25 -9663.0
4.5 0.75 0.25 -9665.0
4.75 0.75 0.25 -9667.0
0.5 1.0 0.25 -9572.0
0.75 1.0 0.25 -9530.0
1.0 1.0 0.25 -9500.0
1.25 1.0 0.25 -9481.0
1.5 1.0 0.25 -9476.0
1.75 1.0 0.25 -9483.0
2.0 1.0 0.25 -9501.0
2.25 1.0 0.25 -9525.0
2.5 1.0 0.25 -9552.0
2.75 1.0 0.25 -9579.0
3.0 1.0 0.25 -9603.0
3.25 1.0 0.25 -9623.0
3.5 1.0 0.25 -9638.0
3.75 1.0 0.25 -9648.0
4.0 1.0 0.25 -9654.0
4.25 1.0 0.25 -9658.0
4.5 1.0 0.25 -9660.0
4.75 1.0 0.25 -9661.0
0.5 1.25 0.25 -9666.0
0.75 1.25 0.25 -9615.0
1.0 1.25 0.25 -9568.0
1.25 1.25 0.25 -9529.0
1.5 1.25 0.25 -9499.0
1.75 1.25 0.25 -9481.0
2.0 1.25 0.25 -9475.0
2.25 1.25 0.25 -9481.0
2.5 1.25 0.25 -9497.0
2.75 1.25 0.25 -9521.0
3.0 1.25 0.25 -9549.0
3.25 1.25 0.25 -9577.0
3.5 1.25 0.25 -9604.0
3.75 1.25 0.25 -9626.0
4.0 1.25 0.25 -9643.0
4.25 1.25 0.25 -9656.0
4.5 1.25 0.25 -9664.0
4.75 1.25 0.25 -9668.0
0.5 1.5 0.25 -9759.0
0.75 1.5 0.25 -9707.0
1.0 1.5 0.25 -9657.0
1.25 1.5 0.25 -9608.0
1.5 1.5 0.25 -9564.0
1.75 1.5 0.

4.25 1.0 0.3 -9721.0
4.5 1.0 0.3 -9719.0
4.75 1.0 0.3 -9717.0
0.5 1.25 0.3 -9730.0
0.75 1.25 0.3 -9694.0
1.0 1.25 0.3 -9662.0
1.25 1.25 0.3 -9635.0
1.5 1.25 0.3 -9615.0
1.75 1.25 0.3 -9605.0
2.0 1.25 0.3 -9603.0
2.25 1.25 0.3 -9611.0
2.5 1.25 0.3 -9626.0
2.75 1.25 0.3 -9645.0
3.0 1.25 0.3 -9667.0
3.25 1.25 0.3 -9688.0
3.5 1.25 0.3 -9706.0
3.75 1.25 0.3 -9719.0
4.0 1.25 0.3 -9729.0
4.25 1.25 0.3 -9734.0
4.5 1.25 0.3 -9736.0
4.75 1.25 0.3 -9736.0
0.5 1.5 0.3 -9794.0
0.75 1.5 0.3 -9756.0
1.0 1.5 0.3 -9719.0
1.25 1.5 0.3 -9686.0
1.5 1.5 0.3 -9655.0
1.75 1.5 0.3 -9631.0
2.0 1.5 0.3 -9613.0
2.25 1.5 0.3 -9604.0
2.5 1.5 0.3 -9604.0
2.75 1.5 0.3 -9612.0
3.0 1.5 0.3 -9627.0
3.25 1.5 0.3 -9648.0
3.5 1.5 0.3 -9671.0
3.75 1.5 0.3 -9693.0
4.0 1.5 0.3 -9714.0
4.25 1.5 0.3 -9730.0
4.5 1.5 0.3 -9742.0
4.75 1.5 0.3 -9749.0
0.5 1.75 0.3 -9853.0
0.75 1.75 0.3 -9816.0
1.0 1.75 0.3 -9779.0
1.25 1.75 0.3 -9743.0
1.5 1.75 0.3 -9708.0
1.75 1.75 0.3 -9676.0
2.0 1.75 0.3 -9649.0
2.25 1.75 0.3 -9627.0
2.5 1.75 0

2.75 1.5 0.35 -9790.0
3.0 1.5 0.35 -9802.0
3.25 1.5 0.35 -9816.0
3.5 1.5 0.35 -9830.0
3.75 1.5 0.35 -9842.0
4.0 1.5 0.35 -9851.0
4.25 1.5 0.35 -9856.0
4.5 1.5 0.35 -9859.0
4.75 1.5 0.35 -9858.0
0.5 1.75 0.35 -9963.0
0.75 1.75 0.35 -9933.0
1.0 1.75 0.35 -9904.0
1.25 1.75 0.35 -9876.0
1.5 1.75 0.35 -9849.0
1.75 1.75 0.35 -9826.0
2.0 1.75 0.35 -9806.0
2.25 1.75 0.35 -9792.0
2.5 1.75 0.35 -9784.0
2.75 1.75 0.35 -9781.0
3.0 1.75 0.35 -9785.0
3.25 1.75 0.35 -9795.0
3.5 1.75 0.35 -9808.0
3.75 1.75 0.35 -9824.0
4.0 1.75 0.35 -9840.0
4.25 1.75 0.35 -9854.0
4.5 1.75 0.35 -9866.0
4.75 1.75 0.35 -9873.0
0.5 2.0 0.35 -10007.0
0.75 2.0 0.35 -9978.0
1.0 2.0 0.35 -9948.0
1.25 2.0 0.35 -9919.0
1.5 2.0 0.35 -9890.0
1.75 2.0 0.35 -9863.0
2.0 2.0 0.35 -9838.0
2.25 2.0 0.35 -9817.0
2.5 2.0 0.35 -9800.0
2.75 2.0 0.35 -9788.0
3.0 2.0 0.35 -9782.0
3.25 2.0 0.35 -9782.0
3.5 2.0 0.35 -9789.0
3.75 2.0 0.35 -9800.0
4.0 2.0 0.35 -9816.0
4.25 2.0 0.35 -9833.0
4.5 2.0 0.35 -9851.0
4.75 2.0 0.35 -9867.0
0.5 2.25 0.35

In [96]:
print(np.sum(alphas * alpha_ps/np.sum(alpha_ps)))

1.2339233824112568
