Skip to content

Commit

Permalink
[RLlib-contrib] Alpha Zero. (#36736)
Browse files Browse the repository at this point in the history
  • Loading branch information
avnishn committed Oct 4, 2023
1 parent 0aaf579 commit 331c5b7
Show file tree
Hide file tree
Showing 12 changed files with 1,114 additions and 13 deletions.
43 changes: 30 additions & 13 deletions .buildkite/pipeline.ml.yml
Original file line number Diff line number Diff line change
Expand Up @@ -484,11 +484,26 @@
- pytest rllib_contrib/alpha_star/tests/
- python rllib_contrib/alpha_star/examples/multi-agent-cartpole-alpha-star.py --run-as-test

- label: ":exploding_death_star: RLlib Contrib: AlphaZero Tests"
conditions: ["NO_WHEELS_REQUIRED", "RAY_CI_RLLIB_CONTRIB_AFFECTED"]
commands:
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
- (cd rllib_contrib/alpha_zero && pip install -r requirements.txt && pip install -e ".[development"])
- ./ci/env/env_info.sh
- pytest rllib_contrib/alpha_zero/tests/
- python alpha_zero/examples/alpha_zero_cartpole_sparse_rewards.py --run-as-test

- label: ":exploding_death_star: RLlib Contrib: APEX DDPG Tests"
conditions: ["NO_WHEELS_REQUIRED", "RAY_CI_RLLIB_CONTRIB_AFFECTED"]
commands:
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
- (cd rllib_contrib/apex_ddpg && pip install -r requirements.txt && pip install -e ".[development"])
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
- (cd rllib_contrib/apex_ddpg && pip install -r requirements.txt && pip install -e ".[development]")
- ./ci/env/env_info.sh
- pytest rllib_contrib/apex_ddpg/tests/
- python rllib_contrib/apex_ddpg/examples/apex_ddpg_pendulum_v1.py --run-as-test
Expand All @@ -500,7 +515,7 @@
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
- (cd rllib_contrib/apex_dqn && pip install -r requirements.txt && pip install -e ".[development"])
- (cd rllib_contrib/apex_dqn && pip install -r requirements.txt && pip install -e ".[development]")
- ./ci/env/env_info.sh
- pytest rllib_contrib/apex_dqn/tests/
- python rllib_contrib/apex_dqn/examples/apex_dqn_cartpole_v1.py --run-as-test
Expand All @@ -512,7 +527,7 @@
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
- (cd rllib_contrib/bandit && pip install -r requirements.txt && pip install -e .)
- (cd rllib_contrib/bandit && pip install -r requirements.txt && pip install -e ".[development]")
- ./ci/env/env_info.sh
- pytest rllib_contrib/bandit/tests/
- python rllib_contrib/bandit/examples/bandit_linucb_interest_evolution_recsim.py --run-as-test
Expand All @@ -524,7 +539,7 @@
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
- (cd rllib_contrib/ddpg && pip install -r requirements.txt && pip install -e ".[development"])
- (cd rllib_contrib/ddpg && pip install -r requirements.txt && pip install -e ".[development]")
- ./ci/env/env_info.sh
- pytest rllib_contrib/ddpg/tests/
- python rllib_contrib/ddpg/examples/ddpg_pendulum_v1.py --run-as-test
Expand All @@ -536,7 +551,7 @@
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
- (cd rllib_contrib/leela_chess_zero && pip install -r requirements.txt && pip install -e ".[development"])
- (cd rllib_contrib/leela_chess_zero && pip install -r requirements.txt && pip install -e ".[development]")
- ./ci/env/env_info.sh
- pytest rllib_contrib/leela_chess_zero/tests/
- python rllib_contrib/leela_chess_zero/examples/leela_chess_zero_connect_4.py --run-as-test
Expand All @@ -545,6 +560,7 @@
conditions: ["NO_WHEELS_REQUIRED", "RAY_CI_RLLIB_CONTRIB_AFFECTED"]
commands:
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT

# Install mujoco necessary for the testing environments
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
Expand All @@ -555,7 +571,8 @@
- mv mujoco210-linux-x86_64.tar.gz /root/.mujoco/.
- (cd /root/.mujoco && tar -xf /root/.mujoco/mujoco210-linux-x86_64.tar.gz)
- export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/root/.mujoco/mujoco210/bin
- (cd rllib_contrib/maml && pip install -r requirements.txt && pip install -e ".[development"])

- (cd rllib_contrib/maml && pip install -r requirements.txt && pip install -e ".[development]")
- ./ci/env/env_info.sh
- pytest rllib_contrib/maml/tests/test_maml.py

Expand All @@ -573,7 +590,7 @@
- echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/root/.mujoco/mujoco210/bin' >> /root/.bashrc
- source /root/.bashrc

- (cd rllib_contrib/mbmpo && pip install -r requirements.txt && pip install -e ".[development"])
- (cd rllib_contrib/mbmpo && pip install -r requirements.txt && pip install -e ".[development]")
- ./ci/env/env_info.sh
- pytest rllib_contrib/mbmpo/tests/
- python rllib_contrib/mbmpo/examples/mbmpo_cartpole_v1_model_based.py --run-as-test
Expand All @@ -585,7 +602,7 @@
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
- (cd rllib_contrib/pg && pip install -r requirements.txt && pip install -e ".[development"])
- (cd rllib_contrib/pg && pip install -r requirements.txt && pip install -e ".[development]")
- ./ci/env/env_info.sh
- pytest rllib_contrib/pg/tests/
- python rllib_contrib/pg/examples/pg_cartpole_v1.py --run-as-test
Expand All @@ -597,7 +614,7 @@
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
- (cd rllib_contrib/qmix && pip install -r requirements.txt && pip install -e ".[development"])
- (cd rllib_contrib/qmix && pip install -r requirements.txt && pip install -e ".[development]")
- ./ci/env/env_info.sh
- pytest rllib_contrib/qmix/tests/
- python rllib_contrib/qmix/examples/qmix_two_step_game.py --run-as-test
Expand All @@ -609,13 +626,13 @@
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
- (cd rllib_contrib/r2d2 && pip install -r requirements.txt && pip install -e ".[development"])
- (cd rllib_contrib/r2d2 && pip install -r requirements.txt && pip install -e ".[development]")
- ./ci/env/env_info.sh
- pytest rllib_contrib/r2d2/tests/
- python rllib_contrib/r2d2/examples/r2d2_stateless_cartpole.py --run-as-test

- label: ":exploding_death_star: RLlib Contrib: SimpleQ Tests"
- (cd rllib_contrib/simple_q && pip install -r requirements.txt && pip install -e ".[development"])
- (cd rllib_contrib/simple_q && pip install -r requirements.txt && pip install -e ".[development]")
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
Expand All @@ -630,7 +647,7 @@
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
- (cd rllib_contrib/slate_q && pip install -r requirements.txt && pip install -e ".[development"])
- (cd rllib_contrib/slate_q && pip install -r requirements.txt && pip install -e ".[development]")
- ./ci/env/env_info.sh
- pytest rllib_contrib/slate_q/tests/
- python rllib_contrib/slate_q/examples/recommender_system_with_recsim_and_slateq.py --run-as-test
Expand All @@ -642,7 +659,7 @@
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
- (cd rllib_contrib/td3 && pip install -r requirements.txt && pip install -e ".[development"])
- (cd rllib_contrib/td3 && pip install -r requirements.txt && pip install -e ".[development]")
- ./ci/env/env_info.sh
- pytest rllib_contrib/td3/tests/
- python rllib_contrib/td3/examples/td3_pendulum_v1.py --run-as-test
16 changes: 16 additions & 0 deletions rllib_contrib/alpha_zero/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Alpha Zero

[Alpha Zero](https://arxiv.org/abs/1712.01815) is a general reinforcement learning approach that achieved superhuman performance in the games of chess, shogi, and Go through tabula rasa learning from games of self-play, surpassing previous state-of-the-art programs that relied on handcrafted evaluation functions and domain-specific adaptations.

## Installation

```
conda create -n rllib-alpha-zero python=3.10
conda activate rllib-alpha-zero
pip install -r requirements.txt
pip install -e '.[development]'
```

## Usage

[AlphaZero Example]()
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import argparse

from rllib_alpha_zero.alpha_zero import AlphaZero, AlphaZeroConfig
from rllib_alpha_zero.alpha_zero.custom_torch_models import DenseModel

import ray
from ray import air, tune
from ray.rllib.examples.env.cartpole_sparse_rewards import CartPoleSparseRewards
from ray.rllib.utils.test_utils import check_learning_achieved


def get_cli_args():
"""Create CLI parser and return parsed arguments"""
parser = argparse.ArgumentParser()
parser.add_argument("--run-as-test", action="store_true", default=False)
args = parser.parse_args()
print(f"Running with following CLI args: {args}")
return args


if __name__ == "__main__":
args = get_cli_args()

ray.init()

config = (
AlphaZeroConfig()
.rollouts(
num_rollout_workers=6,
rollout_fragment_length=50,
)
.framework("torch")
.environment(CartPoleSparseRewards)
.training(
train_batch_size=500,
sgd_minibatch_size=64,
lr=1e-4,
num_sgd_iter=1,
mcts_config={
"puct_coefficient": 1.5,
"num_simulations": 100,
"temperature": 1.0,
"dirichlet_epsilon": 0.20,
"dirichlet_noise": 0.03,
"argmax_tree_policy": False,
"add_dirichlet_noise": True,
},
ranked_rewards={
"enable": True,
},
model={
"custom_model": DenseModel,
},
)
)

stop_reward = 30.0

tuner = tune.Tuner(
AlphaZero,
param_space=config.to_dict(),
run_config=air.RunConfig(
stop={
"sampler_results/episode_reward_mean": stop_reward,
"timesteps_total": 100000,
},
failure_config=air.FailureConfig(fail_fast="raise"),
),
)
results = tuner.fit()

if args.run_as_test:
check_learning_achieved(results, stop_reward)
18 changes: 18 additions & 0 deletions rllib_contrib/alpha_zero/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"

[tool.setuptools.packages.find]
where = ["src"]

[project]
name = "rllib-alpha-zero"
authors = [{name = "Anyscale Inc."}]
version = "0.1.0"
description = ""
readme = "README.md"
requires-python = ">=3.7, <3.11"
dependencies = ["gymnasium==0.26.3", "ray[rllib]==2.5.1"]

[project.optional-dependencies]
development = ["pytest>=7.2.2", "pre-commit==2.21.0", "torch==1.12.0"]
1 change: 1 addition & 0 deletions rllib_contrib/alpha_zero/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
torch==1.12.0
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from rllib_alpha_zero.alpha_zero.alpha_zero import (
AlphaZero,
AlphaZeroConfig,
AlphaZeroDefaultCallbacks,
)
from rllib_alpha_zero.alpha_zero.alpha_zero_policy import AlphaZeroPolicy

from ray.tune.registry import register_trainable

__all__ = [
"AlphaZeroConfig",
"AlphaZero",
"AlphaZeroDefaultCallbacks",
"AlphaZeroPolicy",
]

register_trainable("rllib-contrib-alpha-zero", AlphaZero)
Loading

0 comments on commit 331c5b7

Please sign in to comment.