# PBO optimal on LQR

## Define paramters

In [5]:
%load_ext autoreload
%autoreload 2

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import jax
import os
import json

parameters = json.load(open("parameters.json"))
env_seed = parameters["env_seed"]

dummy_max_discrete_state = parameters["max_discrete_state"]

# Q function
dummy_action_range_on_max = parameters["action_range_on_max"]
dummy_n_actions_on_max = parameters["n_actions_on_max"]

# Visualisation of errors and performances
max_bellman_iterations = parameters["max_bellman_iterations"]
max_bellman_iterations_validation = max_bellman_iterations + 5
dummy_learning_rate = {"first": 0, "last": 0, "duration": 0}

# keys
env_key = jax.random.PRNGKey(env_seed)
dummy_q_network_key = env_key.copy()
dummy_pbo_network_key = env_key.copy()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Define environment

In [6]:
import numpy as np

from pbo.environment.linear_quadratic import LinearQuadraticEnv

env = LinearQuadraticEnv(env_key, max_init_state=dummy_max_discrete_state)

Transition: s' = As + Ba
Transition: s' = -0.44256162643432617s + 0.37337517738342285a
Reward: Qs² + Ra² + 2 Ssa
Reward: -0.41404926776885986s² + -73.91294860839844a² + 1.6393332481384277sa


## Optimal PBO

In [7]:
from tqdm.notebook import tqdm

from pbo.networks.learnable_q import LQRQ
from pbo.networks.learnable_pbo import CustomLinearPBO

q = LQRQ(
    state_dim=1,
    action_dim=1,
    n_actions_on_max=dummy_n_actions_on_max,
    action_range_on_max=dummy_action_range_on_max,
    network_key=dummy_q_network_key,
    random_weights_range=None,
    random_weights_key=None,
    learning_rate=None,
    zero_initializer=True,
)
pbo_optimal = CustomLinearPBO(q, max_bellman_iterations, False, dummy_pbo_network_key, dummy_learning_rate)
pbo_optimal.params["CustomLinearPBONet"]["slope"] = env.optimal_slope.reshape((1, 3))
pbo_optimal.params["CustomLinearPBONet"]["bias"] = env.optimal_bias.reshape((1, 3))

validation_initial_weight = q.to_weights(q.params)

weights = np.zeros((max_bellman_iterations_validation + 1, q.weights_dimension))

batch_iterated_weights = validation_initial_weight.reshape((1, -1))
for bellman_iteration in range(max_bellman_iterations_validation + 1):
    weights[bellman_iteration] = batch_iterated_weights[0]
    print(weights[bellman_iteration])

    batch_iterated_weights = pbo_optimal(pbo_optimal.params, batch_iterated_weights)
    
print("Optimal weights")
print(env.optimal_weights)

[0. 0. 0.]
[ -0.41404927   0.81966662 -73.91294861]
[ -0.49336496   0.88658273 -73.96940613]
[ -0.5085988   0.8994351 -73.9802475]
[ -0.51152205   0.9019013  -73.98233032]
[ -0.51208293   0.90237451 -73.98272705]
[ -0.51219052   0.90246528 -73.98280334]
[ -0.51221114   0.90248269 -73.9828186 ]
[ -0.51221514   0.90248603 -73.9828186 ]
[ -0.51221591   0.90248668 -73.9828186 ]
[ -0.51221603   0.9024868  -73.9828186 ]
[ -0.51221603   0.9024868  -73.9828186 ]
[ -0.51221603   0.9024868  -73.9828186 ]
Optimal weights
[ -0.5122161   0.9024868 -73.98282  ]


## Save data

In [8]:
np.save(f"figures/data/PBO_optimal/{max_bellman_iterations}_W.npy", weights)