In [1]:
%load_ext autoreload
%autoreload 2

import sys
import warnings
sys.path.append('..')
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import pingouin as pg
import plotly.express as px
from sklearn.metrics.pairwise import cosine_similarity
from scipy.stats import pearsonr
import torch

import experiment3
import utils

# Setup

In [2]:
params = utils.Map(
    n_participants = 20,
    state_d = 14, # dimensionality of the state input
    context_d = 4, # dimensionality of the learned context representations
    semantic_d = 32, # dimensionality of the context-dependent semantic representations
    output_d = 5, # dimensionality of the output layer
    semantic_lr = 2, # learning rate for the semantic pathway
    episodic_lr = .5, # learning rate for the episodic pathway
    persistance = 1, # bias towards memory retention in the recurrent context module
    temperature = .05 # temperature for EM retrieval (lower is more argmax-like)
)

# Run experiment

In [3]:
performance_data, model_data = experiment3.run_experiment(params)

# Plot Results

## Model Results

In [4]:
# Train data
plot_data = pd.DataFrame(performance_data)
plot_data = plot_data[plot_data.pathway=='Combined']
plot_data = plot_data.groupby(['trial','paradigm']).agg(**{'accuracy':('accuracy','mean'),'error':('accuracy','sem')}).reset_index()
plot_data = plot_data[plot_data.trial%50==0]
plot_data = plot_data[(plot_data.trial>0)&(plot_data.trial<401)]

# Split blocked data into two to remove line between the blocks, following Flesch et al., 2018
plot_data.loc[(plot_data.paradigm=='Blocked')&(plot_data.trial>200),'paradigm'] = 'Blocked2'

# Create the line plot
plot_data['Accuracy'] = plot_data['accuracy']
plot_data['Trial'] = plot_data['trial']
fig = px.line(plot_data,x='Trial',y='Accuracy',color='paradigm',error_y='error',markers=True,range_y=[.5,1],range_x=[0,410])
fig = utils.format_figure(fig,width=1000,height=500,showlegend=False)
fig = utils.change_figure_colors(fig,['#95DFCD','#D985CB','#95DFCD'])
fig.update_traces(line=dict(width=4),marker=dict(size=10))
fig.show()

# Test data
plot_data = pd.DataFrame(performance_data)
plot_data = plot_data[plot_data.pathway=='Combined']
plot_data = plot_data[plot_data.trial>400].groupby(['seed','paradigm']).mean().reset_index()
display(pg.pairwise_tests(plot_data,dv='accuracy',between='paradigm'))
plot_data = plot_data.groupby(['paradigm']).agg(**{'accuracy':('accuracy','mean'),'error':('accuracy','sem')}).reset_index()
plot_data['Accuracy'] = plot_data['accuracy']
plot_data['Training Type'] = plot_data['paradigm']
fig = px.bar(plot_data,x='Training Type',color='Training Type',y='Accuracy',error_y='error',range_y=[.5,1])
fig = utils.format_figure(fig,width=500,height=500,showlegend=False)
fig = utils.change_figure_colors(fig,['#95DFCD','#D985CB'])
fig.update_yaxes(title=None,showticklabels=False,tickmode='array',tickvals=[])
fig.show()

# EM vs SM data
plot_data = pd.DataFrame(performance_data)
plot_data = plot_data[plot_data.pathway!='Combined']
plot_data = plot_data.groupby(['trial','paradigm','pathway']).agg(**{'accuracy':('accuracy','mean'),'error':('accuracy','std')}).reset_index()
plot_data['error'] /= np.sqrt(params.n_participants)
plot_data = plot_data[plot_data.trial%50==0]
plot_data = plot_data[(plot_data.trial>0)&(plot_data.trial<401)]

# Split blocked data into two to remove line between the blocks, following Flesch et al., 2018
plot_data.loc[(plot_data.paradigm=='Blocked')&(plot_data.trial>200),'paradigm'] = 'Blocked2'

# Create the line plot
plot_data['Accuracy'] = plot_data['accuracy']
plot_data['Trial'] = plot_data['trial']
fig = px.line(plot_data,x='Trial',y='Accuracy',color='paradigm',line_dash='pathway',error_y='error',markers=True,range_y=[.5,1],range_x=[0,410])
fig = utils.format_figure(fig,width=1000,height=500,showlegend=False)
fig = utils.change_figure_colors(fig,['#95DFCD','#95DFCD','#D985CB','#D985CB','#95DFCD','#95DFCD'])
fig.update_traces(line=dict(width=4),marker=dict(size=10))
fig.show()

Unnamed: 0,Contrast,A,B,Paired,Parametric,T,dof,alternative,p-unc,BF10,hedges
0,paradigm,Blocked,Interleaved,False,True,2.521542,38.0,two-sided,0.015997,3.471,0.781539


## Human Results

In [5]:
"""

@Declan to replicate above plots for behavioral results from Flesch et al., 2018

"""

'\n\n@Declan to replicate above plots for behavioral results from Flesch et al., 2018\n\n'

# Analyze semantic representations

In [7]:
xs = torch.tensor(np.load('data/exp3_reps.npy')).float()
separate_sim = experiment3.get_template_rdm('separate')
shared_sim = experiment3.get_template_rdm('shared')

rdms = []
corr_data = []
for row_idx,row in enumerate(model_data):
    embedding = experiment3.get_embedding(row,xs)
    rdms.append(-cosine_similarity(embedding))
    embedding_sim = cosine_similarity(embedding)[np.triu_indices(len(embedding),k=1)]
    separate_corr = pearsonr(separate_sim,embedding_sim)[0]
    shared_corr = pearsonr(shared_sim,embedding_sim)[0]
    corr_data.append({'paradigm':row['paradigm'],'Correlation':separate_corr,'Representation Type':'Separate','seed':row['seed']})
    corr_data.append({'paradigm':row['paradigm'],'Correlation':shared_corr,'Representation Type':'Shared','seed':row['seed']})
corr_data = pd.DataFrame(corr_data)
display(pg.pairwise_ttests(corr_data,dv='Correlation',between=['Representation Type','paradigm'],parametric=False))
corr_data = corr_data.groupby(['paradigm','Representation Type']).agg(**{'Correlation':('Correlation','mean'),
                                                                        'error':('Correlation','sem')}).reset_index()
fig = px.bar(corr_data,x='Representation Type',y='Correlation',error_y='error',color='paradigm',barmode='group',range_y=[0,1])
fig = utils.format_figure(fig,width=1000,height=500)
fig = utils.change_figure_colors(fig,['#95DFCD','#D985CB'])
fig.show()

Unnamed: 0,Contrast,Representation Type,A,B,Paired,Parametric,U-val,alternative,p-unc,hedges
0,Representation Type,-,Separate,Shared,False,False,651.0,two-sided,0.153021,-0.258082
1,paradigm,-,Blocked,Interleaved,False,False,851.0,two-sided,0.627012,0.24477
2,Representation Type * paradigm,Separate,Blocked,Interleaved,False,False,305.0,two-sided,0.004703,1.035028
3,Representation Type * paradigm,Shared,Blocked,Interleaved,False,False,96.0,two-sided,0.005115,-0.609967


In [9]:
all_features = np.load('data/exp3_features.npy')
context = np.load('data/exp3_context.npy')-1
ys1 = all_features[:,0]
ys2 = all_features[:,1]

# Visualize embeddings for two models, one blocked and one interleaved
for m, rot_type in zip([model_data[1],model_data[-1]],['angle','flip']):
    embedding = experiment3.get_embedding(m,xs)
    mds_reps = experiment3.get_rotated_mds(embedding,rot_type)
    utils.format_figure(px.scatter(mds_reps,x=0,y=2,color=ys2,size=ys1,symbol=context),width=500,height=500).show()
    utils.format_figure(px.scatter(mds_reps,x=1,y=2,color=ys2,size=ys1,symbol=context),width=500,height=500).show()
    px.imshow(cosine_similarity(embedding)).show()