In [None]:
%load_ext autoreload
%autoreload 2

import jax
import jax.numpy as jnp

from idqn.networks.architectures.base import FeatureNet
from idqn.utils.pickle import load_pickled_data, save_pickled_data
from flax.core import FrozenDict

old_params = load_pickled_data("old_figures/q35_N50000_b100_b20_f4_t1_lr4_5_t/iFQI/20_P_1_0-20")
# old_params = load_pickled_data("figures/N50000_t500_b100_lr1_5_D1/iFQI_bound/1_Q_s11_online_params")

# # feature_net = FeatureNet([35])
feature_net = FeatureNet([50])
params_features = feature_net.init(jax.random.PRNGKey(0), jnp.array([0.0, 0.0]))

unfrozen_params_features = params_features.unfreeze()
unfrozen_params_features["params"]["Dense_0"]["bias"] = old_params["FullyConnectedNet/~/head_20_linear_0"]["b"]
unfrozen_params_features["params"]["Dense_0"]["kernel"] = old_params["FullyConnectedNet/~/head_20_linear_0"]["w"]
# unfrozen_params_features["params"]["Dense_0"]["bias"] = old_params["params"]["Dense_0"]["bias"][0]
# unfrozen_params_features["params"]["Dense_0"]["kernel"] = old_params["params"]["Dense_0"]["kernel"][0]
params_features = FrozenDict(unfrozen_params_features)

save_pickled_data("figures/data/features", params_features)

## Test the params without the bias on the linear layer

In [None]:
from idqn.environments.car_on_hill import CarOnHillEnv
from idqn.networks.architectures.ifqi import CarOnHilliFQI

env = CarOnHillEnv(0.95)


q = CarOnHilliFQI(
    2,
    "figures/data/features",
    (2,),
    env.n_actions,
    0.95,
    jax.random.PRNGKey(0),
    0,
    0
)

unfrozen_params = q.params.unfreeze()
unfrozen_params["params"]["Dense_0"]["kernel"] = jnp.repeat(old_params["FullyConnectedNet/~/head_20_linear_last"]["w"][None], 2, axis=0)
# unfrozen_params["params"]["Dense_0"]["kernel"] = old_params["params"]["Dense_1"]["kernel"]
# unfrozen_params["params"]["Dense_0"]["bias"] = old_params["params"]["Dense_1"]["bias"]
q.params = FrozenDict(unfrozen_params)

In [None]:
import numpy as np

from experiments.car_on_hill.utils import TwoDimesionsMesh


states_x = np.linspace(-env.max_position, env.max_position, 17)
states_v = np.linspace(-env.max_velocity, env.max_velocity, 17)

q_diff_values= env.diff_q_estimate_mesh(q, jax.tree_map(lambda param: param[0][None], q.params), states_x, states_v)
performance = env.evaluate(q, jax.tree_map(lambda param: param[0][None], q.params), 100, np.array([-0.5, 0]))

q_visu_mesh = TwoDimesionsMesh(states_x, states_v, axis_equal=False, zero_centered=True)

title = r"$\pi^K_k$" + f" for K = 20 at k = 20\n"
title += f"V([-0.5, 0]) = {np.around(performance, 2)}"

q_visu_mesh.set_values((2 * (q_diff_values > 0) -1).astype(float))
q_visu_mesh.show(title, xlabel="x", ylabel="v", ticks_freq=3)