# PBO optimal on the chain walk environment

## Define parameters

In [6]:
%load_ext autoreload
%autoreload 2

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import jax
import os
import json

parameters = json.load(open("parameters.json"))
n_states = parameters["n_states"]
n_actions = parameters["n_actions"]
sucess_probability = parameters["sucess_probability"]
gamma = parameters["gamma"]
env_seed = parameters["env_seed"]

# Visualisation of errors and performances
max_bellman_iterations = parameters["max_bellman_iterations"]
dummy_learning_rate = {"first": 0, "last": 0, "duration": 0}
max_bellman_iterations_validation = max_bellman_iterations + 20

# keys
env_key = jax.random.PRNGKey(env_seed)
dummy_q_network_key = env_key.copy()
dummy_pbo_network_key = env_key.copy()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Define environment

In [7]:
import numpy as np
from pbo.environment.chain_walk import ChainWalkEnv


states = np.arange(n_states)
actions = np.arange(n_actions)
states_boxes = np.arange(n_states + 1) - 0.5
actions_boxes = np.arange(n_actions + 1) - 0.5

env = ChainWalkEnv(env_key, n_states, sucess_probability, gamma)

## Optimal PBO

In [8]:
from tqdm.notebook import tqdm
import jax.numpy as jnp

from pbo.sample_collection.dataloader import SampleDataLoader
from pbo.networks.learnable_q import TableQ
from pbo.networks.learnable_pbo import MaxLinearPBO


q = TableQ(1, n_states, 1, n_actions, gamma, dummy_q_network_key, None, None, None, zero_initializer=True)
validation_initial_weight = q.to_weights(q.params)

pbo_optimal = MaxLinearPBO(q, max_bellman_iterations, dummy_pbo_network_key, dummy_learning_rate, n_actions, 0, 0)
pbo_optimal.params["MaxLinearPBONet/linear"]["w"] = gamma * env.transition_proba.T
pbo_optimal.params["MaxLinearPBONet/linear"]["b"] = env.R.T

q_functions = np.zeros((max_bellman_iterations_validation + 1, n_states, n_actions))
bellman_iteration_functions = np.zeros((max_bellman_iterations_validation + 1, n_states, n_actions))
v_functions = np.zeros((max_bellman_iterations_validation + 1, n_states))


batch_iterated_weights = validation_initial_weight.reshape((1, -1))
for bellman_iteration in range(max_bellman_iterations_validation + 1):
    q_i = q.discretize(batch_iterated_weights, states, actions)[0]
    policy_q = q_i.argmax(axis=1)

    q_functions[bellman_iteration] = q_i
    bellman_iteration_functions[bellman_iteration] = env.apply_bellman_operator(q_i)
    v_functions[bellman_iteration] = env.value_function(policy_q)
    print(policy_q)

    batch_iterated_weights = pbo_optimal(pbo_optimal.params, batch_iterated_weights)

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 0]
[0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 0]
[0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 0]
[0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0]
[0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0]
[0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0]
[0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0]
[0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0]
[0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0]
[0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0]
[0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0]
[0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0]
[0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0]
[0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0]
[0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0]
[0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0]
[0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1

## Save data

In [9]:
if not os.path.exists("figures/data/PBO_optimal/"):
    os.makedirs("figures/data/PBO_optimal/")
np.save(f"figures/data/PBO_optimal/{max_bellman_iterations}_Q.npy", q_functions)
np.save(f"figures/data/PBO_optimal/{max_bellman_iterations}_BI.npy", bellman_iteration_functions)
np.save(f"figures/data/PBO_optimal/{max_bellman_iterations}_V.npy", v_functions)