In [1]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

from torchdiffeq import odeint
import numpy as np
from tqdm import tqdm

from bokeh.plotting import figure, show, output_file, save
from bokeh.layouts import gridplot
from bokeh.io import output_notebook
from bokeh.palettes import Viridis, Category10, Category20
from bokeh.io import export_svg
output_notebook()

In [2]:
from neuromod.control import NonLinearEq, NonLinearControl
from neuromod.tasks import AffineCorrelatedGaussian
from neuromod.trainers import two_layer_training
from neuromod.networks import NonLinearNet

In [3]:
dataset_params = {"mu_vec": (3.0, 1.0), 
                  "batch_size": 1024, 
                  "dependence_parameter": 0.8, 
                  "sigma_vec":(1.0, 1.0)}

dataset = AffineCorrelatedGaussian(**dataset_params)

In [4]:
model_params = {"learning_rate": 1e-3,
                "hidden_dim": 4,
                "intrinsic_noise": 0.05,
                "reg_coef": 0.0,
                "input_dim": dataset.input_dim,
                "output_dim": dataset.output_dim,
                "W1_0": None,
                "W2_0": None}

model = NonLinearNet(**model_params)

In [5]:
n_steps = 6000
save_weights_every = 20

iters, loss, weights_iter, weights = two_layer_training(model=model, dataset=dataset, n_steps=n_steps, save_weights_every=save_weights_every)

100%|██████████| 6000/6000 [00:05<00:00, 1012.34it/s]


In [6]:
weights[0].shape

(300, 4, 3)

In [7]:
init_W1 = weights[0][0, ...]
init_W2 = weights[1][0, ...]

init_weights = [init_W1, init_W2]
input_corr, output_corr, input_output_corr, expected_y, expected_x = dataset.get_correlation_matrix()

time_span = np.arange(0, len(iters))*model_params["learning_rate"]

In [8]:
equation_params = {"in_cov": input_corr,
                   "out_cov": output_corr,
                   "in_out_cov": input_output_corr,
                   "expected_y": expected_y,
                   "expected_x": expected_x,
                   "init_weights": init_weights,
                   "n_steps": n_steps,
                   "reg_coef": model_params["reg_coef"],
                   "intrinsic_noise": model_params["intrinsic_noise"],
                   "learning_rate": model_params["learning_rate"],
                   "time_constant": 1.0}

In [9]:
solver = NonLinearEq(**equation_params)

In [10]:
control_params = {**equation_params,
                  "control_lower_bound": -0.5,
                  "control_upper_bound": 0.5,
                  "gamma": 0.99,
                  "cost_coef": 0.3,
                  "reward_convertion": 1.0,
                  "init_g": None,
                  "control_lr": 10.0}

In [11]:
control = NonLinearControl(**control_params)

In [12]:
sim_weights1 = weights[0]
sim_weights2 = weights[1]
print(sim_weights1.shape, sim_weights2.shape)

(300, 4, 3) (300, 2, 4)


In [13]:
W1_t, W2_t = solver.get_weights(time_span, get_numpy=True)
Loss_t = solver.get_loss_function(W1_t, W2_t, get_numpy=True)
# Loss_t = solver.get_loss_function(sim_weights1, sim_weights2, get_numpy=True)

In [14]:
W1_t_control, W2_t_control = control.get_weights(time_span, get_numpy=True)
Loss_t_control = control.get_loss_function(W1_t_control, W2_t_control, get_numpy=True)

In [15]:
s = figure(x_axis_label="iters", y_axis_label="Loss", width=800, height=500)
s.line(iters, loss, line_width=2,  alpha=0.3, legend_label="Real Non-linear")
s.line(iters, Loss_t, line_width=3, color=Category10[10][0], legend_label="Approximation")
s.line(iters, Loss_t_control, line_width=3, color=Category10[10][1], legend_label="Init Control")
show(s)

In [16]:
flat_W1_t = np.reshape(weights[0], (weights[0].shape[0], -1))
flat_eq_W1_t = np.reshape(W1_t, (W1_t.shape[0], -1))

flat_W2_t = np.reshape(weights[1], (weights[1].shape[0], -1))
flat_eq_W2_t = np.reshape(W2_t, (W2_t.shape[0], -1))

In [17]:
def plot_weight_ev(flat_W_t, flat_eq_W_t, sim_iters, eq_iters, title=""):
    weight_plot = figure(x_axis_label="iters", y_axis_label="Weights", title=title)
    for i in range(np.min([flat_W_t.shape[-1], 20])):
        if i == 0:
            weight_plot.line(sim_iters, flat_W_t[:, i], line_width=6, line_dash=(4, 4), alpha=0.5, color=Category20[20][i],
                             legend_label="Simulation")
            weight_plot.line(eq_iters, flat_eq_W_t[:, i], line_width=3, color=Category20[20][i],
                             legend_label="First order")
        else:
            weight_plot.line(sim_iters, flat_W_t[:, i], line_width=6, line_dash=(4, 4), alpha=0.5, color=Category20[20][i])
            weight_plot.line(eq_iters, flat_eq_W_t[:, i], line_width=3, color=Category20[20][i])
    weight_plot.legend.location = "bottom_right"
    # weight_plot.output_backend = "svg"
    return weight_plot

In [18]:
weight_plot1 = plot_weight_ev(flat_W1_t, flat_eq_W1_t, sim_iters=weights_iter, eq_iters=iters, title="W1")
weight_plot2 = plot_weight_ev(flat_W2_t, flat_eq_W2_t, sim_iters=weights_iter, eq_iters=iters, title="W2")
grid = gridplot([weight_plot1, weight_plot2], ncols=2, width=600, height=500)
show(grid)

## Optimizing control signal

In [19]:
iter_control = 10
cumulated_reward = []

In [20]:
for i in tqdm(range(iter_control)):
    R = control.train_step(get_numpy=True, lr=10.0)
    # print("cumulated reward:", R)
    cumulated_reward.append(R)
cumulated_reward = np.array(cumulated_reward).astype(float)

100%|██████████| 10/10 [02:01<00:00, 12.19s/it]


In [21]:
opt = figure(x_axis_label="gradient steps on control", y_axis_label="Cumulated reward", width=800, height=500)
opt.line(np.arange(iter_control), cumulated_reward, line_width=2)
show(opt)

In [22]:
W1_t_opt, W2_t_opt = control.get_weights(time_span, get_numpy=True)
Loss_t_opt = control.get_loss_function(W1_t_opt, W2_t_opt, get_numpy=True)

In [23]:
s = figure(x_axis_label="iters", y_axis_label="Loss", width=800, height=500)
s.line(iters, loss, line_width=2,  alpha=0.3, legend_label="Real Non-linear")
s.line(iters, Loss_t, line_width=3, color=Category10[10][0], legend_label="Approximation")
s.line(iters, Loss_t_opt, line_width=3, color=Category10[10][1], legend_label="Approximated Optimized Control")
show(s)

In [24]:
g1_tilda = control.g1_tilda
g2_tilda = control.g2_tilda
control_signal = (g1_tilda, g2_tilda)
W1_0, W2_0 = control_params["init_weights"]

In [25]:
g1_tilda.shape, g2_tilda.shape, W1_0.shape

(torch.Size([6000, 4, 3]), torch.Size([6000, 2, 4]), (4, 3))

In [26]:
model_params = {"learning_rate": 1e-3,
                "hidden_dim": 4,
                "intrinsic_noise": 0.05,
                "reg_coef": 0.0,
                "input_dim": dataset.input_dim,
                "output_dim": dataset.output_dim,
                "W1_0": W1_0,
                "W2_0": W2_0}

reset_model = NonLinearNet(**model_params)

In [27]:
iters, loss_OPT, weights_iter, weights = two_layer_training(model=reset_model, 
                                                            dataset=dataset, 
                                                            n_steps=n_steps, 
                                                            save_weights_every=save_weights_every,
                                                            control_signal=control_signal)

100%|██████████| 6000/6000 [00:07<00:00, 841.56it/s]


In [28]:
s = figure(x_axis_label="iters", y_axis_label="Loss", width=800, height=500)
s.line(iters, loss, line_width=2,  alpha=0.3, legend_label="Real Non-linear")
s.line(iters, Loss_t, line_width=3, color=Category10[10][0], legend_label="Approximation")
s.line(iters, Loss_t_opt, line_width=3, color=Category10[10][1], legend_label="Approximated Optimized Control")
s.line(iters, loss_OPT, line_width=3, color=Category10[10][1], alpha=0.3, legend_label="Optimized Non-linear")
show(s)