In [None]:
%load_ext autoreload
%autoreload 2
import pickle
import os, sys
import numpy as np
import seaborn as sns
import statsmodels.api as sm
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from scipy.spatial.distance import squareform,pdist
from scipy.stats import zscore, norm, ttest_ind
from scipy.io import loadmat
from copy import deepcopy
root_path = os.path.realpath('../')
sys.path.append(root_path)

from utils import eval, choicemodel, plotting, data
from hebbcl.parameters import parser

plt.rcParams['svg.fonttype'] = 'none'
plt.rcParams['pdf.use14corefonts'] = True


## Figure 1: Experiment design and baselines

In [None]:
plotting.plot_basicstats(models=["baseline_interleaved_new_select", "baseline_blocked_new_select"])

In [None]:
plotting.plot_mds(filename_embedding="mds_embedding_baseline_int_new",filename_runs= "baseline_interleaved_new_select", thetas=(40,0,10), layer="all_y_hidden", n_runs=50, resultsdir="../results/")


In [None]:

plotting.plot_mds(filename_embedding="mds_embedding_baseline_blocked_new", filename_runs = "baseline_blocked_new_select", thetas=(-10,-10, -70), layer="all_y_hidden", n_runs=50, resultsdir="../results/")



## Figure 2: The cost of interleaving

In [None]:
# choice matrix models

%matplotlib inline
_,_,cmats = eval.gen_behav_models()

f,ax = plt.subplots(2,2, figsize=(2,2),dpi=300)
for i in range(2):
    for j in range(2):
        ax[i,j].imshow(np.flipud(cmats[i,j,:,:]))
        if j<1:
            ax[i,j].set(xlabel='rel',ylabel='irrel')
        else:
            ax[i,j].set(xlabel='irrel',ylabel='rel')
        ax[i,j].set_xticks([])
        ax[i,j].set_yticks([])


In [None]:
# accuracy (sluggishness)
%matplotlib inline

plotting.plot_sluggish_results(filename="sluggish_baseline_int_select_sv")

## Figure 3: continual learning with manual gating

In [None]:
plotting.plot_basicstats(models = "gated_blocked_new_select_cent")

In [None]:
plotting.plot_mds(filename_embedding="mds_embedding_gated_blocked_new",filename_runs= "gated_blocked_new_select_cent", thetas=(-20,-20,-150), layer="all_y_hidden", n_runs=50, resultsdir="../results/")


## Figure 4: Hebbian learning of context weights 

In [None]:
plotting.biplot_dataset(ds="blobs",ctx_scaling=6)

In [None]:
plotting.plot_oja(n_hidden=1, ds="blobs")

In [None]:
plotting.plot_oja(n_hidden=1, ds="trees")

In [None]:
plotting.plot_basicstats(models = ["oja_blocked_new_select_halfcenter"])

In [None]:
plotting.plot_mds(filename_embedding = "mds_embedding_oja_blocked_new_select_halfcenter", 
filename_runs="oja_blocked_new_select_halfcenter", thetas = (-25,305, 20))


## Figure 5: Modelling human learning with Oja + EMA

### HP Grid Search

In [None]:
# model validation: fit to idealised choice matrices 
%matplotlib inline
_,_,mats = eval.gen_behav_models()
tmp_b = mats[0].ravel()[:,np.newaxis]
tmp_i = mats[1].ravel()[:,np.newaxis]
plt.subplot(1,2,1)
plt.imshow(tmp_b.reshape(10,5))
plt.title('blocked model')
plt.subplot(1,2,2)
plt.imshow(tmp_i.reshape(10,5))
plt.title('interleaved model')

mses_b = choicemodel.gridsearch_modelparams(tmp_b, curriculum="blocked")
plt.figure()
plt.imshow(np.fliplr(np.array(mses_b).reshape(20,20)))
plt.xlabel('sluggishness')
plt.ylabel('slope')
plt.title('blocked model')
plt.colorbar()
mses_i = choicemodel.gridsearch_modelparams(tmp_i, curriculum="interleaved")
plt.figure()
plt.imshow(np.fliplr(np.array(mses_i).reshape(20,20)))
plt.xlabel('sluggishness')
plt.ylabel('slope')
plt.title('interleaved model')
plt.colorbar()

plt.figure(figsize=(15,5))
plt.subplot(1,2,1)
plt.plot(np.array(mses_b).reshape(20,20).mean(1),color='lightgreen',linestyle='-')
plt.plot(np.array(mses_i).reshape(20,20).mean(1),color='orange',linestyle='-')
plt.scatter(np.where(np.array(mses_b).reshape(20,20).mean(1)==np.min(np.array(mses_b).reshape(20,20).mean(1)))[0][0],np.min(np.array(mses_b).reshape(20,20).mean(1)),marker='d',s=100,color='lightgreen')
plt.scatter(np.where(np.array(mses_i).reshape(20,20).mean(1)==np.min(np.array(mses_i).reshape(20,20).mean(1)))[0][0],np.min(np.array(mses_i).reshape(20,20).mean(1)),marker='d',s=100,color='orange')
plt.legend(('blocked','interleaved'))
plt.xlabel('slope')
plt.ylabel('mse')
plt.subplot(1,2,2)
plt.plot(np.flip(np.array(mses_b).reshape(20,20).mean(0)),color='lightgreen',linestyle='-')
plt.plot(np.flip(np.array(mses_i).reshape(20,20).mean(0)),color='orange',linestyle='-')
plt.scatter(np.where(np.flip(np.array(mses_b).reshape(20,20).mean(0))==np.min(np.flip(np.array(mses_b).reshape(20,20).mean(0))))[0][0],np.min(np.flip(np.array(mses_b).reshape(20,20).mean(0))),marker='d',s=100,color='lightgreen')
plt.scatter(np.where(np.flip(np.array(mses_i).reshape(20,20).mean(0))==np.min(np.flip(np.array(mses_i).reshape(20,20).mean(0))))[0][0],np.min(np.flip(np.array(mses_i).reshape(20,20).mean(0))),marker='d',s=100,color='orange')
plt.legend(('blocked','interleaved'))
plt.xlabel('sluggishness')
plt.ylabel('mse')

In [None]:
# grid search at single subject level
gs_results = choicemodel.wrapper_gridsearch_modelparams()
plt.figure()
plt.imshow(np.fliplr(gs_results['cmat_b'].reshape(-1,20,20).mean(0)))
plt.xlabel('sluggishness')
plt.ylabel('slope')
plt.title('single subject lvl, blocked')
plt.colorbar()

plt.figure()
plt.imshow(np.fliplr(gs_results['cmat_i'].reshape(-1,20,20).mean(0)))
plt.xlabel('sluggishness')
plt.ylabel('slope')
plt.title('single subject lvl, interleaved')
plt.colorbar()


# averages
plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plt.plot(gs_results['cmat_b'].reshape(-1,20,20).mean(0).mean(1),color='lightgreen',linestyle='-')
plt.plot(gs_results['cmat_i'].reshape(-1,20,20).mean(0).mean(1),color='orange',linestyle='-')
plt.scatter(np.where(gs_results['cmat_b'].reshape(-1,20,20).mean(0).mean(1)==np.min(gs_results['cmat_b'].reshape(-1,20,20).mean(0).mean(1)))[0][0],np.min(gs_results['cmat_b'].reshape(-1,20,20).mean(0).mean(1)),marker='d',s=100,color='lightgreen')
plt.scatter(np.where(gs_results['cmat_i'].reshape(-1,20,20).mean(0).mean(1)==np.min(gs_results['cmat_i'].reshape(-1,20,20).mean(0).mean(1)))[0][0],np.min(gs_results['cmat_i'].reshape(-1,20,20).mean(0).mean(1)),marker='d',s=100,color='orange')
plt.title('sigmoid slope')
plt.xlabel('param val')
plt.ylabel('mse')
plt.legend(('blocked','interleaved'))
plt.subplot(1,2,2)
plt.plot(np.flip(gs_results['cmat_b'].reshape(-1,20,20).mean(0).mean(0)),color='lightgreen',linestyle='-')
plt.plot(np.flip(gs_results['cmat_i'].reshape(-1,20,20).mean(0).mean(0)),color='orange',linestyle='-')
plt.scatter(np.where(np.flip(gs_results['cmat_b'].reshape(-1,20,20).mean(0).mean(0))==np.min(np.flip(gs_results['cmat_b'].reshape(-1,20,20).mean(0).mean(0))))[0][0],np.min(np.flip(gs_results['cmat_b'].reshape(-1,20,20).mean(0).mean(0))),marker='d',s=100,color='lightgreen')
plt.scatter(np.where(np.flip(gs_results['cmat_i'].reshape(-1,20,20).mean(0).mean(0))==np.min(np.flip(gs_results['cmat_i'].reshape(-1,20,20).mean(0).mean(0))))[0][0],np.min(np.flip(gs_results['cmat_i'].reshape(-1,20,20).mean(0).mean(0))),marker='d',s=100,color='orange')
plt.title('sluggishness')
plt.xlabel('param val')
plt.ylabel('mse')
plt.legend(('blocked','interleaved'))
plt.tight_layout()

In [None]:
print(f"estimated sluggishness (idx), interleaved: {np.argmin(gs_results['cmat_i'].reshape(-1,20,20).mean(0).mean(0))}")
print(f"estimated sluggishness (idx), blocked: {np.argmin(gs_results['cmat_b'].reshape(-1,20,20).mean(0).mean(0))}")
print(f"estimated slope (idx), interleaved: {np.argmin(gs_results['cmat_i'].reshape(-1,20,20).mean(0).mean(1))}")
print(f"estimated slope (idx), blocked: {np.argmin(gs_results['cmat_b'].reshape(-1,20,20).mean(0).mean(1))}")
# np.logspace(np.log(0.1), np.log(4), 20)[12]

In [None]:
%matplotlib inline
# grid search at group level 
gs_m_results = choicemodel.wrapper_gridsearch_modelparams(single_subs=False)
plt.figure()
plt.imshow(np.fliplr(gs_m_results['cmat_b'].reshape(-1,20,20).mean(0)))
plt.xlabel('sluggishness')
plt.ylabel('slope')
plt.title('single subject lvl, blocked')
plt.colorbar()

plt.figure()
plt.imshow(np.fliplr(gs_m_results['cmat_i'].reshape(-1,20,20).mean(0)))
plt.xlabel('sluggishness')
plt.ylabel('slope')
plt.title('single subject lvl, interleaved')
plt.colorbar()


# averages
plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plt.plot(gs_m_results['cmat_b'].reshape(-1,20,20).mean(0).mean(1),color='lightgreen',linestyle='-')
plt.plot(gs_m_results['cmat_i'].reshape(-1,20,20).mean(0).mean(1),color='orange',linestyle='-')
plt.scatter(np.where(gs_m_results['cmat_b'].reshape(-1,20,20).mean(0).mean(1)==np.min(gs_m_results['cmat_b'].reshape(-1,20,20).mean(0).mean(1)))[0][0],np.min(gs_m_results['cmat_b'].reshape(-1,20,20).mean(0).mean(1)),marker='d',s=100,color='lightgreen')
plt.scatter(np.where(gs_m_results['cmat_i'].reshape(-1,20,20).mean(0).mean(1)==np.min(gs_m_results['cmat_i'].reshape(-1,20,20).mean(0).mean(1)))[0][0],np.min(gs_m_results['cmat_i'].reshape(-1,20,20).mean(0).mean(1)),marker='d',s=100,color='orange')
plt.title('sigmoid slope')
plt.xlabel('param val')
plt.ylabel('mse')
plt.legend(('blocked','interleaved'))
plt.subplot(1,2,2)
plt.plot(np.flip(gs_m_results['cmat_b'].reshape(-1,20,20).mean(0).mean(0)),color='lightgreen',linestyle='-')
plt.plot(np.flip(gs_m_results['cmat_i'].reshape(-1,20,20).mean(0).mean(0)),color='orange',linestyle='-')
plt.scatter(np.where(np.flip(gs_m_results['cmat_b'].reshape(-1,20,20).mean(0).mean(0))==np.min(np.flip(gs_m_results['cmat_b'].reshape(-1,20,20).mean(0).mean(0))))[0][0],np.min(np.flip(gs_m_results['cmat_b'].reshape(-1,20,20).mean(0).mean(0))),marker='d',s=100,color='lightgreen')
plt.scatter(np.where(np.flip(gs_m_results['cmat_i'].reshape(-1,20,20).mean(0).mean(0))==np.min(np.flip(gs_m_results['cmat_i'].reshape(-1,20,20).mean(0).mean(0))))[0][0],np.min(np.flip(gs_m_results['cmat_i'].reshape(-1,20,20).mean(0).mean(0))),marker='d',s=100,color='orange')
plt.title('sluggishness')
plt.xlabel('param val')
plt.ylabel('mse')
plt.legend(('blocked','interleaved'))
plt.tight_layout()

In [None]:
print(f"estimated sluggishness (idx), interleaved: {np.argmin(gs_m_results['cmat_i'].reshape(20,20).mean(0))}")
print(f"estimated sluggishness (idx), blocked: {np.argmin(gs_m_results['cmat_b'].reshape(20,20).mean(0))}")
print(f"estimated slope (idx), interleaved: {np.argmin(gs_m_results['cmat_i'].reshape(20,20).mean(1))}")
print(f"estimated slope (idx), blocked: {np.argmin(gs_m_results['cmat_b'].reshape(20,20).mean(1))}")

### Accuracy

In [None]:
n_runs = 20
idx = 1
tempval_interleaved = 12
cmats_a = []
cmats_b = []
for r in np.arange(0, n_runs):
    with open(
        "../checkpoints/sluggish_oja_int_select_sv"
        + str(idx)
        + "/run_"
        + str(r)
        + "/results.pkl",
        "rb",
    ) as f:
        results = pickle.load(f)
        cc = np.clip(results["all_y_out"][1, :], -709.78, 709.78).astype(np.float64)
        choices = 1 / (1 + np.exp(-cc))
        choices = choicemodel.choice_sigmoid(cc,T=tempval_interleaved)
        cmats_a.append(choices[:25].reshape(5, 5))
        cmats_b.append(choices[25:].reshape(5, 5))

cmats_a = np.array(cmats_a).mean(0)
cmats_b = np.array(cmats_b).mean(0)
plt.figure()
plt.subplot(1,2,1)
plt.imshow(cmats_a)
plt.colorbar()
plt.subplot(1,2,2)
plt.imshow(cmats_b)
plt.colorbar()
acc_est = choicemodel.compute_sampled_accuracy(cmats_a, cmats_b)
print(f"accuracy blocked: {acc_est:.2f}")


In [None]:
plotting.plot_modelcomparison_accuracy()

### Sigmoids

In [None]:
betas = plotting.plot_modelcomparison_sigmoids()

In [None]:
plotting.plot_modelcomparison_betas(betas)

### Choice Matrices

In [None]:
cmats = plotting.plot_modelcomparison_choicemats()

### Psychophysical Model

In [None]:
plotting.plot_modelcomparison_choicemodel()

### Congruency Effect

In [None]:
plotting.plot_modelcomparison_congruency(cmats)

## Figure 6: Neural predictions

### Hidden Layer RSA
Fit grid, orthogonal and diagonal model to data.  
Prediction: Orthogonal best in blocked, diagonal best in interleaved group

In [None]:
n_runs = 20
sluggish_vals = np.round(np.linspace(0.05,1,20),2)
alpha = 0.05
idx = np.where(alpha==sluggish_vals)[0][0]

betas_blocked = []
betas_int = []
rdms, dmat = eval.gen_modelrdms(ctx_offset=1)


for r in np.arange(0,n_runs):
    with open('../checkpoints/oja_blocked_new_select_halfcenter/run_' + str(r) +'/results.pkl','rb') as f:
        results = pickle.load(f)
    y = zscore(squareform(pdist(results['all_y_hidden'][1,:,:]))[np.tril_indices(50, k=-1)]).flatten()
    lr = LinearRegression()
    lr.fit(dmat,y)
    betas_blocked.append(lr.coef_)
betas_blocked = np.asarray(betas_blocked)


for r in np.arange(0,n_runs):
    with open('../checkpoints/sluggish_oja_int_select_sv'+ str(idx) +'/run_' + str(r) +'/results.pkl','rb') as f:
        results = pickle.load(f)
    y = zscore(squareform(pdist(results['all_y_hidden'][1,:,:]))[np.tril_indices(50, k=-1)]).flatten()
    lr = LinearRegression()
    lr.fit(dmat,y)
    betas_int.append(lr.coef_)
betas_int = np.asarray(betas_int)

In [None]:
%matplotlib inline
f, axs = plt.subplots(1,1,figsize=(2,2), dpi = 300)

b1 = axs.bar(0-0.2,betas_blocked[:,0].mean(0),yerr=np.std(betas_blocked[:,0],0)/np.sqrt(n_runs),color=(39/255, 140/255, 145/255),width=0.2)
b2 = axs.bar(0,betas_blocked[:,1].mean(0),yerr=np.std(betas_blocked[:,1],0)/np.sqrt(n_runs),color=(34/255, 76/255, 128/255),width=0.2)
b3 = axs.bar(0+0.2,betas_blocked[:,2].mean(0),yerr=np.std(betas_blocked[:,2],0)/np.sqrt(n_runs),color=(159/255, 45/255, 235/255),width=0.2)

b1 = axs.bar(1-0.2,betas_int[:,0].mean(0),yerr=np.std(betas_int[:,0],0)/np.sqrt(n_runs),color=(39/255, 140/255, 145/255),width=0.2)
b2 = axs.bar(1,betas_int[:,1].mean(0),yerr=np.std(betas_int[:,1],0)/np.sqrt(n_runs),color=(34/255, 76/255, 128/255),width=0.2)
b3 = axs.bar(1+0.2,betas_int[:,2].mean(0),yerr=np.std(betas_int[:,2],0)/np.sqrt(n_runs),color=(159/255, 45/255, 235/255),width=0.2)
axs.set_xticks((0,1))
axs.set_yticks((0,0.25,0.5,0.75,1))
axs.set_xticklabels(('blocked','interleaved'),rotation=0,fontsize=6)
axs.set_yticklabels([0,0.25,0.5,0.75,1],fontsize=6)
axs.set_title('Model RSA, Hidden Layer',fontsize=6)
axs.set_ylabel(r'$\beta$ estimate (a.u.)',fontsize=6)
axs.legend([b1,b2,b3],['grid','orthogonal','diagonal'],fontsize=6,frameon=False)

res = ttest_ind(betas_blocked[:,0].ravel(),betas_blocked[:,1].ravel())
z = res.statistic
print(f"blocked grid vs orth: t={z:.2f}, p={res.pvalue:.4f}")
if res.pvalue >= 0.05:
    sigstar ='n.s.'
elif res.pvalue < 0.001:
    sigstar ='*'*3
elif res.pvalue < 0.01:
    sigstar ='*'*2
elif res.pvalue <0.05:
    sigstar ='*'
plt.plot([-0.2,0],[0.75,0.75],'k-', linewidth=1)
plt.text(-0.1,0.75,sigstar,ha='center',fontsize=6)

res = ttest_ind(betas_blocked[:,1].ravel(),betas_blocked[:,2].ravel())
z = res.statistic
print(f"blocked orth vs diag: t={z:.2f}, p={res.pvalue:.4f}")
if res.pvalue >= 0.05:
    sigstar ='n.s.'
elif res.pvalue < 0.001:
    sigstar ='*'*3
elif res.pvalue < 0.01:
    sigstar ='*'*2
elif res.pvalue <0.05:
    sigstar ='*'
plt.plot([0, 0.2],[0.7,0.7],'k-', linewidth=1)
plt.text(0.1,0.7,sigstar,ha='center',fontsize=6)

res = ttest_ind(betas_int[:,0].ravel(),betas_int[:,2].ravel())
z = res.statistic
print(f"int grid vs diag: t={z:.2f}, p={res.pvalue:.4f}")
if res.pvalue >= 0.05:
    sigstar ='n.s.'
elif res.pvalue < 0.001:
    sigstar ='*'*3
elif res.pvalue < 0.01:
    sigstar ='*'*2
elif res.pvalue <0.05:
    sigstar ='*'
plt.plot([0.8,1.2],[0.94,0.94],'k-', linewidth=1)
plt.text(1,0.94,sigstar,ha='center',fontsize=6)

res = ttest_ind(betas_int[:,1].ravel(),betas_int[:,2].ravel())
z = res.statistic
print(f"int orth vs diag: t={z:.2f}, p={res.pvalue:.4f}")
if res.pvalue >= 0.05:
    sigstar ='n.s.'
elif res.pvalue < 0.001:
    sigstar ='*'*3
elif res.pvalue < 0.01:
    sigstar ='*'*2
elif res.pvalue <0.05:
    sigstar ='*'
plt.plot([1,1.2],[0.88,0.88],'k-', linewidth=1)
plt.text(1.1,0.88,sigstar,ha='center',fontsize=6)


sns.despine()
plt.tight_layout()

### Task Selectivity (%)

In [None]:
blocked_n_only_a = []
blocked_n_only_b = []
blocked_n_mixed = []
for r in np.arange(0, n_runs):
    with open(
        "../checkpoints/oja_blocked_new_select/run_" + str(r) + "/results.pkl", "rb"
    ) as f:
        results = pickle.load(f)
    blocked_n_only_a.append(results["n_only_a_regr"][-1])
    blocked_n_only_b.append(results["n_only_b_regr"][-1])
    blocked_n_mixed.append(100 - blocked_n_only_a[-1] - blocked_n_only_b[-1])

blocked_n_only_a = np.asarray(blocked_n_only_a)
blocked_n_only_b = np.asarray(blocked_n_only_b)
blocked_n_mixed = np.asarray(blocked_n_mixed)
print(f"blocked, task-specific: {blocked_n_only_a.mean()+blocked_n_only_b.mean()}")
print(f"blocked, task-agnostic: {blocked_n_mixed.mean()}")
int_n_only_a = []
int_n_only_b = []
int_n_mixed = []
for r in np.arange(0, n_runs):
    with open(
        "../checkpoints/sluggish_oja_int_select_sv"
        + str(idx)
        + "/run_"
        + str(r)
        + "/results.pkl",
        "rb",
    ) as f:
        results = pickle.load(f)
    int_n_only_a.append(results["n_only_a_regr"][-1])
    int_n_only_b.append(results["n_only_b_regr"][-1])
    int_n_mixed.append(100 - int_n_only_a[-1] - int_n_only_b[-1])

int_n_only_a = np.asarray(int_n_only_a)
int_n_only_b = np.asarray(int_n_only_b)
int_n_mixed = np.asarray(int_n_mixed)
print(f"int, task-specific: {int_n_only_a.mean()+int_n_only_b.mean()}")
print(f"int, task-agnostic: {int_n_mixed.mean()}")
n_a = np.stack((blocked_n_only_a, int_n_only_a), axis=1)
n_b = np.stack((blocked_n_only_b, int_n_only_b), axis=1)
n_m = np.stack((blocked_n_mixed, int_n_mixed), axis=1)
f, ax = plt.subplots(figsize=(2, 2), dpi=300)
b1 = ax.bar(
    ["blocked", "interleaved"],
    n_a.mean(0),
    yerr=np.std(n_a, 0, ddof=1) / np.sqrt(n_runs),
    width=0.2,
)
b2 = ax.bar(
    ["blocked", "interleaved"],
    n_b.mean(0),
    yerr=np.std(n_a, 0, ddof=1) / np.sqrt(n_runs),
    bottom=n_a.mean(0),
    width=0.2,
)
b3 = ax.bar(
    ["blocked", "interleaved"],
    n_m.mean(0),
    yerr=np.std(n_a, 0, ddof=1) / np.sqrt(n_runs),
    bottom=n_b.mean(0) + n_a.mean(0),
    width=0.2,
)
ax.set_ylabel("hidden units (%)", fontsize=6)
ax.legend(
    [b1, b2, b3],
    ["1st task, rel. dim.", "2nd task, rel. dim.", "task agnostic"],
    fontsize=6,
    frameon=False,
)
ax.set_title("Hidden Unit Selectivity", fontsize=6)
sns.despine()

for item in (
    [ax.title, ax.xaxis.label, ax.yaxis.label]
    + ax.get_xticklabels()
    + ax.get_yticklabels()
    + ax.get_legend().get_texts()
):
    item.set_fontsize(6)


### Readout Magnitude

In [None]:
import torch
from utils.eval import make_dmat

n_runs = 20
sluggish_vals = np.round(np.linspace(0.05,1,20),2)
alpha = 0.05
idx = np.where(alpha==sluggish_vals)[0][0]
data = eval.make_blobs_dataset()

dmat = make_dmat(data['f_all'])

readout_magnitude = np.empty((50,3))
for r in np.arange(0, n_runs):
    with open(
        "../checkpoints/sluggish_oja_int_select_sv"
        + str(idx)
        + "/run_"
        + str(r)
        + "/results.pkl",
        "rb",
    ) as f:
        results = pickle.load(f)

    yh = results['all_y_hidden'][1,:,:]
    selectivity_matrix = np.zeros((100, 6))
    for i_neuron in range(100):
        y_neuron = yh[:, i_neuron]
        lr = sm.OLS(zscore(y_neuron), dmat)
        regr_results = lr.fit()
        # if only a single regressor is significant, store that neuron's selectivity
        if np.sum(regr_results.tvalues > 1.96) == 1:
            selectivity_matrix[
                i_neuron,
                np.where(regr_results.tvalues == np.max(regr_results.tvalues))[0][
                    0
                ],
            ] = 1
    i_task_a = (
        (selectivity_matrix[:, 0] == 0)
        & (selectivity_matrix[:, 1] == 1)
        & (selectivity_matrix[:, 2] == 0)
        & (selectivity_matrix[:, 3] == 0)
    )
    i_task_b = (
        (selectivity_matrix[:, 0] == 0)
        & (selectivity_matrix[:, 1] == 0)
        & (selectivity_matrix[:, 2] == 1)
        & (selectivity_matrix[:, 3] == 0)
    )

    with open(
            "../checkpoints/sluggish_oja_int_select_sv"
            + str(idx)
            + "/run_"
            + str(r)
            + "/model.pkl",
            "rb",
        ) as f:
            model = pickle.load(f)
    wo = model.W_o.cpu().detach().numpy()
    
    readout_magnitude[r,0] = np.abs(wo[i_task_a]).mean() if sum(i_task_a) !=0 else 0
    readout_magnitude[r,1] = np.abs(wo[i_task_b]).mean() if sum(i_task_b) !=0 else 0
    readout_magnitude[r,2] = np.abs(wo[~((i_task_a==True) | (i_task_b==True))]).mean()
readout_magnitude_int = readout_magnitude


In [None]:
import sys
sys.path.append(r'../hebbcl/')


with open(
    "../checkpoints/oja_blocked_new_select/run_" + str(r) + "/results.pkl", "rb"
) as f:
    results = pickle.load(f)


import torch
from utils.eval import make_dmat


data = eval.make_blobs_dataset()

dmat = make_dmat(data["f_all"])

readout_magnitude = np.empty((50, 3))
for r in np.arange(0, n_runs):
    with open(
        "../checkpoints/oja_blocked_new_select_halfcenter/run_" + str(r) + "/results.pkl",
        "rb",
    ) as f:
        results = pickle.load(f)

    yh = results["all_y_hidden"][1, :, :]
    selectivity_matrix = np.zeros((100, 6))
    for i_neuron in range(100):
        y_neuron = yh[:, i_neuron]
        lr = sm.OLS(zscore(y_neuron), dmat)
        regr_results = lr.fit()
        # if only a single regressor is significant, store that neuron's selectivity
        if np.sum(regr_results.tvalues > 1.96) == 1:
            selectivity_matrix[
                i_neuron,
                np.where(regr_results.tvalues == np.max(regr_results.tvalues))[0][0],
            ] = 1
    i_task_a = (
        (selectivity_matrix[:, 0] == 0)
        & (selectivity_matrix[:, 1] == 1)
        & (selectivity_matrix[:, 2] == 0)
        & (selectivity_matrix[:, 3] == 0)
    )
    i_task_b = (
        (selectivity_matrix[:, 0] == 0)
        & (selectivity_matrix[:, 1] == 0)
        & (selectivity_matrix[:, 2] == 1)
        & (selectivity_matrix[:, 3] == 0)
    )

    with open("../checkpoints/oja_blocked_new_select_halfcenter/run_" + str(r) + "/model.pkl", "rb") as f:
        model = pickle.load(f)
    
    wo = model.W_o.cpu().detach().numpy()

    readout_magnitude[r, 0] = np.abs(wo[i_task_a]).mean() if sum(i_task_a) != 0 else 0
    readout_magnitude[r, 1] = np.abs(wo[i_task_b]).mean() if sum(i_task_b) != 0 else 0
    readout_magnitude[r, 2] = np.abs(
        wo[~((i_task_a == True) | (i_task_b == True))]
    ).mean()

readout_magnitude_blocked = readout_magnitude


In [None]:
f, axs = plt.subplots(1,1,figsize=(2,2), dpi = 300)

b1 = axs.bar(0-0.2,readout_magnitude_blocked[:,0].mean(0),yerr=np.std(readout_magnitude_blocked[:,0],0)/np.sqrt(50),color=(39/255, 140/255, 145/255),width=0.2)
b2 = axs.bar(0,readout_magnitude_blocked[:,1].mean(0),yerr=np.std(readout_magnitude_blocked[:,1],0)/np.sqrt(50),color=(34/255, 76/255, 128/255),width=0.2)
b3 = axs.bar(0+0.2,readout_magnitude_blocked[:,2].mean(0),yerr=np.std(readout_magnitude_blocked[:,2],0)/np.sqrt(50),color=(159/255, 45/255, 235/255),width=0.2)

b1 = axs.bar(1-0.2,readout_magnitude_int[:,0].mean(0),yerr=np.std(readout_magnitude_int[:,0],0)/np.sqrt(20),color=(39/255, 140/255, 145/255),width=0.2)
b2 = axs.bar(1,readout_magnitude_int[:,1].mean(0),yerr=np.std(readout_magnitude_int[:,1],0)/np.sqrt(20),color=(34/255, 76/255, 128/255),width=0.2)
b3 = axs.bar(1+0.2,readout_magnitude_int[:,2].mean(0),yerr=np.std(readout_magnitude_int[:,2],0)/np.sqrt(20),color=(159/255, 45/255, 235/255),width=0.2)
axs.set_xticks((0,1))
axs.set_xticklabels(('blocked','interleaved'),rotation=0,fontsize=6)
# ticks = np.round(axs.get_yticks(),2)
# axs.set_yticklabels(ticks,fontsize=6)
plt.yticks(fontsize=6)
axs.set_title('Readout Weights',fontsize=6)
axs.set_ylabel(r'weight magnitude (a.u.)',fontsize=6)
axs.legend([b1,b2,b3],['1st task','2nd task','task agnostic'],fontsize=6,frameon=False)

res = ttest_ind(readout_magnitude_int[:,0].ravel(),readout_magnitude_int[:,2].ravel())
z = res.statistic
print(f"int 1st vs agnostic: t={z:.2f}, p={res.pvalue:.4f}")
if res.pvalue >= 0.05:
    sigstar ='n.s.'
elif res.pvalue < 0.001:
    sigstar ='*'*3
elif res.pvalue < 0.01:
    sigstar ='*'*2
elif res.pvalue <0.05:
    sigstar ='*'
plt.plot([0.8,1.2],[0.22,0.22],'k-', linewidth=1)
plt.text(1,0.22,sigstar,ha='center',fontsize=6)

res = ttest_ind(readout_magnitude_int[:,1].ravel(),readout_magnitude_int[:,2].ravel())
z = res.statistic
print(f"int 2nd vs agnostic: t={z:.2f}, p={res.pvalue:.4f}")
if res.pvalue >= 0.05:
    sigstar ='n.s.'
elif res.pvalue < 0.001:
    sigstar ='*'*3
elif res.pvalue < 0.01:
    sigstar ='*'*2
elif res.pvalue <0.05:
    sigstar ='*'
plt.plot([1,1.2],[0.19,0.19],'k-', linewidth=1)
plt.text(1.1,0.19,sigstar,ha='center',fontsize=6)
sns.despine()
plt.tight_layout()