In [1]:
%load_ext autoreload
%autoreload 2

import sys
import warnings
sys.path.append('..')
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import pingouin as pg
import plotly.express as px
import torch

import experiment1
import utils

# Setup

In [2]:
params = utils.Map(
    n_participants=58,
    n_simulations = 100, # number of rollouts per participant
    n_steps = 3, # number of steps per rollout
    state_d = 7, # length of state vector
    context_d = 7, # length of context vector
    time_d = 25, # length of time vector
    self_excitation = .25, # rate at which old context is carried over to new context
    input_weight = .45, # rate at which state is integrated into new context
    retrieved_context_weight = .3, # rate at which context retrieved from EM is integrated into new context
    time_noise=.01, # noise std for time integrator (drift is set to 0)
    state_weight = .5, # weight of the state used during memory retrieval
    context_weight = .3, # weight of the context used during memory retrieval
    time_weight = .2, # weight of the time used during memory retrieval
    temperature = .05, # temperature of the softmax used during memory retrieval (smaller means more argmax-like)
    seed=1234 # random seed for the simulation
)

# Run experiment

In [3]:
utils.set_random_seed(params.seed)
revaluation_scores = experiment1.run_experiment(params)

# Plot results

In [4]:
# Model data
plot_data = pd.DataFrame(revaluation_scores[:,:2],columns=['Reward<br>Revaluation','Transition<br>Revaluation']).melt(var_name='Condition',value_name='Revaluation Score')
scaling_factor = .5199/revaluation_scores[:,0].mean() # Scale revaluation scores to match mean of behavioral data
t_test_results = pg.pairwise_tests(plot_data,between='Condition',dv='Revaluation Score')

plot_data['Revaluation Score'] *= scaling_factor
plot_data = plot_data.groupby(['Condition']).agg(**{'Revaluation Score': ('Revaluation Score','mean'), 'error': ('Revaluation Score','sem')}).reset_index()
fig = px.bar(plot_data,range_y=[0,1],x='Condition',y='Revaluation Score',color='Condition',error_y='error',
             color_discrete_sequence=['#0099C6','#FF9900'],title='Model Performance')
fig.update_layout(showlegend=False)
fig = utils.format_figure(fig,width=600,height=600)
fig.show()
print('T-test for model results:')
display(t_test_results)

# Behavioral data
plot_data = pd.DataFrame({'Condition':['Reward<br>Revaluation','Transition<br>Revaluation'],'Revaluation Score':[.5199,.4503],'error':[.0203,.0229]})
fig = px.bar(plot_data,range_y=[0,1],x='Condition',y='Revaluation Score',color='Condition',error_y='error',
             color_discrete_sequence=['#0099C6','#FF9900'],title='Human Performance')
fig.update_layout(showlegend=False)
fig = utils.format_figure(fig,width=600,height=600)
fig.show()

T-test for model results:


Unnamed: 0,Contrast,A,B,Paired,Parametric,T,dof,alternative,p-unc,BF10,hedges
0,Condition,Reward<br>Revaluation,Transition<br>Revaluation,False,True,3.908416,114.0,two-sided,0.000158,145.323,0.720989


# Plot context representations

In [41]:
utils.set_random_seed(params.seed)

# Generate sample trials
visited_states_baseline, rewards_baseline = experiment1.gen_baseline_trials(params)
visited_states_reward_revaluation, rewards_reward_revaluation = experiment1.gen_reward_revaluation_trials(params)
visited_states_transition_revaluation, rewards_transition_revaluation = experiment1.gen_transition_revaluation_trials(params)

# Generate memories for each condition
initial_memories = experiment1.gen_memories(visited_states_baseline, rewards_baseline, params)
reward_revaluation_memories = experiment1.gen_memories(torch.cat([visited_states_baseline, visited_states_reward_revaluation]), torch.cat([rewards_baseline, rewards_reward_revaluation]), params)
transition_revaluation_memories = experiment1.gen_memories(torch.cat([visited_states_baseline, visited_states_transition_revaluation]), torch.cat([rewards_baseline, rewards_transition_revaluation]), params)

# Averge memories for each condition
memory_averages = []
for memories in [initial_memories, reward_revaluation_memories, transition_revaluation_memories]:
    if len(memories[0]) > 60:
        memories = [m[60:70] for m in memories]
    else:
        memories = [m[:60] for m in memories]
    states = memories[0].argmax(-1).numpy()
    contexts = memories[1].numpy()
    df = pd.DataFrame(np.concatenate([contexts,states[:,np.newaxis]],axis=-1))
    df.columns = ['Context 0','Context 1','Context 2','Context 3','Context 4','Context 5','Context 6','State']
    memory_averages.append(df.groupby(['State'],as_index=False).mean())
average_initial_contexts, average_reward_reval_contexts, average_trans_reval_contexts = memory_averages
initial_matrix = average_initial_contexts.values[:,2:].T
reward_matrix = np.concatenate([average_initial_contexts.values[:,2:].T[:,:2],average_reward_reval_contexts.values[:,2:].T],axis=1)
transition_matrix = np.concatenate([average_initial_contexts.values[:,2:].T[:,:2],average_trans_reval_contexts.values[:,2:].T],axis=1)

# Plot average memories for each condition
print('Full plots')
for name, plot_matrix in zip(['Initial','Reward Reval','Transition Reval'],[initial_matrix,reward_matrix,transition_matrix]):
    fig = px.imshow(plot_matrix,x=['S<sub>1</sub>', 'S<sub>2</sub>', 'S<sub>3</sub>', 'S<sub>4</sub>', 'S<sub>5</sub>', 'S<sub>6</sub>'],
          y=['S<sub>1</sub>', 'S<sub>2</sub>', 'S<sub>3</sub>', 'S<sub>4</sub>', 'S<sub>5</sub>', 'S<sub>6</sub>'],
          title=name
         )
    fig.update_xaxes(tickfont=dict(size=26),titlefont=dict(size=24))
    fig.update_yaxes(tickfont=dict(size=26),titlefont=dict(size=24))
    fig.show()

print('Plot insets')
for name, plot_matrix in zip(['Reward Reval','Transition Reval'],[reward_matrix,transition_matrix]):
    fig = px.imshow(plot_matrix[:2,-2:],x=['S<sub>5</sub>', 'S<sub>6</sub>'],
          y=['S<sub>1</sub>', 'S<sub>2</sub>'],
          title=name,
          range_color=(0,.2)
         )
    fig.update_xaxes(tickfont=dict(size=26),titlefont=dict(size=24))
    fig.update_yaxes(tickfont=dict(size=26),titlefont=dict(size=24))
    fig.show()

Full plots


Plot insets
