In [1]:
from IPython.display import display, HTML
# display(HTML("<style>.container { width:100% !important; }</style>"))

import numpy as np
from bokeh.plotting import figure, show, output_file, save
from bokeh.layouts import gridplot
from bokeh.io import output_notebook
from bokeh.palettes import Viridis, Category10, Category20
from bokeh.io import export_svg
from tqdm import tqdm
output_notebook()

In [2]:
from metamod.tasks import TaskModulation, MNIST
from metamod.trainers import task_mod_training
from metamod.networks import LinearNet
from metamod.control import LinearNetTaskModEq, LinearNetTaskModControl
from metamod.utils import plot_lines

In [3]:
run_name = "task_mod"
results_path = "../results"
results_dict = {}

In [4]:
dataset_params1 = {"batch_size": 256,
                   "new_shape": (5, 5),
                   "subset": (0, 1)}
dataset_params2 = {"batch_size": 256,
                   "new_shape": (5, 5),
                   "subset": (7, 1)}
dataset_params3 = {"batch_size": 256,
                   "new_shape": (5, 5),
                   "subset": (8, 9)}

dataset_params = {"dataset_classes": (MNIST, MNIST, MNIST),
                  "dataset_list_params": (dataset_params1, dataset_params2, dataset_params3)}

dataset_class = TaskModulation

model_params = {"learning_rate": 1e-3,
                "hidden_dim": 40,
                "intrinsic_noise": 0.0,
                "reg_coef": 0.0,
                "W1_0": None,
                "W2_0": None}

control_lr = 1.0
iter_control = 10
n_steps = 5000
save_weights_every = 20

In [5]:
dataset = dataset_class(**dataset_params)
model_params["input_dim"] = dataset.input_dim
model_params["output_dim"] = dataset.output_dim



In [6]:
engage_coefficients = np.ones((n_steps, len(dataset.datasets)))  # (t, phis)
#engage_coefficients[:, 1] = np.linspace(0, 1, num=engage_coefficients.shape[0])*0.5
#engage_coefficients[:, 0] = np.linspace(1, 0, num=engage_coefficients.shape[0])*0.5

model = LinearNet(**model_params)

In [7]:
iters, loss, weights_iter, weights = task_mod_training(model=model, 
                                                            dataset=dataset, 
                                                            n_steps=n_steps, 
                                                            save_weights_every=save_weights_every,
                                                            engagement_coefficients=engage_coefficients)

results_dict["iters"] = iters
results_dict["Loss_t_sim"] = loss
results_dict["weights_sim"] = weights
results_dict["weights_iters_sim"] = weights_iter

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5000/5000 [00:05<00:00, 985.17it/s]


In [8]:
init_W1 = weights[0][0, ...]
init_W2 = weights[1][0, ...]

init_weights = [init_W1, init_W2]
input_corr, output_corr, input_output_corr, expected_y, expected_x = dataset.get_correlation_matrix()

time_span = np.arange(0, len(iters))*model_params["learning_rate"]
results_dict["time_span"] = time_span

In [9]:
equation_params = {"in_cov": input_corr,
                   "out_cov": output_corr,
                   "in_out_cov": input_output_corr,
                   "expected_y": expected_y,
                   "expected_x": expected_x,
                   "init_weights": init_weights,
                   "n_steps": n_steps,
                   "reg_coef": model_params["reg_coef"],
                   "intrinsic_noise": model_params["intrinsic_noise"],
                   "learning_rate": model_params["learning_rate"],
                   "time_constant": 1.0}

In [10]:
solver = LinearNetTaskModEq(**equation_params)

In [11]:
control_params = {**equation_params,
                  "control_lower_bound": 0.0,
                  "control_upper_bound": 2.0,
                  "gamma": 0.99,
                  "cost_coef": 0.1,
                  "reward_convertion": 1.0,
                  "init_g": None,
                  "control_lr": 0.02}

In [12]:
control = LinearNetTaskModControl(**control_params)

In [13]:
W1_t, W2_t = solver.get_weights(time_span, get_numpy=True)
Loss_t = solver.get_loss_function(W1_t, W2_t, get_numpy=True)

results_dict["W1_t_eq"] = W1_t
results_dict["W2_t_eq"] = W2_t
results_dict["Loss_t_eq"] = Loss_t

In [14]:
W1_t_control, W2_t_control = control.get_weights(time_span, get_numpy=True)
Loss_t_control = control.get_loss_function(W1_t_control, W2_t_control, get_numpy=True)

results_dict["W1_t_control_init"] = W1_t_control
results_dict["W2_t_control_init"] = W2_t_control
results_dict["Loss_t_control_init"] = Loss_t_control
results_dict["init_engagement_coef"] = control.engagement_coef.detach()

In [15]:
losses = (loss, Loss_t, Loss_t_control)
colors = (Category10[10][0], Category10[10][0], Category10[10][1])
legends = ("Real Non-linear", "Approximation", "Init Control")
alphas = (0.3, 1, 1)

s = plot_lines(iters, losses, legends, alphas, colors)
show(s)

In [16]:
iter_control = 3
control_params["iters_control"] = iter_control
cumulated_reward = []

In [17]:
for i in tqdm(range(iter_control)):
    R = control.train_step(get_numpy=True)
    # print("cumulated reward:", R)
    cumulated_reward.append(R)
cumulated_reward = np.array(cumulated_reward).astype(float)
results_dict["cumulated_reward_opt"] = cumulated_reward

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:45<00:00, 15.19s/it]


In [18]:
opt = plot_lines(np.arange(iter_control), (cumulated_reward,), x_axis_label="gradient steps on control", y_axis_label="Cumulated reward")
show(opt)

In [19]:
W1_t_opt, W2_t_opt = control.get_weights(time_span, get_numpy=True)
Loss_t_opt = control.get_loss_function(W1_t_opt, W2_t_opt, get_numpy=True)

results_dict["W1_t_control_opt"] = W1_t_opt
results_dict["W2_t_control_opt"] = W2_t_opt
results_dict["Loss_t_control_opt"] = Loss_t_opt

In [20]:
losses = (loss, Loss_t, Loss_t_opt)
colors = (Category10[10][0], Category10[10][0], Category10[10][1])
legends = ("Real Non-linear", "Approximation", "Approximated Optimized Control")
alphas = (0.3, 1, 1)

s = plot_lines(iters, losses, legends, alphas, colors)
show(s)

In [21]:
control.engagement_coef.shape

torch.Size([5000, 3])

In [22]:
phis = [control.engagement_coef[:, i].detach().cpu().numpy() for i in range(control.engagement_coef.shape[1])]
colors = [Category10[10][i] for i in range(control.engagement_coef.shape[1])]
legends = ["C"+str(i) for i in range(control.engagement_coef.shape[1])]

s = plot_lines(iters, phis, labels=legends)
show(s)

In [23]:
W1_0, W2_0 = control_params["init_weights"]

reset_model_params = model_params.copy()
reset_model_params["W1_0"] = W1_0
reset_model_params["W2_0"] = W2_0

reset_model = LinearNet(**reset_model_params)

In [24]:
iters, loss_OPT, weights_iter_OPT, weights_OPT = task_mod_training(model=reset_model, dataset=dataset,
                                                                           n_steps=n_steps,
                                                                           save_weights_every=save_weights_every,
                                                                           engagement_coefficients=control.engagement_coef.detach().cpu().numpy())W1_0, W2_0 = control_params["init_weights"]

reset_model_params = model_params.copy()
reset_model_params["W1_0"] = W1_0
reset_model_params["W2_0"] = W2_0

reset_model = LinearNet(**reset_model_params)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5000/5000 [00:04<00:00, 1078.91it/s]


In [25]:
losses = (loss, Loss_t, Loss_t_opt, loss_OPT)
colors = (Category10[10][0], Category10[10][0], Category10[10][1], Category10[10][1])
legends = ("Real Non-linear", "Approximation", "Approximated Optimized Control", "Optimized Non-linear")
alphas = (0.3, 1, 1, 0.3)

s = plot_lines(iters, losses, legends, alphas, colors)
show(s)