In [None]:
%load_ext autoreload
%autoreload 2

import json
import math
import jax

p = json.load(open("figures/try_optimal/parameters.json"))# p for parameters

from idqn.environments.car_on_hill import CarOnHillEnv
from idqn.sample_collection.replay_buffer import ReplayBuffer
from idqn.networks.architectures.ifqi import CarOnHilliFQI

q_key = jax.random.PRNGKey(0)

env = CarOnHillEnv(p["gamma"])

replay_buffer = ReplayBuffer(
    (2,),
    p["replay_buffer_size"],
    p["batch_size"],
    p["n_step_return"],
    p["gamma"],
    lambda x: x,
    stack_size=1,
    observation_dtype=float,
)

q = CarOnHilliFQI(
    2,
    "figures/data/features",
    (2,),
    env.n_actions,
    math.pow(p["gamma"], p["n_step_return"]),
    q_key,
    p["learning_rate"],
    p["optimizer_eps"],
)

env.collect_random_samples(
    jax.random.PRNGKey(0),
    replay_buffer,
    p["n_random_samples"],
    p["n_oriented_samples"],
    p["oriented_states"],
    p["horizon"],
)

dataset = replay_buffer.get_all_valid_samples()

In [None]:
from idqn.sample_collection import IDX_RB
from flax.core import FrozenDict
import numpy as np
import jax.numpy as jnp

def update_bellman_iteration(q):
    # shape (n_samples, 1)
    targets = jax.vmap(q.compute_target, in_axes=(None, 0))(q.params, dataset)
    # shape (n_features, n_samples)
    features = jax.vmap(
        lambda sample: q.network.feature_net.apply(
            q.network.feature_net.params, jnp.squeeze(sample[IDX_RB["state"]])
        ),
        out_axes=1,
    )(dataset)

    idx_action_0 = dataset[IDX_RB["action"]] == 0
    targets_action_0 = targets[idx_action_0]
    targets_action_1 = targets[~idx_action_0]
    features_action_0 = features[:, idx_action_0]
    features_action_1 = features[:, ~idx_action_0]

    # shape (n_features)
    params_action_0 = np.linalg.inv(features_action_0 @ features_action_0.T) @ features_action_0 @ targets_action_0
    params_action_1 = np.linalg.inv(features_action_1 @ features_action_1.T) @ features_action_1 @ targets_action_1

    # shape (n_features, 2)
    new_params = jnp.hstack((params_action_0, params_action_1))

    unfrozen_params = q.params.unfreeze()
    # shape (2, n_features, 2)
    unfrozen_params["params"]["Dense_0"]["kernel"] = jnp.repeat(new_params[None], 2, axis=0)
    q.params = FrozenDict(unfrozen_params)

In [None]:
import numpy as np

from experiments.car_on_hill.utils import TwoDimesionsMesh

for bellman_iteration in range(20):
    states_x = np.linspace(-env.max_position, env.max_position, 17)
    states_v = np.linspace(-env.max_velocity, env.max_velocity, 17)

    q_diff_values= env.diff_q_estimate_mesh(q, jax.tree_map(lambda param: param[0][None], q.params), states_x, states_v)
    performance = env.evaluate(q, jax.tree_map(lambda param: param[0][None], q.params), 100, np.array([-0.5, 0]))

    q_visu_mesh = TwoDimesionsMesh(states_x, states_v, axis_equal=False, zero_centered=True)

    title = r"$\pi^*_k$" + f"k = {bellman_iteration}\n"
    title += f"V([-0.5, 0]) = {np.around(performance, 2)}"

    q_visu_mesh.set_values((2 * (q_diff_values > 0) -1).astype(float))
    q_visu_mesh.show(title, xlabel="x", ylabel="v", ticks_freq=3)
    
    old_params = q.params
    update_bellman_iteration(q)
    
    old_loss = q.loss_on_batch(old_params, old_params, dataset, None)
    new_loss = q.loss_on_batch(q.params, old_params, dataset, None)
    print(f"k = {bellman_iteration}, td gain: {old_loss - new_loss}")