In [232]:
%load_ext autoreload
%autoreload 2

%cd /home/theo/PBO_project/PBO/
import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)

import json
import jax
import numpy as np
from experiments.lqr.utils import define_environment

experiment_name = "test"
max_bellman_iterations = 4

p = json.load(open(f"experiments/lqr/figures/{experiment_name}/parameters.json"))

env = define_environment(jax.random.PRNGKey(p["env_seed"]), p["max_discrete_state"])
g_mesh, i_mesh = np.meshgrid(np.linspace(-4 + env.optimal_weights[0], 4 + env.optimal_weights[0], 3), np.linspace(-4 + env.optimal_weights[1], 4 + env.optimal_weights[1], 4))
weights = np.stack((g_mesh, i_mesh), axis=-1).reshape((-1, 2))

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
/home/theo/PBO_project/PBO
Transition: s' = As + Ba
Transition: s' = -0.45554542541503906s + 0.5418910980224609a
Reward: Qs² + Ra² + 2 Ssa
Reward: -0.7250176668167114s² + -0.9326448440551758a² + -0.6272382736206055sa


## Compute PBOs iteration

In [233]:
from experiments.lqr.utils import define_q_vector_field
from pbo.networks.learnable_pbo import LinearPBO
from pbo.utils.params import load_params

q = define_q_vector_field(p["n_actions_on_max"], p["max_action_on_max"], env.optimal_weights[2], jax.random.PRNGKey(0))

pbo = LinearPBO(
    q=q,
    max_bellman_iterations=0,
    add_infinity=True,
    network_key=jax.random.PRNGKey(0),
    learning_rate={"first": 0, "last": 0, "duration": 0},
    initial_weight_std=p["initial_weight_std"],
)
pbo.params = load_params(
    f"experiments/lqr/figures/{experiment_name}/PBO_linear/{max_bellman_iterations}_P_0"
)

pbo_iterated_weights = pbo(pbo.params, weights)

FileNotFoundError: [Errno 2] No such file or directory: 'experiments/lqr/figures/test/PBO_linear/4_P_0'

## Compute optimal Bellman iteration

In [260]:
from experiments.lqr.utils import define_data_loader_samples

data_loader_samples = define_data_loader_samples(
    p["n_discrete_states"] * p["n_discrete_actions"], experiment_name, p["batch_size_samples"], jax.random.PRNGKey(0)
)
Z = np.zeros((p["n_discrete_states"] * p["n_discrete_actions"], 2))
Z[:, 0] = -np.array(data_loader_samples.replay_buffer.states.flatten())
Z[:, 1] = -2 * np.array(
    data_loader_samples.replay_buffer.states.flatten() * data_loader_samples.replay_buffer.actions.flatten()
)
invert_Z_Z = np.linalg.inv(Z.T @ Z)

J = np.zeros((p["n_discrete_states"] * p["n_discrete_actions"], 3))
J[:, 0] = np.array(data_loader_samples.replay_buffer.states.flatten()) ** 2
J[:, 1] = 2 * np.array(
    data_loader_samples.replay_buffer.states.flatten() * data_loader_samples.replay_buffer.actions.flatten()
)
J[:, 2] = np.array(data_loader_samples.replay_buffer.actions.flatten()) ** 2

D = np.array(
    [
        [env.A**2, -env.A**2 / env.optimal_weights[2]],
        [env.A * env.B, - env.A * env.B / env.optimal_weights[2]],
        [env.B**2, -env.B**2 / env.optimal_weights[2]],
    ]
)

E = np.array([env.Q, env.S, env.R - env.optimal_weights[2]])
projection = - invert_Z_Z @ Z.T @ J

In [261]:
projection

array([[ 2.77555756e-17, -2.77555756e-17, -4.16333634e-17],
       [ 3.46944695e-18,  1.00000000e+00,  0.00000000e+00]])

In [258]:
Z.T @ J

array([[ 0.00000000e+00,  0.00000000e+00, -1.13686838e-13],
       [ 0.00000000e+00, -1.98246411e+04,  0.00000000e+00]])

In [256]:
non_linear_weights = np.zeros_like(weights)
non_linear_weights[:, 0] = weights[:, 0]
non_linear_weights[:, 1] = weights[:, 1] ** 2

bellman_iterated_weights = np.zeros_like(weights)

for idx_weight, weight in enumerate(weights):
    bellman_iterated_weights[idx_weight] = projection @ (D @ non_linear_weights[idx_weight] + E)
    

print(bellman_iterated_weights)

[[-0.91310928 -0.08987599]
 [-0.91310927 -0.089876  ]
 [        nan         nan]
 [-0.91310928 -0.08987599]
 [-0.91310928 -0.08987599]
 [-0.91310928 -0.08987599]
 [-0.91310928 -0.08987599]
 [-0.91310928 -0.08987599]
 [-0.91310928 -0.08987599]
 [-0.91310928 -0.08987599]
 [-0.91310928 -0.08987599]
 [        nan         nan]]
