In [None]:
!curl -sSL https://install.python-poetry.org | python3 -
%cd /kaggle/working/
!git clone https://github.com/vwxyzjn/cleanrl.git
%cd cleanrl
!/root/.local/bin/poetry install
!/root/.local/bin/poetry install -E atari

In [None]:
!pip install "stable_baselines3==2.0.0a1" "gymnasium[atari,accept-rom-license]==0.28.1"  "ale-py==0.8.1"

In [None]:
!pip install gym pyvirtualdisplay > /dev/null 2>&1
!apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1
!pip install colabgymrender==1.0.2

In [None]:
!rm -rf videos

In [None]:
!python cleanrl/c51_atari_jax.py --env-id ALE/KungFuMaster-v5 --buffer-size 10000 --save-model --capture-video --learning-start 1000  --total-timesteps 500000 --target-network-frequency 150 --seed 42

In [2]:
%cd /kaggle/working/cleanrl

/kaggle/working/cleanrl


In [None]:
import random
from argparse import Namespace
from typing import Callable

import flax
import flax.linen as nn
import gymnasium as gym
import jax
import jax.numpy as jnp
import numpy as np

from stable_baselines3.common.atari_wrappers import (
    ClipRewardEnv,
    EpisodicLifeEnv,
    FireResetEnv,
    MaxAndSkipEnv,
    NoopResetEnv,
)

def make_env(env_id, seed, idx, capture_video, run_name):
    def thunk():
        if capture_video and idx == 0:
            env = gym.make(env_id, render_mode="rgb_array")
            env = gym.wrappers.RecordVideo(env, f"videos/{run_name}", episode_trigger=lambda x: True)
        else:
            env = gym.make(env_id)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        env = NoopResetEnv(env, noop_max=30)
        env = MaxAndSkipEnv(env, skip=4)
        env = EpisodicLifeEnv(env)
        if "FIRE" in env.unwrapped.get_action_meanings():
            env = FireResetEnv(env)
        env = ClipRewardEnv(env)
        env = gym.wrappers.ResizeObservation(env, (84, 84))
        env = gym.wrappers.GrayScaleObservation(env)
        env = gym.wrappers.FrameStack(env, 4)
        env.action_space.seed(seed)

        return env

    return thunk

def evaluate(
    model_path: str,
    make_env: Callable,
    env_id: str,
    eval_episodes: int,
    run_name: str,
    Model: nn.Module,
    epsilon: float = 0.05,
    capture_video: bool = True,
    seed=1,
):
    envs = gym.vector.SyncVectorEnv([make_env(env_id, 0, 0, capture_video, run_name)])
    obs, _ = envs.reset()
    model_data = None
    with open(model_path, "rb") as f:
        model_data = flax.serialization.from_bytes(model_data, f.read())
    args = Namespace(**model_data["args"])
    model = Model(action_dim=envs.single_action_space.n, n_atoms=args.n_atoms)
    # q_key = jax.random.PRNGKey(seed)
    params = model_data["model_weights"]
    model.apply = jax.jit(model.apply)
    atoms = jnp.asarray(np.linspace(args.v_min, args.v_max, num=args.n_atoms))

    episodic_returns = []
    while max(episodic_returns) < 15000 if len(episodic_returns) > 0 else True:
        if random.random() < epsilon:
            actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
        else:
            pmfs = model.apply(params, obs)
            q_vals = (pmfs * atoms).sum(axis=-1)
            actions = q_vals.argmax(axis=-1)
            actions = jax.device_get(actions)
        next_obs, _, _, _, infos = envs.step(actions)
        if "final_info" in infos:
            for info in infos["final_info"]:
                if "episode" not in info:
                    continue
                print(f"eval_episode={len(episodic_returns)}, episodic_return={info['episode']['r']}")
                episodic_returns += [info["episode"]["r"]]
        obs = next_obs

    return episodic_returns



from huggingface_hub import hf_hub_download

from cleanrl.c51_atari_jax import QNetwork

episodic_rerturn = evaluate(
    "/kaggle/working/cleanrl/runs/ALE/KungFuMaster-v5__c51_atari_jax__42__1688711491/c51_atari_jax.cleanrl_model",
    make_env,
    "ALE/KungFuMaster-v5",
    eval_episodes=10,
    run_name=f"eval",
    Model=QNetwork,
    capture_video=True,
)

In [14]:
max(episodic_rerturn)

array([11900.], dtype=float32)

In [18]:
!rm -rf /kaggle/working/cleanrl/videos/eval

In [11]:
!rm /kaggle/working/cleanrl/file.zip
!zip -r file.zip /kaggle/working/cleanrl/videos/eval

  adding: kaggle/working/cleanrl/videos/eval/ (stored 0%)
  adding: kaggle/working/cleanrl/videos/eval/rl-video-episode-0.mp4 (deflated 11%)
  adding: kaggle/working/cleanrl/videos/eval/rl-video-episode-0.meta.json (deflated 13%)
  adding: kaggle/working/cleanrl/videos/eval/rl-video-episode-1.meta.json (deflated 10%)
  adding: kaggle/working/cleanrl/videos/eval/rl-video-episode-8.meta.json (deflated 8%)
  adding: kaggle/working/cleanrl/videos/eval/rl-video-episode-8.mp4 (deflated 10%)
  adding: kaggle/working/cleanrl/videos/eval/rl-video-episode-1.mp4 (deflated 10%)


In [6]:
from IPython.display import FileLink
FileLink(r'file.zip')