In [1]:
%load_ext autoreload
%autoreload 2
%cd /home/vincent/iDQN
import sys
import argparse
import multiprocessing
import json
import jax
import haiku as hk
import numpy as np

from experiments.base.parser import addparse
from experiments.base.print import print_info


import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)

experiment_name = "q15_15_N5000_b500_b12_f500_t5_lr4"
bellman_iterations_scope = 12
seed = 1

print_info(experiment_name, "iFQI", "Car-On-Hill", bellman_iterations_scope, seed, train=False)
p = json.load(
    open(f"experiments/car_on_hill/figures/{experiment_name}/parameters.json")
)  # p for parameters

from experiments.car_on_hill.utils import define_environment, define_multi_q
from idqn.networks.learnable_multi_head_q import FullyConnectedMultiQ
from idqn.utils.params import load_params

env, states_x, _, states_v, _ = define_environment(p["gamma"], p["n_states_x"], p["n_states_v"])

q = define_multi_q(
    bellman_iterations_scope + 1,
    p["gamma"],
    jax.random.PRNGKey(0),
    p["layers_dimension"],
)

def evaluate(
    iteration: int,
    idx_head: int,
    v_list: list,
    q_estimate_list: list,
    q: FullyConnectedMultiQ,
    params: hk.Params,
    horizon: int,
    states_x: np.ndarray,
    states_v: np.ndarray,
):
    v_list[iteration] = env.v_mesh_multi_head(q, idx_head, params, horizon, states_x, states_v)
    q_estimate_list[iteration] = env.q_multi_head_estimate_mesh(q, idx_head, params, states_x, states_v)

iterated_v = list(
        np.nan
        * np.zeros((p["n_epochs"] * p["n_bellman_iterations_per_epoch"] + 1, p["n_states_x"], p["n_states_v"]))
    )
iterated_q_estimate = list(
        np.nan
        * np.zeros(
            (
                p["n_epochs"] * p["n_bellman_iterations_per_epoch"] + 1,
                p["n_states_x"],
                p["n_states_v"],
                env.n_actions,
            )
        )
    )

n_forward_moves = p["n_epochs"] * p["n_bellman_iterations_per_epoch"] // bellman_iterations_scope

params = load_params(
    f"experiments/car_on_hill/figures/{experiment_name}/iFQI/{bellman_iterations_scope}_P_{seed}_0-{bellman_iterations_scope}"
)
evaluate(0, 0, iterated_v, iterated_q_estimate, q, params, p["horizon"], states_x, states_v)

/home/vincent/iDQN
-------- q15_15_N5000_b500_b12_f500_t5_lr4 --------
Evaluating iFQI on Car-On-Hill with 12 Bellman iterations at a time and seed 1...


No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [16]:
import jax.numpy as jnp

n_boxes = states_x.shape[0] * states_v.shape[0]
states_x_mesh, states_v_mesh = jnp.meshgrid(states_x, states_v, indexing="ij")

states = jnp.hstack((states_x_mesh.reshape((n_boxes, 1)), states_v_mesh.reshape((n_boxes, 1))))

# Dangerous reshape: the indexing of meshgrid is 'ij'.
q(params, states).reshape((states_x.shape[0], states_v.shape[0], q.n_heads, env.n_actions))[:, :, 130]

# .reshape((states_x.shape[0], states_v.shape[0], q.n_heads, self.n_actions))[
#     :, :, idx_head
# ]

Array([[[-3.03799057e+00, -2.20843601e+00],
        [-2.64409924e+00, -1.91496980e+00],
        [-2.25020790e+00, -1.62150300e+00],
        [-1.85631657e+00, -1.32803667e+00],
        [-1.46242499e+00, -1.03457010e+00],
        [-1.06853390e+00, -7.41103530e-01],
        [-6.74642444e-01, -4.47636962e-01],
        [-2.81115651e-01, -1.43847585e-01],
        [ 5.27800322e-02,  1.19759500e-01],
        [ 1.31547987e-01,  1.88595653e-01],
        [ 1.37206674e-01,  1.86157465e-01],
        [ 1.42865479e-01,  1.83719426e-01],
        [ 1.41329706e-01,  1.79503381e-01],
        [ 1.34275556e-01,  1.73923403e-01],
        [ 1.27220988e-01,  1.68343365e-01],
        [ 1.20166540e-01,  1.62763447e-01],
        [ 1.27728581e-01,  2.10425511e-01]],

       [[-2.33197331e+00, -1.61576784e+00],
        [-1.93808174e+00, -1.32230139e+00],
        [-1.54419041e+00, -1.02883482e+00],
        [-1.15029907e+00, -7.35368252e-01],
        [-7.56407738e-01, -4.41901743e-01],
        [-3.62516522e-01, -1.4