Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib-contrib] Alpha Zero #36736

Merged
merged 9 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 30 additions & 13 deletions .buildkite/pipeline.ml.yml
Original file line number Diff line number Diff line change
Expand Up @@ -484,11 +484,26 @@
- pytest rllib_contrib/alpha_star/tests/
- python rllib_contrib/alpha_star/examples/multi-agent-cartpole-alpha-star.py --run-as-test

- label: ":exploding_death_star: RLlib Contrib: AlphaZero Tests"
conditions: ["NO_WHEELS_REQUIRED", "RAY_CI_RLLIB_CONTRIB_AFFECTED"]
commands:
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
- (cd rllib_contrib/alpha_zero && pip install -r requirements.txt && pip install -e ".[development"])
- ./ci/env/env_info.sh
- pytest rllib_contrib/alpha_zero/tests/
- python alpha_zero/examples/alpha_zero_cartpole_sparse_rewards.py --run-as-test

- label: ":exploding_death_star: RLlib Contrib: APEX DDPG Tests"
conditions: ["NO_WHEELS_REQUIRED", "RAY_CI_RLLIB_CONTRIB_AFFECTED"]
commands:
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
- (cd rllib_contrib/apex_ddpg && pip install -r requirements.txt && pip install -e ".[development"])
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
- (cd rllib_contrib/apex_ddpg && pip install -r requirements.txt && pip install -e ".[development]")
- ./ci/env/env_info.sh
- pytest rllib_contrib/apex_ddpg/tests/
- python rllib_contrib/apex_ddpg/examples/apex_ddpg_pendulum_v1.py --run-as-test
Expand All @@ -500,7 +515,7 @@
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
- (cd rllib_contrib/apex_dqn && pip install -r requirements.txt && pip install -e ".[development"])
- (cd rllib_contrib/apex_dqn && pip install -r requirements.txt && pip install -e ".[development]")
- ./ci/env/env_info.sh
- pytest rllib_contrib/apex_dqn/tests/
- python rllib_contrib/apex_dqn/examples/apex_dqn_cartpole_v1.py --run-as-test
Expand All @@ -512,7 +527,7 @@
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
- (cd rllib_contrib/bandit && pip install -r requirements.txt && pip install -e .)
- (cd rllib_contrib/bandit && pip install -r requirements.txt && pip install -e ".[development]")
- ./ci/env/env_info.sh
- pytest rllib_contrib/bandit/tests/
- python rllib_contrib/bandit/examples/bandit_linucb_interest_evolution_recsim.py --run-as-test
Expand All @@ -524,7 +539,7 @@
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
- (cd rllib_contrib/ddpg && pip install -r requirements.txt && pip install -e ".[development"])
- (cd rllib_contrib/ddpg && pip install -r requirements.txt && pip install -e ".[development]")
- ./ci/env/env_info.sh
- pytest rllib_contrib/ddpg/tests/
- python rllib_contrib/ddpg/examples/ddpg_pendulum_v1.py --run-as-test
Expand All @@ -536,7 +551,7 @@
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
- (cd rllib_contrib/leela_chess_zero && pip install -r requirements.txt && pip install -e ".[development"])
- (cd rllib_contrib/leela_chess_zero && pip install -r requirements.txt && pip install -e ".[development]")
- ./ci/env/env_info.sh
- pytest rllib_contrib/leela_chess_zero/tests/
- python rllib_contrib/leela_chess_zero/examples/leela_chess_zero_connect_4.py --run-as-test
Expand All @@ -545,6 +560,7 @@
conditions: ["NO_WHEELS_REQUIRED", "RAY_CI_RLLIB_CONTRIB_AFFECTED"]
commands:
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT

# Install mujoco necessary for the testing environments
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
Expand All @@ -555,7 +571,8 @@
- mv mujoco210-linux-x86_64.tar.gz /root/.mujoco/.
- (cd /root/.mujoco && tar -xf /root/.mujoco/mujoco210-linux-x86_64.tar.gz)
- export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/root/.mujoco/mujoco210/bin
- (cd rllib_contrib/maml && pip install -r requirements.txt && pip install -e ".[development"])

- (cd rllib_contrib/maml && pip install -r requirements.txt && pip install -e ".[development]")
- ./ci/env/env_info.sh
- pytest rllib_contrib/maml/tests/test_maml.py

Expand All @@ -573,7 +590,7 @@
- echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/root/.mujoco/mujoco210/bin' >> /root/.bashrc
- source /root/.bashrc

- (cd rllib_contrib/mbmpo && pip install -r requirements.txt && pip install -e ".[development"])
- (cd rllib_contrib/mbmpo && pip install -r requirements.txt && pip install -e ".[development]")
- ./ci/env/env_info.sh
- pytest rllib_contrib/mbmpo/tests/
- python rllib_contrib/mbmpo/examples/mbmpo_cartpole_v1_model_based.py --run-as-test
Expand All @@ -585,7 +602,7 @@
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
- (cd rllib_contrib/pg && pip install -r requirements.txt && pip install -e ".[development"])
- (cd rllib_contrib/pg && pip install -r requirements.txt && pip install -e ".[development]")
- ./ci/env/env_info.sh
- pytest rllib_contrib/pg/tests/
- python rllib_contrib/pg/examples/pg_cartpole_v1.py --run-as-test
Expand All @@ -597,7 +614,7 @@
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
- (cd rllib_contrib/qmix && pip install -r requirements.txt && pip install -e ".[development"])
- (cd rllib_contrib/qmix && pip install -r requirements.txt && pip install -e ".[development]")
- ./ci/env/env_info.sh
- pytest rllib_contrib/qmix/tests/
- python rllib_contrib/qmix/examples/qmix_two_step_game.py --run-as-test
Expand All @@ -609,13 +626,13 @@
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
- (cd rllib_contrib/r2d2 && pip install -r requirements.txt && pip install -e ".[development"])
- (cd rllib_contrib/r2d2 && pip install -r requirements.txt && pip install -e ".[development]")
- ./ci/env/env_info.sh
- pytest rllib_contrib/r2d2/tests/
- python rllib_contrib/r2d2/examples/r2d2_stateless_cartpole.py --run-as-test

- label: ":exploding_death_star: RLlib Contrib: SimpleQ Tests"
- (cd rllib_contrib/simple_q && pip install -r requirements.txt && pip install -e ".[development"])
- (cd rllib_contrib/simple_q && pip install -r requirements.txt && pip install -e ".[development]")
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
Expand All @@ -630,7 +647,7 @@
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
- (cd rllib_contrib/slate_q && pip install -r requirements.txt && pip install -e ".[development"])
- (cd rllib_contrib/slate_q && pip install -r requirements.txt && pip install -e ".[development]")
- ./ci/env/env_info.sh
- pytest rllib_contrib/slate_q/tests/
- python rllib_contrib/slate_q/examples/recommender_system_with_recsim_and_slateq.py --run-as-test
Expand All @@ -642,7 +659,7 @@
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
- (cd rllib_contrib/td3 && pip install -r requirements.txt && pip install -e ".[development"])
- (cd rllib_contrib/td3 && pip install -r requirements.txt && pip install -e ".[development]")
- ./ci/env/env_info.sh
- pytest rllib_contrib/td3/tests/
- python rllib_contrib/td3/examples/td3_pendulum_v1.py --run-as-test
16 changes: 16 additions & 0 deletions rllib_contrib/alpha_zero/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Alpha Zero

[Alpha Zero](https://arxiv.org/abs/1712.01815) is a general reinforcement learning approach that achieved superhuman performance in the games of chess, shogi, and Go through tabula rasa learning from games of self-play, surpassing previous state-of-the-art programs that relied on handcrafted evaluation functions and domain-specific adaptations.

## Installation

```
conda create -n rllib-alpha-zero python=3.10
conda activate rllib-alpha-zero
pip install -r requirements.txt
pip install -e '.[development]'
```

## Usage

[AlphaZero Example]()
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import argparse

from rllib_alpha_zero.alpha_zero import AlphaZero, AlphaZeroConfig
from rllib_alpha_zero.alpha_zero.custom_torch_models import DenseModel

import ray
from ray import air, tune
from ray.rllib.examples.env.cartpole_sparse_rewards import CartPoleSparseRewards
from ray.rllib.utils.test_utils import check_learning_achieved


def get_cli_args():
"""Create CLI parser and return parsed arguments"""
parser = argparse.ArgumentParser()
parser.add_argument("--run-as-test", action="store_true", default=False)
args = parser.parse_args()
print(f"Running with following CLI args: {args}")
return args


if __name__ == "__main__":
args = get_cli_args()

ray.init()

config = (
AlphaZeroConfig()
.rollouts(
num_rollout_workers=6,
rollout_fragment_length=50,
)
.framework("torch")
.environment(CartPoleSparseRewards)
.training(
train_batch_size=500,
sgd_minibatch_size=64,
lr=1e-4,
num_sgd_iter=1,
mcts_config={
"puct_coefficient": 1.5,
"num_simulations": 100,
"temperature": 1.0,
"dirichlet_epsilon": 0.20,
"dirichlet_noise": 0.03,
"argmax_tree_policy": False,
"add_dirichlet_noise": True,
},
ranked_rewards={
"enable": True,
},
model={
"custom_model": DenseModel,
},
)
)

stop_reward = 30.0

tuner = tune.Tuner(
AlphaZero,
param_space=config.to_dict(),
run_config=air.RunConfig(
stop={
"sampler_results/episode_reward_mean": stop_reward,
"timesteps_total": 100000,
},
failure_config=air.FailureConfig(fail_fast="raise"),
),
)
results = tuner.fit()

if args.run_as_test:
check_learning_achieved(results, stop_reward)
18 changes: 18 additions & 0 deletions rllib_contrib/alpha_zero/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"

[tool.setuptools.packages.find]
where = ["src"]

[project]
name = "rllib-alpha-zero"
authors = [{name = "Anyscale Inc."}]
version = "0.1.0"
description = ""
readme = "README.md"
requires-python = ">=3.7, <3.11"
dependencies = ["gymnasium==0.26.3", "ray[rllib]==2.5.1"]

[project.optional-dependencies]
development = ["pytest>=7.2.2", "pre-commit==2.21.0", "torch==1.12.0"]
1 change: 1 addition & 0 deletions rllib_contrib/alpha_zero/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
torch==1.12.0
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from rllib_alpha_zero.alpha_zero.alpha_zero import (
AlphaZero,
AlphaZeroConfig,
AlphaZeroDefaultCallbacks,
)
from rllib_alpha_zero.alpha_zero.alpha_zero_policy import AlphaZeroPolicy

from ray.tune.registry import register_trainable

__all__ = [
"AlphaZeroConfig",
"AlphaZero",
"AlphaZeroDefaultCallbacks",
"AlphaZeroPolicy",
]

register_trainable("rllib-contrib-alpha-zero", AlphaZero)
Loading
Loading