In [1]:
from IPython.display import display, HTML
# display(HTML("<style>.container { width:100% !important; }</style>"))

import numpy as np
from bokeh.plotting import figure, show, output_file, save
from bokeh.layouts import gridplot
from bokeh.io import output_notebook
from bokeh.palettes import Viridis, Category10, Category20
from bokeh.io import export_svg
from tqdm import tqdm
output_notebook()

In [2]:
from metamod.tasks import TaskSwitch, AffineCorrelatedGaussian, CompositionOfTasks, SemanticTask, MNIST
from metamod.trainers import two_layer_training, two_layer_engage_training
from metamod.networks import LinearNet, LinearTaskEngNet
from metamod.control import TaskSwitchLinearNetEq, TaskSwitchLinearNetControl, LinearNetTaskEngEq, LinearNetTaskEngControl
from metamod.utils import plot_lines

## Baseline linear network

In [3]:
run_name = "composition_of_tasks"
results_path = "../results"
results_dict = {}

In [4]:
dataset_params = {"batch_size": 256,
                  "new_shape": (5, 5),
                  "subset": (0, 1)}

dataset = MNIST(**dataset_params)



In [5]:
model_params = {"learning_rate": 5e-3,
                "hidden_dim": 20,
                "intrinsic_noise": 0.0,
                "reg_coef": 0.0,
                "input_dim": dataset.input_dim,
                "output_dim": dataset.output_dim,
                "W1_0": None,
                "W2_0": None}

model = LinearNet(**model_params)

In [6]:
n_steps = 10000
save_weights_every = 20
iter_control = 100

iters, baseline_loss, weights_iter, weights = two_layer_training(model=model, dataset=dataset, n_steps=n_steps, save_weights_every=save_weights_every)

results_dict["iters"] = iters
results_dict["Loss_t_sim"] = baseline_loss
results_dict["weights_sim"] = weights
results_dict["weights_iters_sim"] = weights_iter

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:08<00:00, 1189.43it/s]


In [7]:
losses = (baseline_loss, )
colors = (Category10[10][0], )
legends = ("Baseline linear network",)
alphas = (1, )

s = plot_lines(iters, losses, legends, alphas, colors)
show(s)

## Composition of tasks test, same task but different code..

In [8]:
composition_dataset_params = {"dataset_classes": (MNIST,), 
                              "dataset_list_params": (dataset_params,)}

In [9]:
composition_dataset = CompositionOfTasks(**composition_dataset_params)

In [10]:
comp_model_params = {"learning_rate": 5e-3,
                "hidden_dim": 20,
                "intrinsic_noise": 0.0,
                "reg_coef": 0.0,
                "input_dim": composition_dataset.input_dim,
                "output_dim": composition_dataset.output_dim,
                "W1_0": weights[0][0, ...],
                "W2_0": weights[1][0, ...],
                "task_output_index": composition_dataset.task_output_index}

engage_coefficients = np.ones((n_steps, 1))  # (t, phis)

comp_model = LinearTaskEngNet(**comp_model_params)

In [11]:
iters, comp_loss, weights_iter, weights = two_layer_engage_training(model=comp_model, 
                                                                    dataset=composition_dataset, 
                                                                    n_steps=n_steps, 
                                                                    save_weights_every=save_weights_every,
                                                                    engagement_coefficients=engage_coefficients)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:12<00:00, 799.57it/s]


In [12]:
losses = (baseline_loss, comp_loss)
colors = (Category10[10][0], Category10[10][1])
legends = ("Baseline linear network", "compositional task")
alphas = (0.5, 0.5)

s = plot_lines(iters, losses, legends, alphas, colors)
show(s)

In [13]:
results_dict["iters"] = iters
results_dict["Loss_t_sim"] = comp_loss
results_dict["weights_sim"] = weights
results_dict["weights_iters_sim"] = weights_iter

init_W1 = weights[0][0, ...]
init_W2 = weights[1][0, ...]

init_weights = [init_W1, init_W2]
input_corr, output_corr, input_output_corr, expected_y, expected_x = composition_dataset.get_correlation_matrix()

time_span = np.arange(0, len(iters))*comp_model_params["learning_rate"]
results_dict["time_span"] = time_span

In [14]:
equation_params = {"in_cov": input_corr,
                   "out_cov": output_corr,
                   "in_out_cov": input_output_corr,
                   "expected_y": expected_y,
                   "expected_x": expected_x,
                   "init_weights": init_weights,
                   "n_steps": n_steps,
                   "reg_coef": comp_model_params["reg_coef"],
                   "intrinsic_noise": comp_model_params["intrinsic_noise"],
                   "learning_rate": comp_model_params["learning_rate"],
                   "time_constant": 1.0,
                   "task_output_index": composition_dataset.task_output_index,
                   "task_input_index": composition_dataset.task_input_index,
                   "engagement_coef": engage_coefficients}

In [15]:
solver = LinearNetTaskEngEq(**equation_params)

In [16]:
W1_t, W2_t = solver.get_weights(time_span, get_numpy=True)
Loss_t = solver.get_loss_function(W1_t, W2_t, get_numpy=True)

In [17]:
losses = (baseline_loss, comp_loss, Loss_t)
colors = (Category10[10][0], Category10[10][1], Category10[10][1])
legends = ("Baseline linear network", "compositional task", "Learning dynamics equation")
alphas = (0.5, 0.5, 1.0)

s = plot_lines(iters, losses, legends, alphas, colors)
show(s)

## Trying different tasks

In [18]:
dataset_params1 = dataset_params.copy()
dataset_params2 = {"batch_size": 256,
                  "new_shape": (5, 5),
                  "subset": (1, 7)}
dataset_params3 = {"batch_size": 256,
                   "new_shape": (5, 5),
                   "subset": (8, 9)}
composition_dataset_params = {"dataset_classes": (MNIST, MNIST, MNIST), 
                              "dataset_list_params": (dataset_params1, dataset_params2, dataset_params3)}

In [19]:
composition_dataset = CompositionOfTasks(**composition_dataset_params)

In [20]:
comp_model_params = {"learning_rate": 5e-3,
                "hidden_dim": 20,
                "intrinsic_noise": 0.0,
                "reg_coef": 0.0,
                "input_dim": composition_dataset.input_dim,
                "output_dim": composition_dataset.output_dim,
                "W1_0": None,
                "W2_0": None,
                "task_output_index": composition_dataset.task_output_index}

engage_coefficients = np.ones((n_steps, len(composition_dataset.datasets)))  # (t, phis)
#engage_coefficients[:, 1] = np.linspace(0, 1, num=engage_coefficients.shape[0])*0.5
#engage_coefficients[:, 0] = np.linspace(1, 0, num=engage_coefficients.shape[0])*0.5

comp_model = LinearTaskEngNet(**comp_model_params)

In [21]:
iters, comp_loss, weights_iter, weights = two_layer_engage_training(model=comp_model, 
                                                                    dataset=composition_dataset, 
                                                                    n_steps=n_steps, 
                                                                    save_weights_every=save_weights_every,
                                                                    engagement_coefficients=engage_coefficients)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:21<00:00, 462.52it/s]


In [22]:
losses = (baseline_loss, comp_loss)
colors = (Category10[10][0], Category10[10][1])
legends = ("Baseline linear network", "compositional task")
alphas = (0.5, 0.5)

s = plot_lines(iters, losses, legends, alphas, colors)
show(s)

In [23]:
composition_dataset.task_output_index, composition_dataset.task_input_index

([(0, 2), (2, 4), (4, 6)], [(0, 26), (26, 52), (52, 78)])

In [24]:
results_dict["iters"] = iters
results_dict["Loss_t_sim"] = comp_loss
results_dict["weights_sim"] = weights
results_dict["weights_iters_sim"] = weights_iter

init_W1 = weights[0][0, ...]
init_W2 = weights[1][0, ...]

init_weights = [init_W1, init_W2]
input_corr, output_corr, input_output_corr, expected_y, expected_x = composition_dataset.get_correlation_matrix()

time_span = np.arange(0, len(iters))*comp_model_params["learning_rate"]
results_dict["time_span"] = time_span

In [25]:
equation_params = {"in_cov": input_corr,
                   "out_cov": output_corr,
                   "in_out_cov": input_output_corr,
                   "expected_y": expected_y,
                   "expected_x": expected_x,
                   "init_weights": init_weights,
                   "n_steps": n_steps,
                   "reg_coef": comp_model_params["reg_coef"],
                   "intrinsic_noise": comp_model_params["intrinsic_noise"],
                   "learning_rate": comp_model_params["learning_rate"],
                   "time_constant": 1.0,
                   "task_output_index": composition_dataset.task_output_index,
                   "task_input_index": composition_dataset.task_input_index,
                   "engagement_coef": engage_coefficients}

In [26]:
solver = LinearNetTaskEngEq(**equation_params)

In [27]:
control_params = {"control_lower_bound": 0.0,
                  "control_upper_bound": 1.0,
                  "gamma": 0.99,
                  "cost_coef": 0.3,
                  "reward_convertion": 1.0,
                  "init_g": None,
                  "control_lr": 1.0}

In [28]:
control_params = {**control_params, **equation_params}
control = LinearNetTaskEngControl(**control_params)

W1_t, W2_t = solver.get_weights(time_span, get_numpy=True)
Loss_t = solver.get_loss_function(W1_t, W2_t, get_numpy=True)

results_dict["W1_t_eq"] = W1_t
results_dict["W2_t_eq"] = W2_t
results_dict["Loss_t_eq"] = Loss_t

W1_t_control, W2_t_control = control.get_weights(time_span, get_numpy=True)
Loss_t_control = control.get_loss_function(W1_t_control, W2_t_control, get_numpy=True)

In [29]:
losses = (comp_loss, Loss_t, Loss_t_control)
colors = (Category10[10][1], Category10[10][2], Category10[10][3])
legends = ("compositional task", "Learning dynamics equation", "Control equation")
alphas = (0.5, 1.0, 1.0)

s = plot_lines(iters, losses, legends, alphas, colors)
show(s)

In [30]:
control.engagement_coef.shape

torch.Size([10000, 3])

In [30]:
results_dict["W1_t_control_init"] = W1_t_control
results_dict["W2_t_control_init"] = W2_t_control
results_dict["Loss_t_control_init"] = Loss_t_control
results_dict["engagement_coef"] = control.engagement_coef

iter_control = 10
control_params["iters_control"] = iter_control
cumulated_reward = []

for i in tqdm(range(iter_control)):
    R = control.train_step(get_numpy=True)
    # print(R)
    cumulated_reward.append(R)
cumulated_reward = np.array(cumulated_reward).astype(float)
results_dict["cumulated_reward_opt"] = cumulated_reward

W1_t_opt, W2_t_opt = control.get_weights(time_span, get_numpy=True)
Loss_t_opt = control.get_loss_function(W1_t_opt, W2_t_opt, get_numpy=True)

results_dict["W1_t_control_opt"] = W1_t_opt
results_dict["W2_t_control_opt"] = W2_t_opt
results_dict["Loss_t_control_opt"] = Loss_t_opt

100%|████████████████████████████████████████████████████████████████████| 10/10 [03:52<00:00, 23.20s/it]


In [31]:
opt = plot_lines(np.arange(iter_control), (cumulated_reward,), x_axis_label="gradient steps on control", y_axis_label="Cumulated reward")
show(opt)

In [32]:
W1_t_opt, W2_t_opt = control.get_weights(time_span, get_numpy=True)
Loss_t_opt = control.get_loss_function(W1_t_opt, W2_t_opt, get_numpy=True)

results_dict["W1_t_control_opt"] = W1_t_opt
results_dict["W2_t_control_opt"] = W2_t_opt
results_dict["Loss_t_control_opt"] = Loss_t_opt

In [33]:
control.engagement_coef

tensor([[0.9962, 0.9873, 0.9881],
        [0.9961, 0.9873, 0.9881],
        [0.9961, 0.9874, 0.9881],
        ...,
        [0.9820, 0.9820, 0.9820],
        [0.9820, 0.9820, 0.9820],
        [0.9820, 0.9820, 0.9820]], device='cuda:0', requires_grad=True)

In [34]:
losses = (comp_loss, Loss_t, Loss_t_opt)
colors = (Category10[10][0], Category10[10][0], Category10[10][1])
legends = ("Simulation", "Equation", "Optimized Control")
alphas = (0.3, 1, 1)

s = plot_lines(iters, losses, legends, alphas, colors)
show(s)

In [35]:
phis = (control.engagement_coef[:, 0].detach().cpu().numpy(), control.engagement_coef[:, 1].detach().cpu().numpy())
colors = (Category10[10][0], Category10[10][1])
legends = ("Task1", "Task2")
alphas=(1, 1)
s = plot_lines(iters, phis)
show(s)