In [None]:
import itertools
import os
from pathlib import Path

import jax
import numpy as np

from coop_rl.configs.atari_mdqn import get_config
from coop_rl.environment import HandlerEnvAtari
from coop_rl.workers.collectors import get_select_action_fn

In [2]:
conf = get_config()

In [None]:
conf.observation_shape, conf.observation_dtype, conf.num_actions = conf.args_collector.env.check_env(
    conf.args_env.env_name, conf.args_env.stack_size
)

In [4]:
conf.args_state_recover.checkpointdir = os.path.join(
    Path.home(), "/home/sia/coop-rl_results/s4gg8ou_/chkpt_train_step_4580000"
)
rng = jax.random.PRNGKey(73)
rng, _rng = jax.random.split(rng)
flax_state = conf.state_recover(_rng, **conf.args_state_recover)
select_action = get_select_action_fn(flax_state.apply_fn)

In [None]:
# environment make and close need to be called within one block
# for atari environments in vscode it does not help

env = HandlerEnvAtari(conf.args_env.env_name, stack_size=conf.args_env.stack_size, render_mode="human")

print(f"Action space shape: {env.action_space.shape}.")
print(f"Observation space shape: {env.observation_space.shape}.")

rewards = 0
actions = []
observation, info = env.reset()
for step in itertools.count(start=1, step=1):
    observation = observation[None, :]
    rng, action_jnp = select_action(rng, flax_state.params, observation)
    action_np = np.asarray(action_jnp, dtype=np.int32).squeeze()
    observation, reward, terminated, truncated, info = env.step(action_np)
    rewards += reward
    actions.append(int(action_np))

    if terminated or truncated:
        print(f"Total steps: {step}.")
        print(f"Rewards: {rewards}")
        print(f"Actions number: {len(actions)}")
        print(f"Actions: {actions}")
        break

env.close()