# CleanRL's Huggingface Integration Demo



[<img src="https://img.shields.io/badge/license-MIT-blue">](https://github.com/vwxyzjn/cleanrl)
[![tests](https://github.com/vwxyzjn/cleanrl/actions/workflows/tests.yaml/badge.svg)](https://github.com/vwxyzjn/cleanrl/actions/workflows/tests.yaml)
[![docs](https://img.shields.io/github/deployments/vwxyzjn/cleanrl/Production?label=docs&logo=vercel)](https://docs.cleanrl.dev/)
[<img src="https://img.shields.io/discord/767863440248143916?label=discord">](https://discord.gg/D6RCjA6sVT)
[<img src="https://img.shields.io/youtube/channel/views/UCDdC6BIFRI0jvcwuhi3aI6w?style=social">](https://www.youtube.com/channel/UCDdC6BIFRI0jvcwuhi3aI6w/videos)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![Imports: isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/)
[<img src="https://img.shields.io/badge/%F0%9F%A4%97%20Models-Huggingface-F8D521">](https://huggingface.co/cleanrl)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/vwxyzjn/cleanrl/blob/master/docs/get-started/CleanRL_Huggingface_Integration_Demo.ipynb)


CleanRL is a Deep Reinforcement Learning library that provides high-quality single-file implementation with research-friendly features. It now has has 🧪 experimental support for saving and loading models from 🤗 HuggingFace's [Model Hub](https://huggingface.co/models). This notebook is a preliminary demo.


* 💾 [GitHub Repo](https://github.com/vwxyzjn/cleanrl)
* 📜 [Documentation](https://docs.cleanrl.dev/)
* 🤗 [HuggingFace Model Hub](https://huggingface.co/cleanrl)
* 🔗 [Open RL Benchmark reports](https://wandb.ai/openrlbenchmark/openrlbenchmark/reportlist)



## Get Started

CleanRL can be installed via `pip`. Let's say we are interested in pulling the model for [`dqn_atari_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn_atari_jax.py), we can install the algorithm-variant-specific dependencies as follows:

In [1]:
!pip install --upgrade "cleanrl[dqn-atari-jax]" # CAVEAT: the extra key is `dqn-atari-jax` with dashes instead of `dqn_atari_jax` with underscores

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting cleanrl[dqn-atari-jax]
  Downloading cleanrl-1.1.2-py3-none-any.whl (16.9 MB)
[K     |████████████████████████████████| 16.9 MB 241 kB/s 
[?25hCollecting pygame==2.1.0
  Downloading pygame-2.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)
[K     |████████████████████████████████| 18.3 MB 59.3 MB/s 
[?25hCollecting huggingface-hub<0.12.0,>=0.11.1
  Downloading huggingface_hub-0.11.1-py3-none-any.whl (182 kB)
[K     |████████████████████████████████| 182 kB 74.9 MB/s 
[?25hCollecting wandb<0.14.0,>=0.13.6
  Downloading wandb-0.13.7-py2.py3-none-any.whl (1.9 MB)
[K     |████████████████████████████████| 1.9 MB 63.6 MB/s 
Collecting stable-baselines3==1.2.0
  Downloading stable_baselines3-1.2.0-py3-none-any.whl (161 kB)
[K     |████████████████████████████████| 161 kB 64.7 MB/s 
[?25hCollecting tensorboard<3.0.0,>=2.10.0
  Downloading tensorboard-2.1

## Enjoy Utility

We have a simple way to load the model by running our "enjoy" utility, which automatically pull the model from 🤗 HuggingFace and run for a few episodes. It also produces a rendered video through the `--capture_video` flag. See more at our [📜 Documentation](https://docs.cleanrl.dev/get-started/zoo/).

In [2]:
!python -m cleanrl_utils.enjoy --exp-name dqn_atari_jax --env-id BreakoutNoFrameskip-v4 --eval-episodes 2 --capture_video

see the appropriate new directories, set the environment variable
`JUPYTER_PLATFORM_DIRS=1` and then run `jupyter --paths`.
The use of platformdirs will be the default in `jupyter_core` v6
  from jupyter_core.paths import jupyter_data_dir, jupyter_runtime_dir, secure_write
loading saved models from cleanrl/BreakoutNoFrameskip-v4-dqn_atari_jax-seed1...
Downloading: 100% 6.75M/6.75M [00:00<00:00, 62.6MB/s]
A.L.E: Arcade Learning Environment (version 0.7.4+069f8bd)
[Powered by Stella]
  deprecation(
  deprecation(
  logger.deprecation(
  deprecation(
  logger.deprecation(
  deprecation(
eval_episode=0, episodic_return=400.0
eval_episode=1, episodic_return=128.0


In [3]:
from IPython.display import Video
Video('videos/eval/rl-video-episode-0.mp4', embed=True)

## Diving Deeper

What happened above was achieved by a simple wrapper for [cleanrl_utils/evals/dqn_eval.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl_utils/evals/dqn_eval.py), which is pretty succinct and may give you a more fine-grained control and access to the model. Its content is roughly as follows, where it attempts to download a model from https://huggingface.co/cleanrl/BreakoutNoFrameskip-v4-dqn_atari_jax-seed1 and run an evaluation pass. 

In [4]:
import random
from typing import Callable

import flax
import flax.linen as nn
import gymnasium as gym
import jax
import numpy as np


def evaluate(
    model_path: str,
    make_env: Callable,
    env_id: str,
    eval_episodes: int,
    run_name: str,
    Model: nn.Module,
    epsilon: float = 0.05,
    capture_video: bool = True,
    seed=1,
):
    envs = gym.vector.SyncVectorEnv([make_env(env_id, 0, 0, capture_video, run_name)],
                                    autoreset_mode=gym.vector.AutoresetMode.SAME_STEP)
    obs, info = envs.reset()
    model = Model(action_dim=envs.single_action_space.n)
    q_key = jax.random.PRNGKey(seed)
    params = model.init(q_key, obs)
    with open(model_path, "rb") as f:
        params = flax.serialization.from_bytes(params, f.read())
    model.apply = jax.jit(model.apply)

    episodic_returns = []
    while len(episodic_returns) < eval_episodes:
        if random.random() < epsilon:
            actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
        else:
            q_values = model.apply(params, obs)
            actions = q_values.argmax(axis=-1)
            actions = jax.device_get(actions)
        next_obs, _, _, _, infos = envs.step(actions)
        for info in infos:
            if "episode" in info.keys():
                print(f"eval_episode={len(episodic_returns)}, episodic_return={info['episode']['r']}")
                episodic_returns += [info["episode"]["r"]]
        obs = next_obs

    return episodic_returns


from huggingface_hub import hf_hub_download

from cleanrl.dqn_atari_jax import QNetwork, make_env

model_path = hf_hub_download(repo_id="cleanrl/BreakoutNoFrameskip-v4-dqn_atari_jax-seed1", filename="dqn_atari_jax.cleanrl_model")
evaluate(
    model_path,
    make_env,
    "BreakoutNoFrameskip-v4",
    eval_episodes=2,
    run_name=f"eval",
    Model=QNetwork,
    capture_video=False,
)

  deprecation(
  deprecation(
  deprecation(


eval_episode=0, episodic_return=340.0
eval_episode=1, episodic_return=399.0


[340.0, 399.0]

## More Examples

Now let's get going with more examples!

In [5]:
from dataclasses import dataclass

from huggingface_hub import hf_hub_download

from pip import main as pipmain

@dataclass
class Args:
    exp_name: str = "dqn_atari_jax"
    seed: int = 1
    hf_entity: str = "cleanrl"
    hf_repository: str = ""
    env_id: str = "BreakoutNoFrameskip-v4"


def dqn():
    import cleanrl.dqn
    import cleanrl_utils.evals.dqn_eval
    return cleanrl.dqn.QNetwork, cleanrl.dqn.make_env, cleanrl_utils.evals.dqn_eval.evaluate

def dqn_atari():
    import cleanrl.dqn_atari
    import cleanrl_utils.evals.dqn_eval
    return cleanrl.dqn_atari.QNetwork, cleanrl.dqn_atari.make_env, cleanrl_utils.evals.dqn_eval.evaluate

def dqn_jax():
    import cleanrl.dqn_jax
    import cleanrl_utils.evals.dqn_jax_eval
    return cleanrl.dqn_jax.QNetwork, cleanrl.dqn_jax.make_env, cleanrl_utils.evals.dqn_jax_eval.evaluate

def dqn_atari_jax():
    import cleanrl.dqn_atari_jax
    import cleanrl_utils.evals.dqn_jax_eval
    return cleanrl.dqn_atari_jax.QNetwork, cleanrl.dqn_atari_jax.make_env, cleanrl_utils.evals.dqn_jax_eval.evaluate

MODELS = {
    "dqn": dqn,
    "dqn_atari": dqn_atari,
    "dqn_jax": dqn_jax,
    "dqn_atari_jax": dqn_atari_jax,
}



exp_names = ["dqn", "dqn_jax", "dqn_atari_jax", "dqn_atari"]
env_idss = [
    [
        "CartPole-v1",
        "Acrobot-v1",
        "MountainCar-v0",
    ],
    [
        "CartPole-v1",
        "Acrobot-v1",
        "MountainCar-v0",
    ],
    [
        "BreakoutNoFrameskip-v4",
        "PongNoFrameskip-v4",
        "BeamRiderNoFrameskip-v4"
    ],
    [
        "BreakoutNoFrameskip-v4",
        "PongNoFrameskip-v4",
        "BeamRiderNoFrameskip-v4"
    ]
  ]


### Install dependencies for each variant

In [6]:
for exp_name, env_ids in zip(exp_names, env_idss):
    # install dependencies for the algorithm variant
    pipmain(['install', '--upgrade', f'cleanrl[{exp_name.replace("_", "-")}]', "--quiet"])
    print("====", ['install', '--upgrade', f'cleanrl[{exp_name.replace("_", "-")}]', "--quiet"])

Please see https://github.com/pypa/pip/issues/5599 for advice on fixing the underlying issue.
To avoid this problem you can invoke Python with '-m pip' instead of running pip directly.
Please see https://github.com/pypa/pip/issues/5599 for advice on fixing the underlying issue.
To avoid this problem you can invoke Python with '-m pip' instead of running pip directly.


==== ['install', '--upgrade', 'cleanrl[dqn]', '--quiet']


Please see https://github.com/pypa/pip/issues/5599 for advice on fixing the underlying issue.
To avoid this problem you can invoke Python with '-m pip' instead of running pip directly.


==== ['install', '--upgrade', 'cleanrl[dqn-jax]', '--quiet']


Please see https://github.com/pypa/pip/issues/5599 for advice on fixing the underlying issue.
To avoid this problem you can invoke Python with '-m pip' instead of running pip directly.


==== ['install', '--upgrade', 'cleanrl[dqn-atari-jax]', '--quiet']




==== ['install', '--upgrade', 'cleanrl[dqn-atari]', '--quiet']


# Enjoy!

In [7]:
for exp_name, env_ids in zip(exp_names, env_idss):
    for env_id in env_ids:
        args = Args(
          exp_name=exp_name,
          seed=1,
          hf_entity="cleanrl",
          hf_repository="",
          env_id=env_id,
        )
        Model, make_env, evaluate = MODELS[args.exp_name]()
        args.hf_repository = f"{args.hf_entity}/{args.env_id}-{args.exp_name}-seed{args.seed}"
        print("loading models from", args.hf_repository)
        model_path = hf_hub_download(repo_id=args.hf_repository, filename=f"{args.exp_name}.cleanrl_model")
        evaluate(
            model_path,
            make_env,
            args.env_id,
            eval_episodes=2,
            run_name=f"eval",
            Model=Model,
            capture_video=False,
        )


loading models from cleanrl/CartPole-v1-dqn-seed1


Downloading:   0%|          | 0.00/45.8k [00:00<?, ?B/s]

  deprecation(


eval_episode=0, episodic_return=500.0
eval_episode=1, episodic_return=500.0
loading models from cleanrl/Acrobot-v1-dqn-seed1


Downloading:   0%|          | 0.00/47.1k [00:00<?, ?B/s]

eval_episode=0, episodic_return=-125.0
eval_episode=1, episodic_return=-82.0
loading models from cleanrl/MountainCar-v0-dqn-seed1


Downloading:   0%|          | 0.00/45.1k [00:00<?, ?B/s]

eval_episode=0, episodic_return=-200.0
eval_episode=1, episodic_return=-200.0
loading models from cleanrl/CartPole-v1-dqn_jax-seed1


Downloading:   0%|          | 0.00/43.9k [00:00<?, ?B/s]

eval_episode=0, episodic_return=500.0
eval_episode=1, episodic_return=500.0
loading models from cleanrl/Acrobot-v1-dqn_jax-seed1


Downloading:   0%|          | 0.00/45.2k [00:00<?, ?B/s]

eval_episode=0, episodic_return=-94.0
eval_episode=1, episodic_return=-81.0
loading models from cleanrl/MountainCar-v0-dqn_jax-seed1


Downloading:   0%|          | 0.00/43.3k [00:00<?, ?B/s]

eval_episode=0, episodic_return=-161.0
eval_episode=1, episodic_return=-150.0
loading models from cleanrl/BreakoutNoFrameskip-v4-dqn_atari_jax-seed1


  deprecation(
  deprecation(
  deprecation(


eval_episode=0, episodic_return=385.0
eval_episode=1, episodic_return=283.0
loading models from cleanrl/PongNoFrameskip-v4-dqn_atari_jax-seed1


Downloading:   0%|          | 0.00/6.75M [00:00<?, ?B/s]

eval_episode=0, episodic_return=18.0
eval_episode=1, episodic_return=19.0
loading models from cleanrl/BeamRiderNoFrameskip-v4-dqn_atari_jax-seed1


Downloading:   0%|          | 0.00/6.76M [00:00<?, ?B/s]

eval_episode=0, episodic_return=9228.0
eval_episode=1, episodic_return=5514.0
loading models from cleanrl/BreakoutNoFrameskip-v4-dqn_atari-seed1


Downloading:   0%|          | 0.00/6.75M [00:00<?, ?B/s]

eval_episode=0, episodic_return=79.0
eval_episode=1, episodic_return=288.0
loading models from cleanrl/PongNoFrameskip-v4-dqn_atari-seed1


Downloading:   0%|          | 0.00/6.75M [00:00<?, ?B/s]

eval_episode=0, episodic_return=16.0
eval_episode=1, episodic_return=19.0
loading models from cleanrl/BeamRiderNoFrameskip-v4-dqn_atari-seed1


Downloading:   0%|          | 0.00/6.76M [00:00<?, ?B/s]

eval_episode=0, episodic_return=12354.0
eval_episode=1, episodic_return=4740.0
