In [1]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import numpy as np
from bokeh.plotting import figure, show, output_file, save
from bokeh.layouts import gridplot
from bokeh.io import output_notebook
from bokeh.palettes import Viridis, Category10, Category20
from bokeh.io import export_svg
from tqdm import tqdm
output_notebook()

In [2]:
from neuromod.tasks import TaskSwitch, AffineCorrelatedGaussian
from neuromod.trainers import two_layer_training
from neuromod.networks import LinearNet
from neuromod.control import TaskSwitchLinearNetEq, TaskSwitchLinearNetControl
from neuromod.utils import plot_lines

In [3]:
run_name = "task_switch_test"
results_path = "../results"

n_steps = 21000
save_weights_every = 20
iter_control = 50

results_dict = {}

# Init dataset
batch_size = 2048
dataset1_params = {"mu_vec": (3.0, 1.0), "sigma_vec": (1.0, 1.0), "dependence_parameter": 0.8, "batch_size": batch_size}
dataset2_params = {"mu_vec": (-2.0, 2.0), "sigma_vec": (1.0, 1.0), "dependence_parameter": 0.2, "batch_size": batch_size}
dataset_params = {"dataset1_params": dataset1_params, 
                  "dataset2_params": dataset2_params, 
                  "change_tasks_every": 1500}

model_params = {"learning_rate": 5e-3,
                "hidden_dim": 6,
                "intrinsic_noise": 0.00,
                "reg_coef": 0.01,
                "W1_0": None,
                "W2_0": None}

control_params = {"control_lower_bound": -0.5,
                  "control_upper_bound": 0.5,
                  "gamma": 0.99,
                  "cost_coef": 0.3,
                  "reward_convertion": 1.0,
                  "init_g": None,
                  "control_lr": 1.0}

In [4]:
dataset = TaskSwitch(dataset_classes=(AffineCorrelatedGaussian, AffineCorrelatedGaussian),
                     dataset_list_params=(dataset1_params, dataset2_params),
                     change_tasks_every=dataset_params["change_tasks_every"])

model_params["input_dim"] = dataset.input_dim
model_params["output_dim"] = dataset.output_dim

In [5]:
model = LinearNet(**model_params)

In [6]:
iters, loss, weights_iter, weights = two_layer_training(model=model, dataset=dataset, n_steps=n_steps,
                                                        save_weights_every=save_weights_every)

results_dict["iters"] = iters
results_dict["Loss_t_sim"] = loss
results_dict["weights_sim"] = weights
results_dict["weights_iters_sim"] = weights_iter

100%|██████████| 21000/21000 [00:21<00:00, 972.30it/s]


In [7]:
# Solving equation
init_W1 = weights[0][0, ...]
init_W2 = weights[1][0, ...]

init_weights = [init_W1, init_W2]
input_corr, output_corr, input_output_corr, expected_y, expected_x = dataset.get_correlation_matrix()

time_span = np.arange(0, len(iters)) * model_params["learning_rate"]
results_dict["time_span"] = time_span

equation_params = {"in_cov": input_corr,
                   "out_cov": output_corr,
                   "in_out_cov": input_output_corr,
                   "init_weights": init_weights,
                   "n_steps": n_steps,
                   "reg_coef": model_params["reg_coef"],
                   "intrinsic_noise": model_params["intrinsic_noise"],
                   "learning_rate": model_params["learning_rate"],
                   "change_task_every": dataset_params["change_tasks_every"],
                   "time_constant": 1.0}

solver = TaskSwitchLinearNetEq(**equation_params)

control_params = {**control_params, **equation_params}
control = TaskSwitchLinearNetControl(**control_params)

W1_t, W2_t = solver.get_weights(time_span, get_numpy=True)
Loss_t = solver.get_loss_function(W1_t, W2_t, get_numpy=True)

results_dict["W1_t_eq"] = W1_t
results_dict["W2_t_eq"] = W2_t
results_dict["Loss_t_eq"] = Loss_t

W1_t_control, W2_t_control = control.get_weights(time_span, get_numpy=True)
Loss_t_control = control.get_loss_function(W1_t_control, W2_t_control, get_numpy=True)

In [8]:
control.control_lr

1.0

In [9]:
losses = (loss, Loss_t, Loss_t_control)
colors = (Category10[10][0], Category10[10][0], Category10[10][1])
legends = ("Simulation", "Equation", "Init Control")
alphas = (0.3, 1, 1)

s = plot_lines(iters, losses, legends, alphas, colors)
show(s)

In [10]:
results_dict["W1_t_control_init"] = W1_t_control
results_dict["W2_t_control_init"] = W2_t_control
results_dict["Loss_t_control_init"] = Loss_t_control
results_dict["control_signal_init"] = (control.g1_tilda, control.g2_tilda)

control_params["iters_control"] = iter_control
cumulated_reward = []

for i in range(iter_control):
    R = control.train_step(get_numpy=True)
    print(R)
    cumulated_reward.append(R)
cumulated_reward = np.array(cumulated_reward).astype(float)
results_dict["cumulated_reward_opt"] = cumulated_reward

W1_t_opt, W2_t_opt = control.get_weights(time_span, get_numpy=True)
Loss_t_opt = control.get_loss_function(W1_t_opt, W2_t_opt, get_numpy=True)

results_dict["W1_t_control_opt"] = W1_t_opt
results_dict["W2_t_control_opt"] = W2_t_opt
results_dict["Loss_t_control_opt"] = Loss_t_opt

-22.79712
-20.530104
-19.865623
-19.63999
-19.496126
-19.398994
-19.286839
-19.128704
-19.058655
-19.011015
-18.969429
-18.931427
-18.89523
-18.859995
-18.826458
-18.794996
-18.765272
-18.73426
-18.696527
-18.644112
-18.58095
-18.521383
-18.47025
-18.426533
-18.38807
-18.353477
-18.32208
-18.293457
-18.266666
-18.241844
-18.218454
-18.19647
-18.175497
-18.155762
-18.136946
-18.118979
-18.101658
-18.085268
-18.069223
-18.054026
-18.039392
-18.025326
-18.011475
-17.998333
-17.985533
-17.973045
-17.961136
-17.949425
-17.937895
-17.926891


In [11]:
opt = plot_lines(np.arange(iter_control), (cumulated_reward,), x_axis_label="gradient steps on control", y_axis_label="Cumulated reward")
show(opt)

In [12]:
losses = (loss, Loss_t, Loss_t_opt)
colors = (Category10[10][0], Category10[10][0], Category10[10][1])
legends = ("Real Non-linear", "Approximation", "Approximated Optimized Control")
alphas = (0.3, 1, 1)

s = plot_lines(iters, losses, legends, alphas, colors)
show(s)

In [13]:
g1_tilda = control.g1_tilda
g2_tilda = control.g2_tilda
control_signal = (g1_tilda, g2_tilda)
W1_0, W2_0 = control_params["init_weights"]
results_dict["control_signal"] = control_signal

In [15]:
reset_model_params = model_params.copy()
reset_model_params["W1_0"] = W1_0
reset_model_params["W2_0"] = W2_0

reset_model = LinearNet(**reset_model_params)

In [16]:
iters, loss_OPT, weights_iter_OPT, weights_OPT = two_layer_training(model=reset_model,
                                                                    dataset=dataset,
                                                                    n_steps=n_steps,
                                                                    save_weights_every=save_weights_every,
                                                                    control_signal=control_signal)
results_dict["Loss_t_sim_OPT"] = loss_OPT
results_dict["weights_sim_OPT"] = weights_OPT
results_dict["weights_iters_sim_OPT"] = weights_iter_OPT
results_dict["iters_OPT"] = iters

100%|██████████| 21000/21000 [00:28<00:00, 745.08it/s]


In [17]:
losses = (loss, Loss_t, Loss_t_opt, loss_OPT)
colors = (Category10[10][0], Category10[10][0], Category10[10][1], Category10[10][1])
legends = ("Real Non-linear", "Approximation", "Approximated Optimized Control", "Optimized Non-linear")
alphas = (0.3, 1, 1, 0.3)

s = plot_lines(iters, losses, legends, alphas, colors)
show(s)