In [None]:
import numpy as np
import pandas as pd

from matplotlib import pyplot as plt
import seaborn as sns

from random import *

In [None]:
def prob(y): #probability of reward
    if random() <= y:
        return 1
    else:
        return 0

In [None]:
def trialgenerator(n, y): #generating 1000 trials for the two levers; the levers switch probabilities every 10-32 trials
    left_reward = []
    right_reward = []
    left_correct = []
    right_correct = []
    for trials in np.arange(50):
        for x in np.arange(randint(10,32)):
            left_reward.append(prob(n))
            right_reward.append(prob(y))
            left_correct.append(1)   #correct is 1 when reward probability is higher and 0 when lower
            right_correct.append(0)
        for x in np.arange(randint(10,32)):
            left_reward.append(prob(y))
            right_reward.append(prob(n))
            left_correct.append(0)
            right_correct.append(1)
    del left_correct [1000:] #deleting everything after 1000 trials
    del right_correct [1000:]
    del left_reward[1000:]
    del right_reward[1000:]
    return left_reward, right_reward, left_correct, right_correct

In [None]:
def lever_update(alpha, value, reward):
    value += alpha * (reward - value) 
    return(value)

In [None]:
def softmax(beta, temp_value1, temp_value2):   #temp_value1 is for left lever and temp_value2 is for right lever
    num = np.exp(temp_value1 * beta)
    den = np.exp(temp_value1 * beta) + np.exp(temp_value2 * beta)    
    return num / den

In [None]:
def sim(alpha, beta, left_reward, right_reward, left_correct, right_correct):
    right_value = [0.5]
    left_value = [0.5]
    correct = []
    for index, lr in enumerate(left_reward):
        if random() <= softmax(beta, left_value[-1], right_value[-1]):
            left_value.append(lever_update(alpha, left_value[-1], left_reward[index]))
            right_value.append(right_value[-1])
            correct.append(left_correct[index])
        else:
            right_value.append(lever_update(alpha, right_value[-1], right_reward[index]))
            left_value.append(left_value[-1])
            correct.append(right_correct[index])
    return correct

In [None]:
correct_choices = {} 
for x in np.arange(0,1.05,0.05): #range of alphas 0 to 1
    x = round(x, 3)
    correct_choices[x] = {}
    for y in np.arange(0,5.1,0.1): #range of beta from 0 to 5
        y = round(y, 3)
        correct_choices[x][y] = {}
        correct_list = []
        for n in range(10):
            left_reward, right_reward, left_correct, right_correct = trialgenerator(0.7, 0.1)  #setting reward for left lever
            correct_list.append(sim(x, y, left_reward, right_reward, left_correct, right_correct))
        correct_choices[x][y] = correct_list

In [None]:
heatmap = []
for y in np.arange(0,5.1,0.1):
    y = round(y, 3)
    heatmap.append([np.mean(correct_choices[x][y]) for x in correct_choices])

In [None]:
df_3a = pd.DataFrame(heatmap, index = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2, 2.1, 2.2, 2.3, 2.4, 2.5,2.6, 2.7, 2.8, 2.9, 3, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4, 4.1, 4.2, 4.3, 4.4, 4.5, 4.6, 4.7, 4.8, 4.9, 5], columns = [0,0.05,0.1,0.15,0.2,0.25,0.3,0.35,0.4,0.45,0.5,0.55,0.6,0.65,0.7,0.75,0.8,0.85,0.9,0.95,1] )
df_3a = df_3a.multiply(100) 

In [None]:
ax = sns.heatmap(df_3a,xticklabels = 2, yticklabels = 5)
plt.xlabel("α values", fontsize = 12)
plt.ylabel("β values", fontsize = 12) 
ax.invert_yaxis()
cbar = ax.collections[0].colorbar
cbar.set_label('% optimal action', labelpad=10, fontsize = 11)
plt.savefig('70v10.png', bbox_inches='tight', dpi = 300)