# Scratchpad for paper revisions

In [None]:
%load_ext autoreload
%autoreload 2
import pickle
import os, sys
root_path = os.path.realpath('../')
sys.path.append(root_path)

import torch
from pathlib import Path

from utils.data import make_dataset
from utils.nnet import get_device

from hebbcl.logger import MetricLogger
from hebbcl.model import Nnet
from hebbcl.trainer import Optimiser, train_model
from hebbcl.parameters import parser
from hebbcl.tuner import HPOTuner

In [None]:
# obtain params
args = parser.parse_args(args=[])

# set checkpoint directory
save_dir = (
        Path("checkpoints") / "test_allhebb"
    ) 

# get device (gpu/cpu)
args.device = get_device(args.cuda)[0]
dict(sorted(vars(args).items(),key=lambda k: k[0]))


## hyperparameter optimisation
hpo on network trained with fewer episodes

In [None]:
args = parser.parse_args(args=[])
args.n_episodes = 8
args.deterministic = True
dict(sorted(vars(args).items(),key=lambda k: k[0]))

In [None]:
# init tuner
tuner = HPOTuner(args, time_budget=60*1, metric="acc")

In [None]:
# HPO on blocked with oja (all units)
tuner.tune(n_samples=500)

In [None]:
tuner.results[["config.lrate_sgd", "config.lrate_hebb","config.ctx_scaling","config.seed","mean_acc","mean_loss","done"]].head(15)

In [None]:
# tuner.results.sort_values("mean_acc",ascending=False).head(20)
df = tuner.results
df = df[["mean_loss", "mean_acc", "config.lrate_sgd","config.lrate_hebb", "config.ctx_scaling","done"]]
df = df[df["done"]==True]
df = df.drop(columns=["done"])
df = df.dropna()
df = df.sort_values("mean_acc",ascending=False)

df.reset_index()
df.head(15)

In [None]:
tuner.best_cfg

In [None]:
# with open("../results/raytune_oja_ctx_blocked_8episodes.pkl", "wb") as f:
#     pickle.dump(df, f)

### verify results

In [None]:
import numpy as np
import random
import torch
# obtain params
args = parser.parse_args(args=[])

# set checkpoint directory
save_dir = (
        Path("checkpoints") / "test_allhebb"
    ) 

# get device (gpu/cpu)
args.device = get_device(args.cuda)[0]

# override defaults 
args.n_episodes = 8
args.lrate_hebb = tuner.best_cfg["lrate_hebb"]
args.lrate_sgd = tuner.best_cfg["lrate_sgd"]
args.ctx_scaling = tuner.best_cfg["ctx_scaling"]

np.random.seed(tuner.best_cfg["seed"])
random.seed(tuner.best_cfg["seed"])
torch.manual_seed(tuner.best_cfg["seed"])


# create dataset 
dataset = make_dataset(args)

# instantiate logger, model and optimiser:
logger = MetricLogger(save_dir)
model = Nnet(args)
optimiser = Optimiser(args)

# send model to device (GPU?)
model = model.to(args.device)


# train model
train_model(args, model, optimiser, dataset, logger)

In [None]:
print(f"config: lrate_sgd: {args.lrate_sgd:.4f}, lrate_hebb: {args.lrate_hebb:.4f}, context offset: {args.ctx_scaling}")
print(f"terminal accuracy: {logger.acc_total[-1]:.2f}, loss: {logger.losses_total[-1]:.2f}")

In [None]:
df = tuner.results
