In [None]:
import itertools
import os
from pathlib import Path

import jax
import jax.numpy as jnp
import optax
import orbax.checkpoint as ocp

from coop_rl import networks
from coop_rl.agents import dqn
from coop_rl.utils import HandlerEnvAtari

In [None]:
checkpointdir = os.path.join(Path.home(), "coop-rl_results/jkebbemx/chkpt_step_1920000")
orbax_checkpointer = ocp.StandardCheckpointer()
rng = jax.random.PRNGKey(0)  # jax.random.key(0)
rng, init_rng = jax.random.split(rng)

In [None]:
env_name = "ALE/Breakout-v5"
stack_size = 4  # >= 1, 1 - no stacking
obs_shape, observation_dtype, num_actions = HandlerEnvAtari.check_env(env_name, stack_size)
args_network = {"num_actions": num_actions}
network = networks.NatureDQNNetwork
optimizer = optax.adam
args_optimizer = {"learning_rate": 0.001, "eps": 3.125e-4}

In [None]:
state = dqn.create_train_state(init_rng, network, args_network, optimizer, args_optimizer, obs_shape)
abstract_my_tree = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, state)

In [None]:
state = orbax_checkpointer.restore(checkpointdir, args=ocp.args.StandardRestore(abstract_my_tree))

In [None]:
jax.tree.map(lambda x: x.shape, state.params)

In [None]:
# environment make and close need to be called within one block
# for atari environments in vscode it does not help

env = HandlerEnvAtari("ALE/Breakout-v5", stack_size=stack_size, render_mode="human")

print(f"Action space shape: {env.action_space.shape}.")
print(f"Observation space shape: {env.observation_space.shape}.")
print(f"Reward range: {env.reward_range}.")

rewards = 0
actions = []
observation, info = env.reset()
for step in itertools.count(start=1, step=1):
    observation = observation[None, :]
    action = jnp.argmax(state.apply_fn({"params": state.params}, x=observation).q_values)
    observation, reward, terminated, truncated, info = env.step(action)
    rewards += reward
    actions.append(int(action))

    if terminated or truncated:
        print(f"Total steps: {step}.")
        print(f"Rewards: {rewards}")
        print(f"Actions number: {len(actions)}")
        print(f"Actions: {actions}")
        break

env.close()