In [1]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import numpy as np
from bokeh.plotting import figure, show, output_file, save
from bokeh.layouts import gridplot
from bokeh.io import output_notebook
from bokeh.palettes import Viridis, Category10, Category20
from bokeh.io import export_svg
from tqdm import tqdm
output_notebook()

In [2]:
from metamod.tasks import TaskSwitch, AffineCorrelatedGaussian
from metamod.trainers import two_layer_training
from metamod.networks import LinearNet
from metamod.control import TaskSwitchLinearNetEq, TaskSwitchLinearNetControl
from metamod.utils import plot_lines

In [3]:
n_steps = 20000
save_weights_every = 20
iter_control = 50

results_dict = {}

# Init dataset
batch_size = 2048
dataset1_params = {"mu_vec": (3.0, 1.0), "sigma_vec": (1.0, 1.0), "dependence_parameter": 0.8, "batch_size": batch_size}
dataset2_params = {"mu_vec": (-2.0, 2.0), "sigma_vec": (1.0, 1.0), "dependence_parameter": 0.2, "batch_size": batch_size}
dataset_params = {"dataset1_params": dataset1_params, 
                  "dataset2_params": dataset2_params, 
                  "change_tasks_every": 2000}

model_params = {"learning_rate": 5e-3,
                "hidden_dim": 4,
                "intrinsic_noise": 0.00,
                "reg_coef": 0.0,
                "W1_0": None,
                "W2_0": None}

control_params = {"control_lower_bound": -0.5,
                  "control_upper_bound": 0.5,
                  "gamma": 0.99,
                  "cost_coef": 0.3,
                  "reward_convertion": 1.0,
                  "init_g": None,
                  "control_lr": 1.0}

In [4]:
dataset = TaskSwitch(dataset_classes=(AffineCorrelatedGaussian, AffineCorrelatedGaussian),
                     dataset_list_params=(dataset1_params, dataset2_params),
                     change_tasks_every=dataset_params["change_tasks_every"])

task1_solution = dataset.datasets[0].get_linear_regression_solution()
task2_solution = dataset.datasets[1].get_linear_regression_solution()
cov_matrix_task1 = dataset.datasets[0].get_correlation_matrix()
cov_matrix_task2 = dataset.datasets[1].get_correlation_matrix()

model_params["input_dim"] = dataset.input_dim
model_params["output_dim"] = dataset.output_dim

In [5]:
task1_solution
task2_solution

array([[-0.37006237,  0.06237006,  0.        ],
       [-0.06237006,  0.37006237,  0.        ]])

In [6]:
def get_loss_function(sol, out_cov, in_out_cov, in_cov):
    loss = 0.5*(np.trace(out_cov) - np.trace(2*in_out_cov @ sol) + np.trace(in_cov @ sol.T @ sol))
    return loss

In [7]:
best_loss_task1 = get_loss_function(task1_solution, cov_matrix_task1[1], cov_matrix_task1[2], cov_matrix_task1[0])
best_loss_task2 = get_loss_function(task2_solution, cov_matrix_task2[1], cov_matrix_task2[2], cov_matrix_task2[0])
print(best_loss_task1, best_loss_task2)

0.25059665871121706 0.18503118503118499


In [8]:
model = LinearNet(**model_params)

In [9]:
iters, loss, weights_iter, weights = two_layer_training(model=model, dataset=dataset, n_steps=n_steps,
                                                        save_weights_every=save_weights_every)

results_dict["iters"] = iters
results_dict["Loss_t_sim"] = loss
results_dict["weights_sim"] = weights
results_dict["weights_iters_sim"] = weights_iter

 17%|█████████████████████████████▏                                                                                                                                          | 3478/20000 [00:04<00:20, 807.23it/s]

KeyboardInterrupt



In [None]:
# Solving equation
init_W1 = weights[0][0, ...]
init_W2 = weights[1][0, ...]

init_weights = [init_W1, init_W2]
input_corr, output_corr, input_output_corr, expected_y, expected_x = dataset.get_correlation_matrix()

time_span = np.arange(0, len(iters)) * model_params["learning_rate"]
results_dict["time_span"] = time_span

equation_params = {"in_cov": input_corr,
                   "out_cov": output_corr,
                   "in_out_cov": input_output_corr,
                   "init_weights": init_weights,
                   "n_steps": n_steps,
                   "reg_coef": model_params["reg_coef"],
                   "intrinsic_noise": model_params["intrinsic_noise"],
                   "learning_rate": model_params["learning_rate"],
                   "change_task_every": dataset_params["change_tasks_every"],
                   "time_constant": 1.0}

solver = TaskSwitchLinearNetEq(**equation_params)

control_params = {**control_params, **equation_params}
control = TaskSwitchLinearNetControl(**control_params)

W1_t, W2_t = solver.get_weights(time_span, get_numpy=True)
Loss_t = solver.get_loss_function(W1_t, W2_t, get_numpy=True)

results_dict["W1_t_eq"] = W1_t
results_dict["W2_t_eq"] = W2_t
results_dict["Loss_t_eq"] = Loss_t

W1_t_control, W2_t_control = control.get_weights(time_span, get_numpy=True)
Loss_t_control = control.get_loss_function(W1_t_control, W2_t_control, get_numpy=True)

In [None]:
best_loss1 = np.ones(iters.shape)*best_loss_task1
best_loss2 = np.ones(iters.shape)*best_loss_task2

losses = (loss, Loss_t, Loss_t_control, best_loss1, best_loss2)
colors = (Category10[10][0], Category10[10][0], Category10[10][1], Category10[10][2], Category10[10][3])
legends = ("Simulation", "Equation", "Init Control", "best task1", "best task2")
alphas = (0.3, 1, 1, 0.5, 0.5)

s = plot_lines(iters, losses, legends, alphas, colors)
show(s)

### Define the weights using task 1 solution, but with random Q

In [None]:
hidden_units = model_params["hidden_dim"]

In [None]:
u, s, vh = np.linalg.svd(task1_solution, full_matrices=True)

In [None]:
s = np.concatenate([np.diag(s), np.zeros((len(s), 1))], axis=1)
s_prime = np.concatenate([s, np.zeros((1, s.shape[1]))], axis=0)

In [None]:
task1_solution, u @ s @ vh, 

In [None]:
u.shape, s.shape, vh.shape, task1_solution.shape

In [None]:
# Q_matrix = np.random.uniform(low=0, high=1.0, size=(hidden_units, s.shape[1]))
Q_matrix = np.random.normal(size=(hidden_units, s.shape[1]))

In [None]:
Q_inverse = np.linalg.pinv(Q_matrix)

In [None]:
Q_matrix.shape, Q_inverse.shape

In [None]:
Q_inverse @ Q_matrix

In [None]:
W2_task_init = u @ np.sqrt(s) @ Q_inverse

In [None]:
W1_task_init = Q_matrix @ np.sqrt(s_prime).T @ vh

In [None]:
task1_solution, W2_task_init @ W1_task_init

In [None]:
n_steps = 5000
reset_model_params = {"learning_rate": model_params["learning_rate"],
                "hidden_dim": model_params["hidden_dim"],
                "intrinsic_noise": model_params["intrinsic_noise"],
                "reg_coef": model_params["reg_coef"],
                "W1_0": W1_task_init,
                "W2_0": W2_task_init}
reset_model_params["input_dim"] = dataset.input_dim
reset_model_params["output_dim"] = dataset.output_dim

reset_model = LinearNet(**reset_model_params)

In [None]:
reset_dataset_params = {"dataset1_params": dataset1_params, 
                        "dataset2_params": dataset2_params, 
                        "change_tasks_every": 500}
time_span = np.arange(0, n_steps) * model_params["learning_rate"]
reset_dataset = TaskSwitch(dataset_classes=(AffineCorrelatedGaussian, AffineCorrelatedGaussian),
                           dataset_list_params=(dataset1_params, dataset2_params),
                           change_tasks_every=reset_dataset_params["change_tasks_every"])

In [None]:
reset_iters, reset_loss, reset_weights_iter, reset_weights = two_layer_training(model=reset_model, dataset=reset_dataset, n_steps=n_steps,
                                                                                save_weights_every=save_weights_every)

In [None]:
init_weights = [W1_task_init, W2_task_init]
reset_equation_params = {"in_cov": input_corr,
                         "out_cov": output_corr,
                         "in_out_cov": input_output_corr,
                         "init_weights": init_weights,
                         "n_steps": n_steps,
                         "reg_coef": reset_model_params["reg_coef"],
                         "intrinsic_noise": reset_model_params["intrinsic_noise"],
                         "learning_rate": reset_model_params["learning_rate"],
                         "change_task_every": reset_dataset_params["change_tasks_every"],
                         "time_constant": 1.0}

solver = TaskSwitchLinearNetEq(**reset_equation_params)

In [None]:
reset_W1_t, reset_W2_t = solver.get_weights(time_span, get_numpy=True)
reset_Loss_t = solver.get_loss_function(reset_W1_t, reset_W2_t, get_numpy=True)

In [None]:
best_loss1 = np.ones(reset_iters.shape)*best_loss_task1
best_loss2 = np.ones(reset_iters.shape)*best_loss_task2

losses = (reset_loss, reset_Loss_t, best_loss1, best_loss2)
colors = (Category10[10][0], Category10[10][0], Category10[10][2], Category10[10][3])
legends = ("Simulation", "Equation", "best task1", "best task2")
alphas = (0.3, 1, 0.5, 0.5)

s = plot_lines(reset_iters, losses, legends, alphas, colors)
show(s)

In [None]:
def draw_trajectories(n_t=20):
    
    sim_losses = []
    eq_losses = []
    best_loss1_list = []
    best_loss2_list = []
    iter_list = []
    
    for i in range(n_t):
        u, s, vh = np.linalg.svd(task1_solution, full_matrices=True)
        s = np.concatenate([np.diag(s), np.zeros((len(s), 1))], axis=1)
        s_prime = np.concatenate([s, np.zeros((1, s.shape[1]))], axis=0)
        
        Q_matrix = np.random.normal(size=(hidden_units, s.shape[1]))
        Q_inverse = np.linalg.pinv(Q_matrix)
        W2_task_init = u @ np.sqrt(s) @ Q_inverse
        W1_task_init = Q_matrix @ np.sqrt(s_prime).T @ vh
        n_steps = 5000
        reset_model_params = {"learning_rate": model_params["learning_rate"],
                        "hidden_dim": model_params["hidden_dim"],
                        "intrinsic_noise": model_params["intrinsic_noise"],
                        "reg_coef": model_params["reg_coef"],
                        "W1_0": W1_task_init,
                        "W2_0": W2_task_init}
        reset_model_params["input_dim"] = dataset.input_dim
        reset_model_params["output_dim"] = dataset.output_dim

        reset_model = LinearNet(**reset_model_params)
        
        reset_dataset_params = {"dataset1_params": dataset1_params, 
                                "dataset2_params": dataset2_params, 
                                "change_tasks_every": 500}
        time_span = np.arange(0, n_steps) * model_params["learning_rate"]
        reset_dataset = TaskSwitch(dataset_classes=(AffineCorrelatedGaussian, AffineCorrelatedGaussian),
                                   dataset_list_params=(dataset1_params, dataset2_params),
                                   change_tasks_every=reset_dataset_params["change_tasks_every"])
        reset_iters, reset_loss, reset_weights_iter, reset_weights = two_layer_training(model=reset_model, dataset=reset_dataset, n_steps=n_steps,
                                                                                        save_weights_every=save_weights_every)
        init_weights = [W1_task_init, W2_task_init]
        reset_equation_params = {"in_cov": input_corr,
                                 "out_cov": output_corr,
                                 "in_out_cov": input_output_corr,
                                 "init_weights": init_weights,
                                 "n_steps": n_steps,
                                 "reg_coef": reset_model_params["reg_coef"],
                                 "intrinsic_noise": reset_model_params["intrinsic_noise"],
                                 "learning_rate": reset_model_params["learning_rate"],
                                 "change_task_every": reset_dataset_params["change_tasks_every"],
                                 "time_constant": 1.0}

        solver = TaskSwitchLinearNetEq(**reset_equation_params)
        
        reset_W1_t, reset_W2_t = solver.get_weights(time_span, get_numpy=True)
        reset_Loss_t = solver.get_loss_function(reset_W1_t, reset_W2_t, get_numpy=True)
        
        sim_losses.append(reset_loss)
        eq_losses.append(reset_Loss_t)
        best_loss1_list.append(best_loss1)
        best_loss2_list.append(best_loss2)
        iter_list.append(reset_iters)
        
    return sim_losses, eq_losses, best_loss1_list, best_loss2_list, iter_list

In [None]:
sim_losses, eq_losses, best_loss1_list, best_loss2_list, iter_list = draw_trajectories()

In [None]:
s = figure(x_axis_label="iters", y_axis_label="Loss", width=800, height=500)
for i in range(len(sim_losses)):
    s.line(iter_list[i], sim_losses[i], alpha=0.3)
    s.line(iter_list[i], eq_losses[i], alpha=1.0, color=Category10[10][1])
s.line(iter_list[0], best_loss1, color=Category10[10][2])
s.line(iter_list[0], best_loss2, color=Category10[10][3])
show(s)