In [1]:
import random
import copy
import numpy as np
import theano as T
import theano.tensor as tt
#import pymc3 as pm
import scipy.stats as st
import pandas as pd
import matplotlib.pyplot as plt


import rpy2
from rpy2.robjects import pandas2ri

pandas2ri.activate()
%load_ext rpy2.ipython

In [2]:
%%capture  --no-stderr --no-stdout --no-display
%%R

library(ggplot2)
library(reshape)
library(grid)
library(dplyr)
library(gridExtra)
library(lme4)

paper_theme <- theme_light() + theme( axis.title.x = element_text(size=18),
  axis.text.x=element_text(colour="black", 
                           size = 14), 
  axis.title.y = element_text(size = 18, vjust = 1),
  axis.text.y  = element_text(size = 14),
  strip.text=element_text(size=16),
  axis.line.x = element_line(colour = "black"), 
  axis.line.y = element_line(colour = "black"),
  legend.title=element_text(size=18),
  legend.text=element_text(size=16))  

  

paper_theme_2 <- theme_light() + theme(
                                       legend.text=element_text(size=14), 
                            legend.title=element_blank(),
                     legend.key = element_rect(color = "transparent", fill="transparent"),
                            legend.position=c(0.82,0.82),
                            strip.text=element_blank(), 

                            axis.text.x=element_text(size=16,color="black"),
                            axis.text.y=element_text(size=16,color="black"),
                            axis.title.x=element_text(face="plain", size=20),
                            axis.title.y=element_text(face="plain", size=20)) 


Attaching package: ‘dplyr’



    rename



    filter, lag



    intersect, setdiff, setequal, union


Attaching package: ‘gridExtra’



    combine



Attaching package: ‘Matrix’



    expand


Attaching package: ‘lme4’



    sigma




In [3]:
data = pd.read_csv("mc_data_big.csv")
subj_group = data.groupby(["subj","incr","n_query", "stim"])


subj_data = {}
for i, trial in subj_group:
    subj = list(trial["subj"])[0]
    if subj not in subj_data:
        subj_data[subj] = []
        
    cnt = [x.count(".") for x in list(trial["stim"])]
    
    tup = (list(trial["guess"])[0],
               list(trial["mod_pred"]), list(trial["cplx"]), list(trial["dist"]), cnt)
    subj_data[subj].append(tup)

    




In [65]:
keys = list(subj_data.keys())
subjs = []
hum_preds = []
mod_preds = []
cplxs = []
dists = []
seen = []
for key in keys:
    tups = subj_data[key]
    m1,m2,m3,m4,m5,m6 = [],[],[],[],[],[]
    if len(tups) == 160:
        for tup in tups:
            m1.append(key)
            m2.append(tup[0])
            m3.append(tup[1])
            m4.append(tup[2])
            m5.append(tup[3])
            m6.append(tup[4])
        subjs.append(m1)
        hum_preds.append(m2)
        mod_preds.append(m3)
        cplxs.append(m4)
        dists.append(m5)
        seen.append(m6)

subjs = np.array(subjs)
hum_preds = np.array(hum_preds)
mod_preds = np.array(mod_preds)
cplxs = np.array(cplxs)
dists = np.array(dists)
seen = np.array(seen)
seen = 16-seen

mod_shape = cplxs.shape
hum_shape = hum_preds.shape

In [5]:


x = np.ones((2,3))
y = np.ones(2)*2
y=y.reshape((2,1,1))

x * y
y

array([[[2.]],

       [[2.]]])

In [66]:
def compute_posterior(complexities,predictions,hum_data,dists,seen, alpha,beta,noise):
    theta = 1./beta
    


    probs = st.gamma.pdf(complexities,a=alpha,scale=theta)
    #print(probs)


    probs = probs * noise**dists
    probs = probs * ((1-noise)**(seen-dists))

    norm = np.sum(probs, axis=2,keepdims=True)
    
    probs = probs/norm
    probs = probs * predictions
    
    probs = np.sum(probs,axis=2)
    probs = probs + (0.5-probs)*noise
    posterior = st.bernoulli.logpmf(hum_data, probs)

    #print(prior)
    
    
    return(np.sum(posterior))

def compute_posterior_multiple(complexities,predictions,hum_data,dists,seen, alpha,beta,noise):
    theta = 1./beta
    
    noise_orig = copy.deepcopy(noise).reshape((len(predictions),1))
    

    alpha = alpha.reshape((len(predictions),1,1))
    theta = theta.reshape((len(predictions),1,1))
    noise = noise.reshape((len(predictions),1,1))

    probs = st.gamma.pdf(complexities,a=alpha,scale=theta)
   # print(probs)
    #print("*"*30)

    probs = probs * noise**dists
    probs = probs * ((1-noise)**(seen-dists))

    norm = np.sum(probs, axis=2,keepdims=True)
    
   # print (probs/norm)
    probs = probs/norm
    probs = probs * predictions
    
    probs = np.sum(probs,axis=2)
    probs = probs + (0.5-probs)*noise_orig
    posterior = st.bernoulli.logpmf(hum_data, probs)

    #print(prior)
    
    
    return(np.sum(posterior))
    

pr1 = compute_posterior_multiple(cplxs[:3], mod_preds[:3],hum_preds[:3],dists[:3],seen[:3],
                       np.array([1,1,1])*0.5,np.array([1,1,1]),np.array([0.1,0.1,0.1])*2)

pr2 = compute_posterior(cplxs[:3], mod_preds[:3],hum_preds[:3],dists[:3],seen[:3],0.5,1,0.2)



In [67]:
n_steps = 10000
curr_alpha = random.random() * 10.
curr_beta = random.random() * 10.
curr_noise = random.random()*0.5
curr_posterior = compute_posterior(cplxs,mod_preds,hum_preds,dists,seen, curr_alpha,curr_beta,curr_noise)


for step in range(n_steps):
    prop_alpha = np.exp(np.log(curr_alpha)+np.random.normal(0,0.1))
    prop_beta = np.exp(np.log(curr_beta)+np.random.normal(0,0.1))
    
    prop_noise = curr_noise + (random.random()-curr_noise)*0.1
    
    prop_posterior = compute_posterior(cplxs,mod_preds,hum_preds,dists,seen, prop_alpha,prop_beta,prop_noise)
    
    if (prop_posterior > curr_posterior) or (np.exp(prop_posterior - curr_posterior) > random.random()):
        curr_posterior = prop_posterior
        curr_alpha = prop_alpha
        curr_beta = prop_beta
        curr_noise = prop_noise
        
        print(step, curr_posterior, curr_alpha,curr_beta,curr_noise)
    
    
print("DONE!")
    


0 -10021.608276395391 9.148216601451722 0.24249652868976815 0.39079565767600616
3 -10014.388266788155 8.412602617100886 0.23890455903609098 0.44214239594049004
5 -9990.319819469296 8.090389249185316 0.22377601246881737 0.41926010751627024
8 -9941.010424032105 7.300831678108995 0.2308621122665153 0.39773777348579187
17 -9939.410432658453 7.762680215112081 0.2494093819555642 0.3650029225777257
21 -9922.315536326936 7.394568254443081 0.236391099827676 0.34228616168534637
24 -9917.2939201419 6.3411045203321015 0.1973367246093374 0.39913236984499517
25 -9882.04782275483 5.065844509353319 0.23965170407231395 0.4066975059730444
26 -9879.85613961969 5.6587024005000295 0.24356950458644047 0.39437242942706635
27 -9845.182630306786 4.994052789387855 0.2823597865417953 0.3828603313211807
33 -9819.342339455226 4.590162041801107 0.278752381403464 0.3654889305876062
35 -9772.564730333583 4.0978187252214395 0.30225055825455766 0.3313357595345456
41 -9756.01279208952 3.7936921134270647 0.25792496796882

KeyboardInterrupt: 

In [122]:

        

n_steps = 12500
n_chains = 2
burn_in = 2500
anti_spread = 10.
acceptance_temperature = 5.

prop_stds, curr_stds = 1,1

all_alphas,all_betas,all_noises,all_posts = [],[],[],[]


def prior_penalty(lst):
   # stds = 0.
    prob = 0.
    for l in lst:
       # stds +=  np.std(l)**2.
        prob += np.sum(st.norm.logpdf(l,l.mean(),1))
        
    return prob#- anti_spread * stds


for chain in range(n_chains):

    alphas,betas,noises,posts = [],[],[],[]


    curr_alpha = np.ones(len(mod_preds)) * random.random() * 8
    curr_beta = np.ones(len(mod_preds)) * random.random() * 8

    curr_noise_logistic = np.random.normal(-1.25,0.1,len(mod_preds))

    curr_noise = 1/(1+np.exp(-curr_noise_logistic))
    #curr_stds = (np.std(curr_alpha) * np.std(curr_beta) * np.std(curr_noise)) 

    curr_posterior = compute_posterior_multiple(cplxs,mod_preds,hum_preds,dists,seen, curr_alpha,curr_beta,curr_noise)
    curr_posterior = curr_posterior + prior_penalty([curr_alpha,curr_beta,curr_noise])

    
    add_alpha,add_beta,add_noise = 0,0,0
    accepted= False


    for step in range(n_steps):
        if accepted == False:
            add_alpha = np.random.normal(np.random.normal(0,0.1),0.1,len(curr_alpha))
            add_beta = np.random.normal(np.random.normal(0,0.1),0.1,len(curr_beta))
            add_noise =np.random.normal(np.random.normal(0,0.1),0.1,len(mod_preds))        
            
        else:
            add_alpha = np.random.normal(add_alpha,0.01,len(add_alpha))
            add_beta =  np.random.normal(add_beta,0.01,len(add_beta))
            add_noise =  np.random.normal(add_noise,0.01,len(add_noise))


        prop_alpha = np.abs(curr_alpha+add_alpha)
        prop_beta = np.abs(curr_beta + add_beta)

        prop_noise_logistic = curr_noise_logistic + add_noise
        prop_noise = 1/(1+np.exp(-prop_noise_logistic))
        
        

        prop_stds = prior_penalty([prop_alpha,prop_beta,prop_noise])

        prop_posterior = compute_posterior_multiple(cplxs,mod_preds,hum_preds,dists,seen, prop_alpha,prop_beta,prop_noise)
        prop_posterior = prop_posterior + prop_stds
        
        
      #  print(np.exp(prop_posterior - curr_posterior))

        if ((prop_posterior > curr_posterior) or 
            (np.exp(prop_posterior - curr_posterior) > random.random()/acceptance_temperature)):
            curr_posterior = prop_posterior
            curr_alpha = prop_alpha
            curr_beta = prop_beta
            curr_noise = prop_noise
            curr_noise_logistic=prop_noise_logistic
            curr_stds = prop_stds

            print(chain, step, np.round(curr_posterior-curr_stds),
                  np.round(curr_alpha.mean(),1),np.round(curr_beta.mean(),1),np.round(curr_noise.mean(),2),
                #  np.round(curr_alpha.std(),1),np.round(curr_beta.std(),1),
                  np.round(curr_stds,1))
            
            
 
        if prop_posterior > curr_posterior:
            accepted=True
        else:
            accepted = False

        if (step > burn_in) and (step % 10 == 0):
            alphas.append(list(curr_alpha))
            betas.append(list(curr_beta))
            noises.append(list(curr_noise))
            posts.append(curr_posterior)
            
    all_alphas.append(copy.deepcopy(alphas))
    all_betas.append(copy.deepcopy(betas))
    all_noises.append(copy.deepcopy(noises))
    all_posts.append(copy.deepcopy(posts))

    
all_alphas = np.array(all_alphas)
all_betas = np.array(all_betas)
all_noises = np.array(all_noises)
all_posts = np.array(all_posts)
    
print("DONE!")
    

0 0 -10411.0 2.2 6.0 0.24 -279.5
0 2 -10409.0 2.3 6.0 0.25 -280.6
0 4 -10385.0 2.4 6.0 0.26 -281.7
0 9 -10358.0 2.6 6.0 0.26 -282.4
0 11 -10357.0 2.5 5.9 0.27 -283.5
0 13 -10334.0 2.6 5.9 0.29 -284.0
0 14 -10332.0 2.6 6.0 0.32 -284.6
0 18 -10330.0 2.8 5.8 0.33 -286.3
0 19 -10323.0 2.8 5.7 0.33 -287.5
0 20 -10319.0 2.7 5.6 0.32 -288.6
0 29 -10311.0 3.0 5.5 0.32 -288.9
0 31 -10290.0 3.1 5.4 0.3 -290.3
0 32 -10292.0 3.2 5.4 0.3 -290.6
0 34 -10290.0 3.3 5.3 0.32 -291.2
0 39 -10246.0 3.5 5.3 0.33 -292.2
0 41 -10244.0 3.4 5.0 0.35 -293.5
0 43 -10215.0 3.5 4.9 0.31 -294.1
0 44 -10217.0 3.5 4.9 0.32 -294.6
0 45 -10219.0 3.5 4.9 0.33 -296.3
0 46 -10208.0 3.6 4.8 0.33 -296.8
0 52 -10188.0 3.7 4.7 0.31 -298.6
0 56 -10190.0 3.6 4.6 0.33 -298.9
0 58 -10186.0 3.6 4.7 0.3 -300.2
0 60 -10175.0 3.5 4.6 0.29 -300.8
0 64 -10124.0 3.8 4.5 0.29 -301.9
0 65 -10116.0 3.9 4.5 0.32 -302.4
0 66 -10099.0 4.0 4.4 0.33 -304.2
0 71 -10062.0 4.1 4.3 0.32 -305.1
0 72 -10011.0 4.1 4.0 0.31 -305.1
0 73 -9960.0 4.2 4.0 

0 1475 -9164.0 4.3 2.2 0.23 -400.8
0 1482 -9166.0 4.3 2.1 0.24 -399.9
0 1484 -9165.0 4.2 2.1 0.25 -402.4
0 1488 -9166.0 4.3 2.2 0.24 -399.8
0 1495 -9162.0 4.1 2.2 0.24 -399.8
0 1523 -9160.0 4.0 2.1 0.25 -400.0
0 1526 -9156.0 4.0 2.1 0.24 -400.6
0 1547 -9155.0 4.0 2.2 0.24 -399.8
0 1549 -9145.0 4.1 2.2 0.24 -402.6
0 1563 -9150.0 4.2 2.3 0.25 -402.5
0 1603 -9150.0 4.2 2.2 0.25 -402.7
0 1630 -9153.0 4.1 2.2 0.23 -403.1
0 1639 -9156.0 4.0 2.1 0.23 -401.9
0 1643 -9152.0 4.0 2.2 0.23 -403.0
0 1662 -9154.0 4.0 2.1 0.23 -403.6
0 1668 -9154.0 4.1 2.1 0.24 -403.2
0 1693 -9153.0 4.1 2.1 0.25 -401.7
0 1706 -9155.0 4.0 2.2 0.24 -399.7
0 1709 -9157.0 4.0 2.1 0.24 -397.0
0 1735 -9160.0 4.0 2.3 0.23 -395.2
0 1740 -9154.0 4.2 2.3 0.24 -395.3
0 1751 -9154.0 4.1 2.2 0.24 -394.1
0 1770 -9160.0 4.1 2.2 0.24 -391.1
0 1781 -9164.0 4.0 2.3 0.25 -390.0
0 1782 -9168.0 4.0 2.2 0.22 -389.0
0 1783 -9158.0 4.1 2.1 0.24 -389.5
0 1824 -9158.0 4.1 2.2 0.24 -391.3
0 1853 -9159.0 4.1 2.1 0.24 -392.7
0 1906 -9153.0 3.9 2

0 5545 -9106.0 4.2 2.0 0.25 -401.7
0 5593 -9109.0 4.1 2.0 0.24 -398.9
0 5618 -9111.0 4.1 2.0 0.26 -397.6
0 5620 -9115.0 4.2 2.0 0.26 -396.7
0 5622 -9109.0 4.2 2.1 0.25 -395.9
0 5628 -9101.0 4.2 2.1 0.24 -397.7
0 5633 -9101.0 4.1 2.1 0.24 -397.3
0 5639 -9108.0 4.0 2.0 0.26 -396.1
0 5642 -9112.0 3.9 2.0 0.26 -393.8
0 5643 -9115.0 4.0 2.0 0.24 -395.2
0 5648 -9114.0 4.0 2.0 0.24 -395.2
0 5659 -9109.0 4.0 2.0 0.24 -395.7
0 5670 -9111.0 3.8 2.0 0.24 -395.7
0 5689 -9112.0 3.8 1.9 0.24 -396.6
0 5709 -9114.0 3.9 1.9 0.25 -393.8
0 5731 -9114.0 3.9 2.0 0.24 -394.5
0 5773 -9113.0 3.8 1.9 0.24 -393.2
0 5793 -9113.0 3.8 1.9 0.25 -395.5
0 5797 -9107.0 3.8 2.0 0.24 -396.7
0 5805 -9109.0 3.8 1.9 0.25 -397.0
0 5809 -9113.0 3.9 1.9 0.24 -395.6
0 5886 -9114.0 4.0 2.0 0.24 -395.1
0 5890 -9111.0 4.0 2.0 0.24 -395.5
0 5900 -9112.0 4.0 2.0 0.25 -395.4
0 5912 -9114.0 4.2 2.1 0.25 -393.1
0 5948 -9114.0 4.1 2.1 0.25 -393.1
0 5949 -9107.0 4.0 2.1 0.24 -392.4
0 5955 -9102.0 4.0 2.0 0.24 -393.1
0 5969 -9100.0 4.0 2

0 9500 -9086.0 3.5 1.7 0.27 -403.1
0 9506 -9089.0 3.4 1.6 0.25 -402.0
0 9507 -9088.0 3.3 1.6 0.27 -403.4
0 9512 -9082.0 3.2 1.5 0.27 -402.8
0 9539 -9080.0 3.2 1.6 0.27 -403.5
0 9551 -9084.0 3.1 1.6 0.27 -402.4
0 9559 -9088.0 3.2 1.6 0.27 -400.6
0 9571 -9087.0 3.1 1.6 0.26 -402.0
0 9575 -9089.0 3.3 1.6 0.27 -404.3
0 9597 -9088.0 3.2 1.6 0.26 -405.6
0 9598 -9092.0 3.1 1.7 0.25 -404.4
0 9600 -9083.0 3.2 1.6 0.26 -406.0
0 9606 -9082.0 3.1 1.5 0.26 -406.5
0 9609 -9086.0 3.2 1.5 0.26 -405.3
0 9615 -9085.0 3.0 1.6 0.26 -405.1
0 9619 -9079.0 3.2 1.6 0.26 -404.7
0 9623 -9084.0 3.3 1.5 0.26 -401.2
0 9633 -9084.0 3.3 1.5 0.26 -402.3
0 9670 -9090.0 3.2 1.5 0.27 -400.1
0 9673 -9091.0 3.2 1.5 0.27 -397.8
0 9699 -9095.0 3.2 1.5 0.28 -397.3
0 9702 -9093.0 3.2 1.5 0.27 -397.7
0 9707 -9096.0 3.2 1.5 0.27 -398.9
0 9711 -9091.0 3.1 1.5 0.26 -399.2
0 9733 -9085.0 3.0 1.5 0.25 -398.8
0 9769 -9084.0 3.0 1.4 0.26 -398.7
0 9807 -9084.0 3.1 1.5 0.27 -399.6
0 9846 -9084.0 3.1 1.5 0.25 -400.4
0 9858 -9093.0 2.9 1

1 138 -9682.0 6.9 4.1 0.24 -323.9
1 139 -9681.0 6.9 3.9 0.23 -324.9
1 143 -9679.0 7.1 3.9 0.24 -325.3
1 150 -9668.0 7.1 3.9 0.22 -327.0
1 153 -9669.0 7.2 3.7 0.23 -327.7
1 164 -9664.0 7.1 3.7 0.22 -329.5
1 165 -9661.0 7.1 3.7 0.23 -329.5
1 167 -9654.0 6.9 3.8 0.24 -330.9
1 176 -9649.0 6.9 3.8 0.24 -332.8
1 178 -9647.0 6.8 3.9 0.25 -335.5
1 195 -9644.0 6.8 4.0 0.25 -335.5
1 200 -9640.0 6.8 3.8 0.24 -337.6
1 208 -9641.0 6.9 3.8 0.22 -337.0
1 209 -9632.0 6.9 3.6 0.24 -338.7
1 212 -9631.0 6.8 3.6 0.22 -339.2
1 217 -9631.0 6.7 3.7 0.22 -339.2
1 230 -9627.0 6.5 3.8 0.23 -340.1
1 233 -9629.0 6.6 3.7 0.23 -341.1
1 234 -9619.0 6.6 3.8 0.23 -343.4
1 238 -9607.0 6.5 3.7 0.22 -344.1
1 243 -9609.0 6.6 3.8 0.23 -344.5
1 252 -9600.0 6.7 3.8 0.23 -346.7
1 253 -9597.0 6.6 3.8 0.24 -345.2
1 257 -9597.0 6.6 3.8 0.25 -345.4
1 262 -9587.0 6.7 3.8 0.24 -346.4
1 264 -9585.0 6.8 3.8 0.24 -346.9
1 266 -9580.0 6.9 3.8 0.26 -345.4
1 268 -9572.0 6.7 3.7 0.24 -346.1
1 279 -9567.0 6.8 3.8 0.24 -346.7
1 288 -9558.0 

1 2315 -9107.0 3.7 1.9 0.23 -415.0
1 2316 -9108.0 3.8 2.0 0.22 -416.1
1 2321 -9108.0 3.9 1.9 0.23 -418.6
1 2337 -9106.0 4.0 2.0 0.23 -420.9
1 2384 -9108.0 4.0 2.0 0.23 -421.3
1 2392 -9108.0 4.0 2.1 0.24 -423.3
1 2400 -9110.0 3.9 2.1 0.24 -423.1
1 2402 -9108.0 4.0 2.1 0.24 -423.7
1 2404 -9107.0 4.0 2.0 0.23 -426.2
1 2408 -9108.0 4.0 2.0 0.23 -424.9
1 2416 -9103.0 4.0 2.1 0.24 -425.6
1 2426 -9104.0 4.1 2.1 0.24 -424.4
1 2448 -9107.0 4.1 2.0 0.25 -423.3
1 2457 -9108.0 4.1 2.1 0.24 -424.5
1 2459 -9108.0 4.1 2.1 0.25 -426.4
1 2475 -9112.0 4.0 2.1 0.26 -426.3
1 2480 -9113.0 4.1 2.3 0.25 -427.6
1 2486 -9111.0 4.3 2.3 0.25 -428.6
1 2487 -9112.0 4.4 2.3 0.23 -428.9
1 2488 -9116.0 4.4 2.2 0.24 -427.6
1 2494 -9121.0 4.4 2.3 0.22 -425.0
1 2498 -9120.0 4.5 2.3 0.23 -423.8
1 2499 -9115.0 4.5 2.2 0.23 -421.9
1 2512 -9113.0 4.3 2.2 0.24 -423.9
1 2530 -9116.0 4.3 2.1 0.25 -421.5
1 2551 -9113.0 4.2 2.0 0.23 -422.9
1 2567 -9116.0 4.2 2.1 0.23 -421.4
1 2586 -9119.0 4.2 2.1 0.25 -419.7
1 2601 -9119.0 4.3 2

1 5369 -9075.0 3.6 2.0 0.24 -414.8
1 5417 -9072.0 3.6 2.0 0.25 -416.8
1 5482 -9074.0 3.4 2.0 0.26 -416.0
1 5483 -9069.0 3.5 2.1 0.25 -416.3
1 5486 -9070.0 3.4 1.9 0.25 -417.4
1 5489 -9071.0 3.4 1.9 0.25 -417.2
1 5498 -9071.0 3.2 1.9 0.24 -416.6
1 5524 -9072.0 3.2 1.9 0.26 -416.6
1 5527 -9072.0 3.3 1.9 0.24 -416.7
1 5530 -9069.0 3.5 2.0 0.24 -418.9
1 5540 -9068.0 3.3 1.9 0.25 -419.8
1 5665 -9069.0 3.3 1.9 0.25 -417.0
1 5666 -9072.0 3.2 1.8 0.25 -416.7
1 5668 -9071.0 3.2 1.9 0.26 -417.5
1 5693 -9072.0 3.3 1.8 0.25 -418.3
1 5725 -9064.0 3.2 1.9 0.25 -418.5
1 5730 -9065.0 3.3 1.9 0.25 -421.4
1 5733 -9067.0 3.4 1.9 0.26 -420.2
1 5742 -9066.0 3.4 2.0 0.24 -420.1
1 5756 -9067.0 3.4 2.0 0.25 -422.9
1 5760 -9067.0 3.4 1.9 0.26 -424.7
1 5767 -9067.0 3.4 2.0 0.24 -424.5
1 5776 -9065.0 3.6 2.1 0.25 -422.1
1 5786 -9066.0 3.5 2.0 0.25 -422.3
1 5811 -9066.0 3.5 2.0 0.23 -422.4
1 5812 -9052.0 3.6 2.1 0.24 -422.6
1 5816 -9056.0 3.6 2.1 0.25 -419.5
1 5827 -9057.0 3.5 2.0 0.25 -417.6
1 5834 -9053.0 3.4 2

1 10111 -9091.0 3.4 1.8 0.26 -413.9
1 10177 -9093.0 3.6 1.9 0.27 -414.0
1 10249 -9095.0 3.6 1.8 0.27 -417.4
1 10253 -9095.0 3.5 1.8 0.26 -417.8
1 10337 -9096.0 3.3 1.8 0.26 -416.9
1 10344 -9100.0 3.2 1.9 0.26 -415.6
1 10357 -9089.0 3.4 1.8 0.27 -417.0
1 10375 -9092.0 3.3 1.8 0.26 -415.7
1 10379 -9094.0 3.2 1.7 0.27 -413.5
1 10417 -9096.0 3.2 1.8 0.28 -413.6
1 10421 -9097.0 3.2 1.7 0.27 -412.5
1 10428 -9097.0 3.3 1.8 0.27 -412.1
1 10492 -9099.0 3.5 1.8 0.27 -412.8
1 10502 -9101.0 3.3 1.7 0.28 -411.8
1 10503 -9096.0 3.3 1.8 0.28 -412.8
1 10505 -9097.0 3.3 1.8 0.28 -412.3
1 10527 -9096.0 3.2 1.7 0.28 -411.6
1 10554 -9097.0 3.2 1.7 0.28 -412.6
1 10558 -9089.0 3.3 1.7 0.27 -412.9
1 10590 -9090.0 3.1 1.8 0.28 -413.7
1 10603 -9085.0 3.2 1.8 0.27 -411.6
1 10614 -9086.0 3.3 1.7 0.27 -409.0
1 10627 -9085.0 3.2 1.7 0.27 -410.6
1 10666 -9083.0 3.3 1.8 0.26 -410.4
1 10681 -9083.0 3.4 1.7 0.29 -411.0
1 10685 -9081.0 3.3 1.7 0.27 -414.3
1 10691 -9080.0 3.3 1.7 0.27 -416.8
1 10711 -9079.0 3.3 1.8 0.28

-15.555555555555554
0.816496580927726


In [101]:



best_post = float("-inf")
best_tup =None

alphas,betas,noises = np.arange(0.5,5,0.25),np.arange(0.5,2,0.1),np.arange(0.15,0.25,0.02)
posts = []
alpha_ps,beta_ps,noise_ps = np.zeros(len(alphas)), np.zeros(len(betas)),np.zeros(len(noises))
for n in range(len(noises)):
    noise = noises[n]
    for b in range(len(betas)):
        beta = betas[b]
        for a in range(len(alphas)):
            alpha = alphas[a]

            post = compute_posterior(cplxs,mod_preds,hum_preds,dists,seen, alpha,beta,noise)
            alpha_ps[a] += np.exp(post+9500)
            beta_ps[b] += np.exp(post+9500)
            noise_ps[n] += np.exp(post+9500)
            
            if str(np.exp(post+9500)) == "nan":
                assert(False)

            print(round(alpha,2),round(beta,2),round(noise,2),round(post))
            
     
            if post > best_post:
                best_tup = (round(alpha,2),round(beta,2),round(noise,2))
                best_post = post
                

        print("")
        print([round(x,4) for x in alpha_ps/np.sum(alpha_ps)])
        print([round(x,4) for x in beta_ps/np.sum(beta_ps)])
        print([round(x,4) for x in noise_ps/np.sum(noise_ps)])

        print("")



print("DONE")

0.5 0.5 0.15 -9458.0
0.75 0.5 0.15 -9464.0
1.0 0.5 0.15 -9482.0
1.25 0.5 0.15 -9510.0
1.5 0.5 0.15 -9542.0
1.75 0.5 0.15 -9576.0
2.0 0.5 0.15 -9609.0
2.25 0.5 0.15 -9639.0
2.5 0.5 0.15 -9665.0
2.75 0.5 0.15 -9686.0
3.0 0.5 0.15 -9704.0
3.25 0.5 0.15 -9719.0
3.5 0.5 0.15 -9733.0
3.75 0.5 0.15 -9746.0
4.0 0.5 0.15 -9758.0
4.25 0.5 0.15 -9770.0
4.5 0.5 0.15 -9783.0
4.75 0.5 0.15 -9796.0

[0.9959, 0.0041, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[1.0, 0.0, 0.0, 0.0, 0.0]

0.5 0.6 0.15 -9469.0
0.75 0.6 0.15 -9458.0
1.0 0.6 0.15 -9463.0
1.25 0.6 0.15 -9480.0
1.5 0.6 0.15 -9506.0
1.75 0.6 0.15 -9537.0
2.0 0.6 0.15 -9569.0
2.25 0.6 0.15 -9600.0
2.5 0.6 0.15 -9629.0
2.75 0.6 0.15 -9653.0
3.0 0.6 0.15 -9674.0
3.25 0.6 0.15 -9690.0
3.5 0.6 0.15 -9704.0
3.75 0.6 0.15 -9716.0
4.0 0.6 0.15 -9727.0
4.25 0.6 0.15 -9737.0
4.5 0.6 0.15 -9748.0
4.75 0.6 0.15 -9758.0

[0.4527, 0.543, 0.0044, 

0.5 1.9 0.15 -10263.0
0.75 1.9 0.15 -10180.0
1.0 1.9 0.15 -10093.0
1.25 1.9 0.15 -10006.0
1.5 1.9 0.15 -9919.0
1.75 1.9 0.15 -9836.0
2.0 1.9 0.15 -9757.0
2.25 1.9 0.15 -9687.0
2.5 1.9 0.15 -9628.0
2.75 1.9 0.15 -9581.0
3.0 1.9 0.15 -9548.0
3.25 1.9 0.15 -9531.0
3.5 1.9 0.15 -9527.0
3.75 1.9 0.15 -9536.0
4.0 1.9 0.15 -9554.0
4.25 1.9 0.15 -9579.0
4.5 1.9 0.15 -9609.0
4.75 1.9 0.15 -9639.0

[0.4163, 0.4993, 0.0813, 0.0031, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.418, 0.5016, 0.0779, 0.0024, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[1.0, 0.0, 0.0, 0.0, 0.0]

0.5 0.5 0.17 -9432.0
0.75 0.5 0.17 -9441.0
1.0 0.5 0.17 -9463.0
1.25 0.5 0.17 -9492.0
1.5 0.5 0.17 -9525.0
1.75 0.5 0.17 -9559.0
2.0 0.5 0.17 -9590.0
2.25 0.5 0.17 -9618.0
2.5 0.5 0.17 -9641.0
2.75 0.5 0.17 -9660.0
3.0 0.5 0.17 -9676.0
3.25 0.5 0.17 -9689.0
3.5 0.5 0.17 -9701.0
3.75 0.5 0.17 -9711.0
4.0 0.5 0.17 -9722.0
4.25 0.5 0.17 -9733.0
4.5 0.5 0.17 -9744.0
4.75 0.5 0.17 -9756.0

[0.

3.75 1.7 0.17 -9508.0
4.0 1.7 0.17 -9537.0
4.25 1.7 0.17 -9568.0
4.5 1.7 0.17 -9599.0
4.75 1.7 0.17 -9628.0

[0.0837, 0.4383, 0.3882, 0.0838, 0.0058, 0.0002, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0828, 0.4376, 0.3886, 0.0846, 0.0062, 0.0002, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0, 1.0, 0.0, 0.0, 0.0]

0.5 1.8 0.17 -10088.0
0.75 1.8 0.17 -10010.0
1.0 1.8 0.17 -9931.0
1.25 1.8 0.17 -9851.0
1.5 1.8 0.17 -9772.0
1.75 1.8 0.17 -9697.0
2.0 1.8 0.17 -9629.0
2.25 1.8 0.17 -9570.0
2.5 1.8 0.17 -9523.0
2.75 1.8 0.17 -9489.0
3.0 1.8 0.17 -9470.0
3.25 1.8 0.17 -9466.0
3.5 1.8 0.17 -9474.0
3.75 1.8 0.17 -9492.0
4.0 1.8 0.17 -9518.0
4.25 1.8 0.17 -9549.0
4.5 1.8 0.17 -9581.0
4.75 1.8 0.17 -9613.0

[0.0837, 0.4383, 0.3882, 0.0838, 0.0058, 0.0002, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0828, 0.4376, 0.3886, 0.0846, 0.0062, 0.0002, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0, 1.0, 0.0, 0.0, 0.0]

0.5 1.9 0.17 -10140.0
0.75 1.9 0.17 -1006

0.75 1.6 0.19 -9826.0
1.0 1.6 0.19 -9752.0
1.25 1.6 0.19 -9680.0
1.5 1.6 0.19 -9613.0
1.75 1.6 0.19 -9553.0
2.0 1.6 0.19 -9503.0
2.25 1.6 0.19 -9465.0
2.5 1.6 0.19 -9442.0
2.75 1.6 0.19 -9433.0
3.0 1.6 0.19 -9437.0
3.25 1.6 0.19 -9453.0
3.5 1.6 0.19 -9477.0
3.75 1.6 0.19 -9507.0
4.0 1.6 0.19 -9539.0
4.25 1.6 0.19 -9571.0
4.5 1.6 0.19 -9600.0
4.75 1.6 0.19 -9626.0

[0.0128, 0.1258, 0.3426, 0.3233, 0.146, 0.0412, 0.0075, 0.0008, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0074, 0.0981, 0.285, 0.2896, 0.1805, 0.0964, 0.0351, 0.0071, 0.0008, 0.0001, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0, 0.0, 1.0, 0.0, 0.0]

0.5 1.7 0.19 -9950.0
0.75 1.7 0.19 -9879.0
1.0 1.7 0.19 -9807.0
1.25 1.7 0.19 -9734.0
1.5 1.7 0.19 -9664.0
1.75 1.7 0.19 -9599.0
2.0 1.7 0.19 -9542.0
2.25 1.7 0.19 -9495.0
2.5 1.7 0.19 -9461.0
2.75 1.7 0.19 -9441.0
3.0 1.7 0.19 -9435.0
3.25 1.7 0.19 -9443.0
3.5 1.7 0.19 -9460.0
3.75 1.7 0.19 -9486.0
4.0 1.7 0.19 -9517.0
4.25 1.7 0.19 -9550.0
4.5 1.7 0.19 -9582.0
4.75 1.7 0.19 -9612.0


0.75 1.4 0.21 -9683.0
1.0 1.4 0.21 -9619.0
1.25 1.4 0.21 -9559.0
1.5 1.4 0.21 -9508.0
1.75 1.4 0.21 -9468.0
2.0 1.4 0.21 -9441.0
2.25 1.4 0.21 -9428.0
2.5 1.4 0.21 -9429.0
2.75 1.4 0.21 -9442.0
3.0 1.4 0.21 -9464.0
3.25 1.4 0.21 -9492.0
3.5 1.4 0.21 -9523.0
3.75 1.4 0.21 -9554.0
4.0 1.4 0.21 -9583.0
4.25 1.4 0.21 -9607.0
4.5 1.4 0.21 -9627.0
4.75 1.4 0.21 -9641.0

[0.0125, 0.1233, 0.3365, 0.3205, 0.1496, 0.0461, 0.0099, 0.0015, 0.0001, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0073, 0.096, 0.2791, 0.2848, 0.1807, 0.1009, 0.0401, 0.0095, 0.0014, 0.0002, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.9779, 0.022, 0.0]

0.5 1.5 0.21 -9799.0
0.75 1.5 0.21 -9733.0
1.0 1.5 0.21 -9667.0
1.25 1.5 0.21 -9605.0
1.5 1.5 0.21 -9548.0
1.75 1.5 0.21 -9499.0
2.0 1.5 0.21 -9462.0
2.25 1.5 0.21 -9439.0
2.5 1.5 0.21 -9429.0
2.75 1.5 0.21 -9432.0
3.0 1.5 0.21 -9447.0
3.25 1.5 0.21 -9471.0
3.5 1.5 0.21 -9500.0
3.75 1.5 0.21 -9532.0
4.0 1.5 0.21 -9563.0
4.25 1.5 0.21 -9592.0
4.5 1.5 0.21 -9616.0
4.75 1.5 0.21

0.75 1.2 0.23 -9583.0
1.0 1.2 0.23 -9533.0
1.25 1.2 0.23 -9491.0
1.5 1.2 0.23 -9462.0
1.75 1.2 0.23 -9446.0
2.0 1.2 0.23 -9444.0
2.25 1.2 0.23 -9454.0
2.5 1.2 0.23 -9473.0
2.75 1.2 0.23 -9500.0
3.0 1.2 0.23 -9529.0
3.25 1.2 0.23 -9559.0
3.5 1.2 0.23 -9585.0
3.75 1.2 0.23 -9608.0
4.0 1.2 0.23 -9625.0
4.25 1.2 0.23 -9638.0
4.5 1.2 0.23 -9646.0
4.75 1.2 0.23 -9650.0

[0.0125, 0.1233, 0.3365, 0.3204, 0.1496, 0.0461, 0.0099, 0.0015, 0.0002, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0073, 0.096, 0.279, 0.2847, 0.1807, 0.1009, 0.04, 0.0095, 0.0014, 0.0002, 0.0001, 0.0, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.9778, 0.0221, 0.0]

0.5 1.3 0.23 -9685.0
0.75 1.3 0.23 -9626.0
1.0 1.3 0.23 -9571.0
1.25 1.3 0.23 -9523.0
1.5 1.3 0.23 -9485.0
1.75 1.3 0.23 -9458.0
2.0 1.3 0.23 -9445.0
2.25 1.3 0.23 -9445.0
2.5 1.3 0.23 -9457.0
2.75 1.3 0.23 -9478.0
3.0 1.3 0.23 -9505.0
3.25 1.3 0.23 -9535.0
3.5 1.3 0.23 -9565.0
3.75 1.3 0.23 -9592.0
4.0 1.3 0.23 -9614.0
4.25 1.3 0.23 -9632.0
4.5 1.3 0.23 -9644.0
4.75 1.3 0.2

In [102]:
print(np.sum(alphas * alpha_ps/np.sum(alpha_ps)))
print(np.sum(betas * beta_ps/np.sum(beta_ps)))
print(np.sum(noises * noise_ps/np.sum(noise_ps)))

1.1645188934543294
0.8057481620491
0.19044158567051403
