In [None]:
import numpy as np
import pandas as pd

from matplotlib import pyplot as plt
import seaborn as sns

from random import *
import itertools 

In [None]:
def prob(y): #probability of reward
    if random() <= y:
        return 1
    else:
        return 0

In [None]:
def trialgenerator(n, y): #generating 1000 trials for the two levers; the levers switch probabilities every 10-32 trials
    left_reward = []
    right_reward = []
    left_correct = []
    right_correct = []
    for trials in np.arange(50):
        for x in np.arange(randint(10,32)):
            left_reward.append(prob(n))
            right_reward.append(prob(y))
            left_correct.append(1)   #correct is 1 when reward probability is higher and 0 when lower
            right_correct.append(0)
        for x in np.arange(randint(10,32)):
            left_reward.append(prob(y))
            right_reward.append(prob(n))
            left_correct.append(0)
            right_correct.append(1)
    del left_correct [1000:] #deleting everything after 1000 trials
    del right_correct [1000:]
    del left_reward[1000:]
    del right_reward[1000:]
    return left_reward, right_reward, left_correct, right_correct

In [None]:
def lever_update(alpha, value, reward):
    value += alpha * (reward - value) 
    return(value)

In [None]:
def softmax(beta, temp_value1, temp_value2):   #temp_value1 is for left lever and temp_value2 is for right lever
    num = np.exp(temp_value1 * beta)
    den = np.exp(temp_value1 * beta) + np.exp(temp_value2 * beta)    
    return num / den

In [None]:
def RW_sim(alpha, beta, temp_value1, temp_value2, left_reward, right_reward, left_correct, right_correct):
    if random() <= softmax(beta, temp_value1, temp_value2):
        softmaxv = softmax(beta, temp_value1, temp_value2)
        leftlever = lever_update(alpha, temp_value1, left_reward)
        rightlever = temp_value2
        choice = 1
        correct = left_correct
        reward = left_reward
        transition = left_correct
    else:
        softmaxv = softmax(beta, temp_value1, temp_value2)
        rightlever = lever_update(alpha, temp_value2, right_reward)
        leftlever = temp_value1
        correct = right_correct
        choice = 0
        reward = right_reward
        transition = left_correct
    return leftlever, rightlever, softmaxv, choice, correct, reward, transition

In [None]:
leftvalues = []
rightvalues = []
softmaxvalues = []
choices = [] # +1 is left lever; 0 is right
correct = []
reward = []
transition = []

for n in range(50):
    leftlever = 0.5
    rightlever = 0.5 
    left_reward, right_reward, left_correct, right_correct = trialgenerator(0.7, 0.1)
    for index, lr in enumerate(left_reward):
        leftlever, rightlever, softmaxv, choice, c, r, t = RW_sim(0.65, 4.9, leftlever, rightlever, left_reward[index], right_reward[index], left_correct[index], right_correct[index])
        choices.append(choice)
        correct.append(c) 
        reward.append(r)
        transition.append(t)

In [None]:
temp_list = list(np.arange(1000,51000,1000))

In [None]:
switch = [a-1 for a in temp_list]

In [None]:
win_stay = []
lose_switch = []


for index, lr in enumerate(choices[:-1]):
    if index not in switch:
        if choices[index] == 1:  #if chooses left
            if reward[index] == 1:    #if wins
                if choices[index + 1] == 1:  #if stays
                    win_stay.append(1)
                elif choices[index + 1] == 0:  #if switches
                    win_stay.append(0)
            elif reward[index] == 0:  #if loses
                if choices[index + 1] == 0:  #if switches
                    lose_switch.append(1)
                elif choices[index + 1] == 1:  #if stays
                    lose_switch.append(0)
        elif choices[index] == 0:   #if chooses right
            if reward[index] == 1: #if wins   
                if choices[index + 1] == 0:  #if stays
                    win_stay.append(1)
                elif choices[index + 1] == 1:  #if switches
                    win_stay.append(0)
            elif reward[index] == 0:  #if loses
                if choices[index + 1] == 1:  #if switches
                    lose_switch.append(1)
                elif choices[index + 1] == 0:  #if stays
                    lose_switch.append(0)
    else:
        lose_switch.append('end')
        win_stay.append('end')

In [None]:
n = 'end'
wssplit = [list(y) for x, y in itertools.groupby(win_stay, lambda z: z == n) if not x]

In [None]:
w = 'end'

lssplit = [list(y) for x, y in itertools.groupby(lose_switch, lambda z: z == w) if not x]

In [None]:
ws = list(map(np.mean, wssplit))
ls = list(map(np.mean, lssplit))

In [None]:
x_axis = ['win stay'for i in range(50)] + ['lose switch'for i in range(50)]
y_axis = ws + ls
temp_list2 = [i * 100 for i in y_axis]

In [None]:
df = pd.DataFrame(list(zip(temp_list2, x_axis)),
               columns =['% of action per session', 'action'])

In [None]:
sns.set_style("whitegrid")
ax = sns.boxplot(x="action", y="% of action per session", data=df, width = 0.6, medianprops={"zorder":3})
ax = sns.swarmplot(x="action", y="% of action per session", data=df, color= ".25", alpha=0.5)
ax.set(xlabel=None)
ax.set(ylim=(0, 105))
ax.set_ylabel("% of action per session",fontsize=12)
ax.tick_params(labelsize=10)
plt.savefig('boxplotagent.png', bbox_inches='tight', dpi = 300)

In [None]:
highprob8 = []
highprob7 = []
highprob6 = []
highprob5 = []
highprob4 = []
highprob3 = []
highprob2 = []
highprob1 = []
hl_switch1 = []
hl_switch2 = []
hl_switch3 = []
hl_switch4 = []
hl_switch5 = []
hl_switch6 = []
hl_switch7 = []
hl_switch8 = []

In [None]:
block_change = [i for i in range(1,len(transition)) if transition[i]!=transition[i-1] ]

In [None]:
for index, lr in enumerate(correct):
    if index not in switch:
        if index in block_change:
            highprob8.append(correct[index-8])
            highprob7.append(correct[index-7])
            highprob6.append(correct[index-6])
            highprob5.append(correct[index-5])
            highprob4.append(correct[index-4])
            highprob3.append(correct[index-3])
            highprob2.append(correct[index-2])
            highprob1.append(correct[index-1])
            hl_switch1.append(correct[index])
            hl_switch2.append(correct[index+1])
            hl_switch3.append(correct[index+2])
            hl_switch4.append(correct[index+3])
            hl_switch5.append(correct[index+4])
            hl_switch6.append(correct[index+5])
            hl_switch7.append(correct[index+6])
            hl_switch8.append(correct[index+7])
        else:
            pass
    else:
        pass 

In [None]:
temp_list3 = []

temp_list3 = [-8 for i in range(len(highprob8))] + [-7 for i in range(len(highprob8))] + [-6 for i in range(len(highprob8))] + [-5 for i in range(len(highprob8))] + [-4 for i in range(len(highprob8))] + [-3 for i in range(len(highprob8))] +[-2 for i in range(len(highprob8))] + [-1 for i in range(len(highprob8))] + [0 for i in range(len(highprob8))]+ [1 for i in range(len(highprob8))]+[2 for i in range(len(highprob8))]+[3 for i in range(len(highprob8))]+[4 for i in range(len(highprob8))]+[5 for i in range(len(highprob8))]+[6 for i in range(len(highprob8))]+[7 for i in range(len(highprob8))]

In [None]:
y_axis = []
y_axis = highprob8 +highprob7+highprob6+highprob5+highprob4+highprob3+highprob2+highprob1+hl_switch1+hl_switch2+hl_switch3+hl_switch4+hl_switch5+hl_switch6+hl_switch7+hl_switch8

In [None]:
y_axis = [i * 100 for i in y_axis]

In [None]:
switch_df = pd.DataFrame(list(zip(temp_list3, y_axis)),
               columns =['Trials from switch', '% of choice for higher reward probability'])

In [None]:
ax = sns.lineplot(x='Trials from switch', y='% of choice for higher reward probability', data = switch_df, color = 'darkblue', ci=68)
ax.set(ylim=(0, 100))
sns.despine()
ax.vlines([0], 0, 100, linestyles='dashed', colors='red')
plt.savefig('switchbehavagent.png', bbox_inches='tight', dpi = 300)