In [None]:
%load_ext autoreload
%autoreload 2

import json
import jax

from idqn.environments.car_on_hill import CarOnHillEnv
from idqn.sample_collection.replay_buffer import ReplayBuffer


p = json.load(open(f"parameters.json"))

env = CarOnHillEnv(p["gamma"])

replay_buffer = ReplayBuffer(
    (2,),
    p["replay_buffer_size"],
    p["batch_size"],
    p["n_step_return"],
    p["gamma"],
    lambda x: x,
    stack_size=1,
    observation_dtype=float,
)
env.collect_random_samples(
    jax.random.PRNGKey(0),
    replay_buffer,
    p["n_random_samples"],
    p["n_oriented_samples"],
    p["oriented_states"],
    p["horizon"],
)

replay_buffer_states_x = replay_buffer._store["observation"][:replay_buffer.cursor(), 0]
replay_buffer_states_v = replay_buffer._store["observation"][:replay_buffer.cursor(), 1]
replay_buffer_rewards = replay_buffer._store["reward"][:replay_buffer.cursor()]

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from experiments.car_on_hill.utils import count_samples, TwoDimesionsMesh


states_x = np.linspace(-env.max_position, env.max_position, p["n_states_x"])
states_v = np.linspace(-env.max_velocity, env.max_velocity, p["n_states_v"])
boxes_x_size = (2 * env.max_position) / (p["n_states_x"] - 1)
states_x_boxes = (
    np.linspace(-env.max_position, env.max_position + boxes_x_size, p["n_states_x"] + 1) - boxes_x_size / 2
)
boxes_v_size = (2 * env.max_velocity) / (p["n_states_v"] - 1)
states_v_boxes = (
    np.linspace(-env.max_velocity, env.max_velocity + boxes_v_size, p["n_states_v"] + 1) - boxes_v_size / 2
)

samples_count, _, rewards_count = count_samples(
    replay_buffer_states_x,
    replay_buffer_states_v,
    states_x_boxes,
    states_v_boxes,
    replay_buffer_rewards,
)

samples_visu_mesh = TwoDimesionsMesh(states_x, states_v, sleeping_time=0, axis_equal=False, zero_centered=False)

samples_count[samples_count == 0] = None

samples_visu_mesh.set_values(samples_count)
samples_visu_mesh.show("", xlabel="x", ylabel="v", ticks_freq=2)
# _ = plt.savefig(f"figures/{experiment_name}/sample_repartition.pdf", bbox_inches='tight')

In [None]:
rewards_count[rewards_count == 0] = None

samples_visu_mesh.set_values(rewards_count)
samples_visu_mesh.show("", xlabel="x", ylabel="v", ticks_freq=2)
# _ = plt.savefig(f"figures/{experiment_name}/reward_repartition.pdf", bbox_inches='tight')