# Scratchpad for paper revisions

In [1]:
%load_ext autoreload
%autoreload 2
import pickle
import os, sys
root_path = os.path.realpath('../')
sys.path.append(root_path)

import torch
from pathlib import Path
import numpy as np
import random


from utils.data import make_blobs_dataset, make_trees_dataset
from utils.nnet import get_device

from hebbcl.logger import LoggerFactory
from hebbcl.model import Nnet, ScaledNet2Hidden
from hebbcl.trainer import Optimiser, train_on_blobs, train_on_trees
from hebbcl.parameters import parser
from hebbcl.tuner import HPOTuner

## Hyperparameter optimisation
hpo on network trained with fewer episodes

### HPO: blocked trials with oja_ctx

In [None]:
# HPO on blocked trials with oja_ctx
args = parser.parse_args(args=[])
args.n_episodes = 8
args.hpo_fixedseed = True
args.hpo_scheduler = "bohb"
args.hpo_searcher = "bohb"
# dict(sorted(vars(args).items(),key=lambda k: k[0]))
args.ctx_avg = False
# init tuner
tuner = HPOTuner(args, time_budget=60*15, metric="loss")

tuner.tune(n_samples=500)

df = tuner.results
df = df[["mean_loss", "mean_acc", "config.lrate_sgd","config.lrate_hebb", "config.ctx_scaling","config.seed","done"]]
df = df[df["done"]==True]
df = df.drop(columns=["done"])
df = df.dropna()
df = df.sort_values("mean_loss",ascending=True)

df.reset_index()
print(df.head(15))

print(tuner.best_cfg)

with open("../results/raytune_oja_ctx_blocked_8episodes.pkl", "wb") as f:
    pickle.dump(df, f)

In [None]:
with open("../results/raytune_oja_ctx_blocked_8episodes.pkl", "rb") as f:
    df = pickle.load(f)

df.iloc[0]

In [None]:
# verify results 
with open("../results/raytune_oja_ctx_blocked_8episodes.pkl", "rb") as f:
    df = pickle.load(f)
# obtain params
args = parser.parse_args(args=[])

# set checkpoint directory
save_dir = (
        Path("checkpoints") / "test_allhebb"
    ) 

# get device (gpu/cpu)
args.device = get_device(args.cuda)[0]

# override defaults 
args.n_episodes = 8
args.lrate_hebb = df.iloc[0]["config.lrate_hebb"]
args.lrate_sgd = df.iloc[0]["config.lrate_sgd"]
args.ctx_scaling = df.iloc[0]["config.ctx_scaling"]
args.ctx_avg = False
np.random.seed(int(df.iloc[0]["config.seed"]))
random.seed(int(df.iloc[0]["config.seed"]))
torch.manual_seed(int(df.iloc[0]["config.seed"]))


# create dataset 
dataset = make_blobs_dataset(args)

# instantiate logger, model and optimiser:
logger = LoggerFactory.create(args, save_dir)
model = Nnet(args)
optimiser = Optimiser(args)

# send model to device (GPU?)
model = model.to(args.device)


# train model
train_on_blobs(args, model, optimiser, dataset, logger)

print(f"config: lrate_sgd: {args.lrate_sgd:.4f}, lrate_hebb: {args.lrate_hebb:.4f}, context offset: {args.ctx_scaling}")
print(f"terminal accuracy: {logger.results['acc_total'][-1]:.2f}, loss: {logger.results['losses_total'][-1]:.2f}")

### HPO: Interleaved trials

In [None]:
# HPO on blocked trials with oja_ctx
args = parser.parse_args(args=[])
args.n_episodes = 8
args.hpo_fixedseed = True
args.hpo_scheduler = "bohb"
args.hpo_searcher = "bohb"
args.training_schedule = "interleaved"
# dict(sorted(vars(args).items(),key=lambda k: k[0]))
args.ctx_avg = False
# init tuner
tuner = HPOTuner(args, time_budget=60*15, metric="loss")

tuner.tune(n_samples=500)

df = tuner.results
df = df[["mean_loss", "mean_acc", "config.lrate_sgd","config.lrate_hebb", "config.ctx_scaling","config.seed","done"]]
df = df[df["done"]==True]
df = df.drop(columns=["done"])
df = df.dropna()
df = df.sort_values("mean_loss",ascending=True)

df.reset_index()
print(df.head(15))

print(tuner.best_cfg)

with open("../results/raytune_oja_ctx_interleaved_8episodes.pkl", "wb") as f:
    pickle.dump(df, f)

In [None]:
# verify results 

# obtain params
args = parser.parse_args(args=[])

# set checkpoint directory
save_dir = (
        Path("checkpoints") / "test_allhebb"
    ) 

# get device (gpu/cpu)
args.device = get_device(args.cuda)[0]

# override defaults 
args.n_episodes = 8
args.lrate_hebb = df.iloc[0]["config.lrate_hebb"]
args.lrate_sgd = df.iloc[0]["config.lrate_sgd"]
args.ctx_scaling = df.iloc[0]["config.ctx_scaling"]
args.ctx_avg = False
args.training_schedule = "interleaved"
np.random.seed(int(df.iloc[0]["config.seed"]))
random.seed(int(df.iloc[0]["config.seed"]))
torch.manual_seed(int(df.iloc[0]["config.seed"]))



# create dataset 
dataset = make_blobs_dataset(args)

# instantiate logger, model and optimiser:
logger = LoggerFactory.create(args, save_dir)
model = Nnet(args)
optimiser = Optimiser(args)

# send model to device (GPU?)
model = model.to(args.device)


# train model
train_on_blobs(args, model, optimiser, dataset, logger)

print(f"config: lrate_sgd: {args.lrate_sgd:.4f}, lrate_hebb: {args.lrate_hebb:.4f}, context offset: {args.ctx_scaling}")
print(f"terminal accuracy: {logger.results['acc_total'][-1]:.2f}, loss: {logger.results['losses_total'][-1]:.2f}")

### HPO: all_oja

In [33]:
with open("../results/raytune_blobs_asha_200episodes_blocked_vanilla_1ctx.pkl","rb") as f:
    df = pickle.load(f)["df"]
df = df.sort_values("mean_loss").head(15)
df

Unnamed: 0_level_0,mean_loss,mean_acc,done,config.lrate_sgd,config.ctx_scaling,config.seed
trial_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
a325e_02551,-28.044544,0.95,True,0.076368,7,2894
a325e_01146,-28.042593,0.975,True,0.08474,7,3005
a325e_01281,-27.508484,0.95,True,0.091234,7,8542
a325e_02688,-27.327234,0.95,True,0.091137,6,490
a325e_00146,-27.02043,0.95,True,0.078242,6,7215
a325e_01302,-26.922705,0.95,True,0.09816,6,6678
a325e_00715,-26.824409,0.95,True,0.041682,7,7111
a325e_01985,-26.337194,0.9,True,0.05085,7,9028
a325e_01519,-26.333696,0.95,True,0.069805,7,7054
a325e_00126,-26.329458,0.95,True,0.063359,6,393


In [116]:
# verify results 

# obtain params
args = parser.parse_args(args=[])

# set checkpoint directory
save_dir = (
        Path("checkpoints") / "test_allhebb"
    ) 

# get device (gpu/cpu)
args.device = get_device(args.cuda)[0]

# override defaults 
args.n_episodes = 200
args.n_layers = 2
args.n_features = 974
args.lrate_hebb = df.iloc[0]["config.lrate_hebb"]
args.lrate_sgd = df.iloc[0]["config.lrate_sgd"]
args.ctx_scaling = df.iloc[0]["config.ctx_scaling"]
args.ctx_avg = False
args.training_schedule = "interleaved"
np.random.seed(int(df.iloc[0]["config.seed"]))
random.seed(int(df.iloc[0]["config.seed"]))
torch.manual_seed(int(df.iloc[0]["config.seed"]))



# create dataset 
dataset = make_trees_dataset(args)

# instantiate logger, model and optimiser:
logger = LoggerFactory.create(args, save_dir)
model = ScaledNet2Hidden(args)
optimiser = Optimiser(args)

# send model to device (GPU?)
model = model.to(args.device)


# train model
train_on_trees(args, model, optimiser, dataset, logger)

print(f"config: lrate_sgd: {args.lrate_sgd:.4f}, lrate_hebb: {args.lrate_hebb:.4f}, context offset: {args.ctx_scaling}")
print(f"terminal accuracy: {logger.results['acc_total'][-1]:.2f}, loss: {logger.results['losses_total'][-1]:.2f}")

5000
step 0, loss: task a -0.4216, task b 0.3495 | acc: task a 0.5000, task b 0.5000
...1st hidden: n_a: 5 n_b: 4
... 2nd hidden: n_a: 7 n_b: 7
step 50, loss: task a 0.1933, task b -1.9896 | acc: task a 0.5000, task b 0.5000
...1st hidden: n_a: 8 n_b: 6
... 2nd hidden: n_a: 1 n_b: 9
step 100, loss: task a -1.1461, task b -12.2039 | acc: task a 0.5000, task b 0.5000
...1st hidden: n_a: 6 n_b: 7
... 2nd hidden: n_a: 2 n_b: 3
step 150, loss: task a -7.9329, task b -36.3379 | acc: task a 0.5000, task b 0.5000
...1st hidden: n_a: 3 n_b: 2
... 2nd hidden: n_a: 1 n_b: 2
step 200, loss: task a -8.0407, task b -44.7073 | acc: task a 0.5000, task b 0.5000
...1st hidden: n_a: 3 n_b: 2
... 2nd hidden: n_a: 1 n_b: 1
step 250, loss: task a -34.5834, task b -113.7178 | acc: task a 0.5000, task b 0.5000
...1st hidden: n_a: 1 n_b: 3
... 2nd hidden: n_a: 1 n_b: 0
step 300, loss: task a -122.8599, task b -120.6186 | acc: task a 0.5000, task b 0.5000
...1st hidden: n_a: 1 n_b: 3
... 2nd hidden: n_a: 1 n_b

In [3]:
from hebbcl.tuner import validate_tuner_results

n_episodes = [8, 200]
configs = [    
    "blocked_ojaall_1ctx",    
]

for cfg in configs:
    for ep in n_episodes:
        for i in range(1,5):
            validate_tuner_results(filename="blobs_asha_"+str(ep)+"episodes_" + cfg,filepath="../results/",datapath="../datasets/", whichtrial=i,njobs=6)

[Parallel(n_jobs=6)]: Using backend LokyBackend with 6 concurrent workers.
[Parallel(n_jobs=6)]: Done   1 tasks      | elapsed:    8.6s
[Parallel(n_jobs=6)]: Done   6 tasks      | elapsed:    9.3s
[Parallel(n_jobs=6)]: Done  13 tasks      | elapsed:   13.1s
[Parallel(n_jobs=6)]: Done  20 tasks      | elapsed:   15.7s
[Parallel(n_jobs=6)]: Done  29 tasks      | elapsed:   18.5s
[Parallel(n_jobs=6)]: Done  38 tasks      | elapsed:   22.3s
[Parallel(n_jobs=6)]: Done  45 out of  50 | elapsed:   24.6s remaining:    2.6s
[Parallel(n_jobs=6)]: Done  50 out of  50 | elapsed:   26.1s finished
[Parallel(n_jobs=6)]: Using backend LokyBackend with 6 concurrent workers.
[Parallel(n_jobs=6)]: Done   1 tasks      | elapsed:    1.9s
[Parallel(n_jobs=6)]: Done   6 tasks      | elapsed:    2.1s
[Parallel(n_jobs=6)]: Done  13 tasks      | elapsed:    6.3s
[Parallel(n_jobs=6)]: Done  20 tasks      | elapsed:    8.7s
[Parallel(n_jobs=6)]: Done  29 tasks      | elapsed:   11.0s
[Parallel(n_jobs=6)]: Done  3

In [71]:
import matplotlib.pyplot as plt
import seaborn as sns
# %matplotlib qt
n_runs = 50
n_episodes = 8
models = ['blobs_asha_8episodes_interleaved_ojaall_1ctx','blobs_asha_8episodes_blocked_ojaall_1ctx_1']

# acc
f1, axs1 = plt.subplots(2,1,figsize=(2.7,3),dpi=300)
# # unit alloc
f2, axs2 = plt.subplots(2,1,figsize=(2.7,3),dpi=300)
# # context corr 
f3, axs3 = plt.subplots(2,1,figsize=(2.7,3),dpi=300)
# # choice matrices 
f4, axs4 = plt.subplots(2,2,figsize=(5,5),dpi=300)
# # hidden layer MDS, interleaved
# f5a, axs5a = plt.subplots(1,2,figsize=(10,3),dpi=300)
# # hidden layer MDS, blocked
# f5a, axs5a = plt.subplots(1,2,figsize=(10,3),dpi=300)


for i,m in enumerate(models):
    t_a = np.empty((n_runs,n_episodes))
    t_b = np.empty((n_runs,n_episodes))
    t_d = np.empty((n_runs,n_episodes))
    t_mixed = np.empty((n_runs,n_episodes))
    acc_1st = np.empty((n_runs,n_episodes))
    acc_2nd = np.empty((n_runs,n_episodes))
    contextcorr = np.empty((n_runs,n_episodes))
    cmats_a = []
    cmats_b = []

    for r in range(n_runs):
        with open('../checkpoints/'+m+'/run_'+str(r)+'/results.pkl', 'rb') as f:
            results = pickle.load(f)
            
            # accuracy:
            acc_1st[r,:] = results['acc_1st']
            acc_2nd[r,:] = results['acc_2nd']
            # task factorisation:
            t_a[r,:] = results['n_only_b_regr']/100
            t_b[r,:] = results['n_only_a_regr']/100
            t_d[r,:] = results['n_dead']/100
            t_mixed[r,:] = 1-t_a[r,:]-t_b[r,:]-t_d[r,:]
            # context correlation:
            contextcorr[r,:] = results['w_context_corr']
            cc = np.clip(results['all_y_out'][1,:], -709.78, 709.78).astype(np.float64)
            choices = 1/(1+np.exp(-cc))
            cmats_a.append(choices[:25].reshape(5,5))
            cmats_b.append(choices[25:].reshape(5,5))
            
    cmats_a = np.array(cmats_a)
    cmats_b = np.array(cmats_b)
    
    # accuracy
    axs1[i].plot(np.arange(n_episodes),acc_1st.mean(0),color='orange')
    axs1[i].fill_between(np.arange(n_episodes),acc_1st.mean(0)-np.std(acc_1st,0)/np.sqrt(n_runs),acc_1st.mean(0)+np.std(acc_1st,0)/np.sqrt(n_runs),alpha=0.5,color='orange',edgecolor=None)
    axs1[i].plot(np.arange(n_episodes),acc_2nd.mean(0),color='blue')
    axs1[i].fill_between(np.arange(n_episodes),acc_2nd.mean(0)-np.std(acc_2nd,0)/np.sqrt(n_runs),acc_2nd.mean(0)+np.std(acc_2nd,0)/np.sqrt(n_runs),alpha=0.5,color='blue',edgecolor=None)
    axs1[i].set_ylim([0.4,1.05])
    axs1[i].set(xlabel='trial', ylabel='accuracy')
    axs1[i].legend(['1st task','2nd task'],frameon=False)
    if 'interleaved' not in m:
        axs1[i].plot([n_episodes/2, n_episodes/2],[0,1],'k--',alpha=0.5)
    axs1[i].set_title(m.split('_')[1])
    plt.gcf()
    sns.despine(f1)
    f1.tight_layout()

    # unit allocation (task factorisation)
    axs2[i].plot(np.arange(n_episodes),t_b.mean(0),color='orange')
    axs2[i].fill_between(np.arange(n_episodes),t_b.mean(0)-np.std(t_b,0)/np.sqrt(n_runs),t_b.mean(0)+np.std(t_b,0)/np.sqrt(n_runs),alpha=0.5,color='orange',edgecolor=None)
    axs2[i].plot(np.arange(n_episodes),t_a.mean(0),color='blue')
    axs2[i].fill_between(np.arange(n_episodes),t_a.mean(0)-np.std(t_a,0)/np.sqrt(n_runs),t_a.mean(0)+np.std(t_a,0)/np.sqrt(n_runs),alpha=0.5,color='blue',edgecolor=None)    
    axs2[i].set_yticks([0,0.5,1])
    ticks = axs2[i].get_yticks()#plt.yticks()
    axs2[i].set_yticklabels((int(x) for x in ticks*100))
    axs2[i].set(xlabel='trial',ylabel='task-sel (%)')
    axs2[i].legend(['1st task','2nd task'],frameon=False)
    if 'interleaved' not in m:
        axs2[i].plot([n_episodes/2, n_episodes/2],[0,1],'k--',alpha=0.5)
    axs2[i].set_title(m.split('_')[1])
    plt.gcf()
    sns.despine(f2)
    axs2[i].set_ylim([0,1.05])    
    f2.tight_layout()

    # context corr 
    axs3[i].plot(np.arange(n_episodes),contextcorr.mean(0),color='k')    
    axs3[i].fill_between(np.arange(n_episodes),contextcorr.mean(0)-np.std(contextcorr,0)/np.sqrt(n_runs),contextcorr.mean(0)+np.std(contextcorr,0)/np.sqrt(n_runs),alpha=0.5,color='magenta',edgecolor=None)
    
    axs3[i].set_ylim([-1.1,1.05])
    axs3[i].set(xlabel='trial',ylabel=r'$w_{context}$ corr ')    
    if 'interleaved' not in m:
        axs3[i].plot([n_episodes/2, n_episodes/2],[-1,1],'k--',alpha=0.5)
    axs3[i].set_title(m.split('_')[1])
    sns.despine(f3)
    f3.tight_layout()


    # choice matrices 
    
    axs4[i,0].imshow(cmats_a.mean(0))
    axs4[i,0].set_title('1st task')
    axs4[i,0].set(xticks=[0,2,4],yticks=[0,2,4],xlabel='irrel',ylabel='rel')
    axs4[i,1].imshow(cmats_b.mean(0))
    axs4[i,1].set(xticks=[0,2,4],yticks=[0,2,4],xlabel='rel',ylabel='irrel')
    axs4[i,1].set_title('2nd task')
    # PCM=axs4[i,1].get_children()[-2] #get the mappable, the 1st and the 2nd are the x and y axes
    
    # plt.subplots_adjust(bottom=0.1, right=0.8, top=0.9)
    # cax = plt.axes([0.85, 0.1, 0.075, 0.8])
    # plt.colorbar(PCM,cax=cax)      


    # hidden layer MDS 

f1.tight_layout()
f2.tight_layout()
f3.tight_layout()
f4.tight_layout()



In [63]:
acc_2nd[:,-1]

array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])