From c555874485d29b62a96a5af7f1f8635bd623a69b Mon Sep 17 00:00:00 2001 From: Avnish Date: Thu, 22 Jun 2023 16:18:37 -0700 Subject: [PATCH 1/3] [RLlib-contrib] Alpha Zero Signed-off-by: Avnish --- .buildkite/pipeline.ml.yml | 8 + rllib_contrib/alpha_zero/README.md | 16 + rllib_contrib/alpha_zero/pyproject.toml | 18 + rllib_contrib/alpha_zero/requirements.txt | 1 + .../rllib_alpha_zero/alpha_zero/__init__.py | 17 + .../rllib_alpha_zero/alpha_zero/alpha_zero.py | 406 ++++++++++++++++++ .../alpha_zero/alpha_zero_policy.py | 158 +++++++ .../alpha_zero/custom_torch_models.py | 116 +++++ .../src/rllib_alpha_zero/alpha_zero/mcts.py | 157 +++++++ .../alpha_zero/ranked_rewards.py | 78 ++++ .../alpha_zero/tests/test_alpha_zero.py | 44 ++ 11 files changed, 1019 insertions(+) create mode 100644 rllib_contrib/alpha_zero/README.md create mode 100644 rllib_contrib/alpha_zero/pyproject.toml create mode 100644 rllib_contrib/alpha_zero/requirements.txt create mode 100644 rllib_contrib/alpha_zero/src/rllib_alpha_zero/alpha_zero/__init__.py create mode 100644 rllib_contrib/alpha_zero/src/rllib_alpha_zero/alpha_zero/alpha_zero.py create mode 100644 rllib_contrib/alpha_zero/src/rllib_alpha_zero/alpha_zero/alpha_zero_policy.py create mode 100644 rllib_contrib/alpha_zero/src/rllib_alpha_zero/alpha_zero/custom_torch_models.py create mode 100644 rllib_contrib/alpha_zero/src/rllib_alpha_zero/alpha_zero/mcts.py create mode 100644 rllib_contrib/alpha_zero/src/rllib_alpha_zero/alpha_zero/ranked_rewards.py create mode 100644 rllib_contrib/alpha_zero/tests/test_alpha_zero.py diff --git a/.buildkite/pipeline.ml.yml b/.buildkite/pipeline.ml.yml index 788a8ced6207f..724681318b205 100644 --- a/.buildkite/pipeline.ml.yml +++ b/.buildkite/pipeline.ml.yml @@ -542,3 +542,11 @@ - (cd rllib_contrib/maml && pip install -r requirements.txt && pip install -e .) - ./ci/env/env_info.sh - pytest rllib_contrib/maml/tests/test_maml.py + +- label: ":exploding_death_star: RLlib Contrib: AlphaZero Tests" + conditions: ["NO_WHEELS_REQUIRED", "RAY_CI_RLLIB_CONTRIB_AFFECTED"] + commands: + - cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT + - (cd rllib_contrib/alpha_zero && pip install -r requirements.txt && pip install -e .) + - ./ci/env/env_info.sh + - pytest rllib_contrib/alpha_zero/tests/ \ No newline at end of file diff --git a/rllib_contrib/alpha_zero/README.md b/rllib_contrib/alpha_zero/README.md new file mode 100644 index 0000000000000..ee5e3f9675cda --- /dev/null +++ b/rllib_contrib/alpha_zero/README.md @@ -0,0 +1,16 @@ +# Alpha Zero + +[Alpha Zero](https://arxiv.org/abs/1712.01815) is a general reinforcement learning approach that achieved superhuman performance in the games of chess, shogi, and Go through tabula rasa learning from games of self-play, surpassing previous state-of-the-art programs that relied on handcrafted evaluation functions and domain-specific adaptations. + +## Installation + +``` +conda create -n rllib-alpha-zero python=3.10 +conda activate rllib-alpha-zero +pip install -r requirements.txt +pip install -e '.[development]' +``` + +## Usage + +[AlphaZero Example]() \ No newline at end of file diff --git a/rllib_contrib/alpha_zero/pyproject.toml b/rllib_contrib/alpha_zero/pyproject.toml new file mode 100644 index 0000000000000..7af25b100a427 --- /dev/null +++ b/rllib_contrib/alpha_zero/pyproject.toml @@ -0,0 +1,18 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["src"] + +[project] +name = "rllib-alpha-zero" +authors = [{name = "Anyscale Inc."}] +version = "0.1.0" +description = "" +readme = "README.md" +requires-python = ">=3.7, <3.11" +dependencies = ["gymnasium==0.26.3", "ray[rllib]==2.5.1"] + +[project.optional-dependencies] +development = ["pytest>=7.2.2", "pre-commit==2.21.0", "torch==1.12.0"] diff --git a/rllib_contrib/alpha_zero/requirements.txt b/rllib_contrib/alpha_zero/requirements.txt new file mode 100644 index 0000000000000..e237d2a817495 --- /dev/null +++ b/rllib_contrib/alpha_zero/requirements.txt @@ -0,0 +1 @@ +torch==1.12.0 diff --git a/rllib_contrib/alpha_zero/src/rllib_alpha_zero/alpha_zero/__init__.py b/rllib_contrib/alpha_zero/src/rllib_alpha_zero/alpha_zero/__init__.py new file mode 100644 index 0000000000000..2a4d9cd92e08b --- /dev/null +++ b/rllib_contrib/alpha_zero/src/rllib_alpha_zero/alpha_zero/__init__.py @@ -0,0 +1,17 @@ +from rllib_alpha_zero.alpha_zero.alpha_zero import ( + AlphaZero, + AlphaZeroConfig, + AlphaZeroDefaultCallbacks, +) +from rllib_alpha_zero.alpha_zero.alpha_zero_policy import AlphaZeroPolicy + +from ray.tune.registry import register_trainable + +__all__ = [ + "AlphaZeroConfig", + "AlphaZero", + "AlphaZeroDefaultCallbacks", + "AlphaZeroPolicy", +] + +register_trainable("rllib-contrib-alpha-zero", AlphaZero) diff --git a/rllib_contrib/alpha_zero/src/rllib_alpha_zero/alpha_zero/alpha_zero.py b/rllib_contrib/alpha_zero/src/rllib_alpha_zero/alpha_zero/alpha_zero.py new file mode 100644 index 0000000000000..46835247e3433 --- /dev/null +++ b/rllib_contrib/alpha_zero/src/rllib_alpha_zero/alpha_zero/alpha_zero.py @@ -0,0 +1,406 @@ +import logging +from typing import List, Optional, Type, Union + +from rllib_alpha_zero.alpha_zero.alpha_zero_policy import AlphaZeroPolicy +from rllib_alpha_zero.alpha_zero.mcts import MCTS +from rllib_alpha_zero.alpha_zero.ranked_rewards import get_r2_env_wrapper + +from ray.rllib.algorithms.algorithm import Algorithm +from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided +from ray.rllib.algorithms.callbacks import DefaultCallbacks +from ray.rllib.execution.rollout_ops import synchronous_parallel_sample +from ray.rllib.execution.train_ops import multi_gpu_train_one_step, train_one_step +from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.models.modelv2 import restore_original_dimensions +from ray.rllib.models.torch.torch_action_dist import TorchCategorical +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.sample_batch import concat_samples +from ray.rllib.utils.annotations import override +from ray.rllib.utils.deprecation import DEPRECATED_VALUE +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.metrics import ( + NUM_AGENT_STEPS_SAMPLED, + NUM_ENV_STEPS_SAMPLED, + SAMPLE_TIMER, + SYNCH_WORKER_WEIGHTS_TIMER, +) +from ray.rllib.utils.replay_buffers.utils import validate_buffer_config +from ray.rllib.utils.typing import ResultDict + +torch, nn = try_import_torch() + +logger = logging.getLogger(__name__) + + +class AlphaZeroDefaultCallbacks(DefaultCallbacks): + """AlphaZero callbacks. + + If you use custom callbacks, you must extend this class and call super() + for on_episode_start. + """ + + def on_episode_start(self, worker, base_env, policies, episode, **kwargs): + # Save environment's state when an episode starts. + env = base_env.get_sub_environments()[0] + state = env.get_state() + episode.user_data["initial_state"] = state + + +class AlphaZeroConfig(AlgorithmConfig): + """Defines a configuration class from which an AlphaZero Algorithm can be built. + + Example: + >>> from rllib_alpha_zero.alpha_zero import AlphaZeroConfig + >>> config = AlphaZeroConfig() # doctest: +SKIP + >>> config = config.training(sgd_minibatch_size=256) # doctest: +SKIP + >>> config = config..resources(num_gpus=0) # doctest: +SKIP + >>> config = config..rollouts(num_rollout_workers=4) # doctest: +SKIP + >>> print(config.to_dict()) # doctest: +SKIP + >>> # Build a Algorithm object from the config and run 1 training iteration. + >>> algo = config.build(env="CartPole-v1") # doctest: +SKIP + >>> algo.train() # doctest: +SKIP + + Example: + >>> from rllib_alpha_zero.alpha_zero import AlphaZeroConfig + >>> from ray import air + >>> from ray import tune + >>> config = AlphaZeroConfig() + >>> # Print out some default values. + >>> print(config.shuffle_sequences) # doctest: +SKIP + >>> # Update the config object. + >>> config.training(lr=tune.grid_search([0.001, 0.0001])) # doctest: +SKIP + >>> # Set the config object's env. + >>> config.environment(env="CartPole-v1") # doctest: +SKIP + >>> # Use to_dict() to get the old-style python config dict + >>> # when running with tune. + >>> tune.Tuner( # doctest: +SKIP + ... "AlphaZero", + ... run_config=air.RunConfig(stop={"episode_reward_mean": 200}), + ... param_space=config.to_dict(), + ... ).fit() + """ + + def __init__(self, algo_class=None): + """Initializes a PPOConfig instance.""" + super().__init__(algo_class=algo_class or AlphaZero) + + # fmt: off + # __sphinx_doc_begin__ + # AlphaZero specific config settings: + self.sgd_minibatch_size = 128 + self.shuffle_sequences = True + self.num_sgd_iter = 30 + self.replay_buffer_config = { + "type": "ReplayBuffer", + # Size of the replay buffer in batches (not timesteps!). + "capacity": 1000, + # Choosing `fragments` here makes it so that the buffer stores entire + # batches, instead of sequences, episodes or timesteps. + "storage_unit": "fragments", + } + # Number of timesteps to collect from rollout workers before we start + # sampling from replay buffers for learning. Whether we count this in agent + # steps or environment steps depends on config.multi_agent(count_steps_by=..). + self.num_steps_sampled_before_learning_starts = 1000 + self.lr_schedule = None + self.vf_share_layers = False + self.mcts_config = { + "puct_coefficient": 1.0, + "num_simulations": 30, + "temperature": 1.5, + "dirichlet_epsilon": 0.25, + "dirichlet_noise": 0.03, + "argmax_tree_policy": False, + "add_dirichlet_noise": True, + } + self.ranked_rewards = { + "enable": True, + "percentile": 75, + "buffer_max_length": 1000, + # add rewards obtained from random policy to + # "warm start" the buffer + "initialize_buffer": True, + "num_init_rewards": 100, + } + + # Override some of AlgorithmConfig's default values with AlphaZero-specific + # values. + self.framework_str = "torch" + self.callbacks_class = AlphaZeroDefaultCallbacks + self.lr = 5e-5 + self.num_rollout_workers = 2 + self.rollout_fragment_length = 200 + self.train_batch_size = 4000 + self.batch_mode = "complete_episodes" + # Extra configuration that disables exploration. + self.evaluation(evaluation_config={ + "mcts_config": { + "argmax_tree_policy": True, + "add_dirichlet_noise": False, + }, + }) + self.exploration_config = { + # The Exploration class to use. In the simplest case, this is the name + # (str) of any class present in the `rllib.utils.exploration` package. + # You can also provide the python class directly or the full location + # of your class (e.g. "ray.rllib.utils.exploration.epsilon_greedy. + # EpsilonGreedy"). + "type": "StochasticSampling", + # Add constructor kwargs here (if any). + } + # __sphinx_doc_end__ + # fmt: on + + self.buffer_size = DEPRECATED_VALUE + + @override(AlgorithmConfig) + def training( + self, + *, + sgd_minibatch_size: Optional[int] = NotProvided, + shuffle_sequences: Optional[bool] = NotProvided, + num_sgd_iter: Optional[int] = NotProvided, + replay_buffer_config: Optional[dict] = NotProvided, + lr_schedule: Optional[List[List[Union[int, float]]]] = NotProvided, + vf_share_layers: Optional[bool] = NotProvided, + mcts_config: Optional[dict] = NotProvided, + ranked_rewards: Optional[dict] = NotProvided, + num_steps_sampled_before_learning_starts: Optional[int] = NotProvided, + **kwargs, + ) -> "AlphaZeroConfig": + """Sets the training related configuration. + + Args: + sgd_minibatch_size: Total SGD batch size across all devices for SGD. + shuffle_sequences: Whether to shuffle sequences in the batch when training + (recommended). + num_sgd_iter: Number of SGD iterations in each outer loop. + replay_buffer_config: Replay buffer config. + Examples: + { + "_enable_replay_buffer_api": True, + "type": "MultiAgentReplayBuffer", + "learning_starts": 1000, + "capacity": 50000, + "replay_sequence_length": 1, + } + - OR - + { + "_enable_replay_buffer_api": True, + "type": "MultiAgentPrioritizedReplayBuffer", + "capacity": 50000, + "prioritized_replay_alpha": 0.6, + "prioritized_replay_beta": 0.4, + "prioritized_replay_eps": 1e-6, + "replay_sequence_length": 1, + } + - Where - + prioritized_replay_alpha: Alpha parameter controls the degree of + prioritization in the buffer. In other words, when a buffer sample has + a higher temporal-difference error, with how much more probability + should it drawn to use to update the parametrized Q-network. 0.0 + corresponds to uniform probability. Setting much above 1.0 may quickly + result as the sampling distribution could become heavily “pointy” with + low entropy. + prioritized_replay_beta: Beta parameter controls the degree of + importance sampling which suppresses the influence of gradient updates + from samples that have higher probability of being sampled via alpha + parameter and the temporal-difference error. + prioritized_replay_eps: Epsilon parameter sets the baseline probability + for sampling so that when the temporal-difference error of a sample is + zero, there is still a chance of drawing the sample. + lr_schedule: Learning rate schedule. In the format of + [[timestep, lr-value], [timestep, lr-value], ...] + Intermediary timesteps will be assigned to interpolated learning rate + values. A schedule should normally start from timestep 0. + vf_share_layers: Share layers for value function. If you set this to True, + it's important to tune vf_loss_coeff. + mcts_config: MCTS specific settings. + ranked_rewards: Settings for the ranked reward (r2) algorithm + from: https://arxiv.org/pdf/1807.01672.pdf + num_steps_sampled_before_learning_starts: Number of timesteps to collect + from rollout workers before we start sampling from replay buffers for + learning. Whether we count this in agent steps or environment steps + depends on config.multi_agent(count_steps_by=..). + + Returns: + This updated AlgorithmConfig object. + """ + # Pass kwargs onto super's `training()` method. + super().training(**kwargs) + + if sgd_minibatch_size is not NotProvided: + self.sgd_minibatch_size = sgd_minibatch_size + if shuffle_sequences is not NotProvided: + self.shuffle_sequences = shuffle_sequences + if num_sgd_iter is not NotProvided: + self.num_sgd_iter = num_sgd_iter + if replay_buffer_config is not NotProvided: + self.replay_buffer_config = replay_buffer_config + if lr_schedule is not NotProvided: + self.lr_schedule = lr_schedule + if vf_share_layers is not NotProvided: + self.vf_share_layers = vf_share_layers + if mcts_config is not NotProvided: + self.mcts_config = mcts_config + if ranked_rewards is not NotProvided: + self.ranked_rewards.update(ranked_rewards) + if num_steps_sampled_before_learning_starts is not NotProvided: + self.num_steps_sampled_before_learning_starts = ( + num_steps_sampled_before_learning_starts + ) + + return self + + @override(AlgorithmConfig) + def update_from_dict(self, config_dict) -> "AlphaZeroConfig": + config_dict = config_dict.copy() + + if "ranked_rewards" in config_dict: + value = config_dict.pop("ranked_rewards") + self.training(ranked_rewards=value) + + return super().update_from_dict(config_dict) + + @override(AlgorithmConfig) + def validate(self) -> None: + """Checks and updates the config based on settings.""" + # Call super's validation method. + super().validate() + validate_buffer_config(self) + + +def alpha_zero_loss(policy, model, dist_class, train_batch): + # get inputs unflattened inputs + input_dict = restore_original_dimensions( + train_batch["obs"], policy.observation_space, "torch" + ) + # forward pass in model + model_out = model.forward(input_dict, None, [1]) + logits, _ = model_out + values = model.value_function() + logits, values = torch.squeeze(logits), torch.squeeze(values) + priors = nn.Softmax(dim=-1)(logits) + # compute actor and critic losses + policy_loss = torch.mean( + -torch.sum(train_batch["mcts_policies"] * torch.log(priors), dim=-1) + ) + value_loss = torch.mean(torch.pow(values - train_batch["value_label"], 2)) + # compute total loss + total_loss = (policy_loss + value_loss) / 2 + return total_loss, policy_loss, value_loss + + +class AlphaZeroPolicyWrapperClass(AlphaZeroPolicy): + def __init__(self, obs_space, action_space, config): + model = ModelCatalog.get_model_v2( + obs_space, action_space, action_space.n, config["model"], "torch" + ) + _, env_creator = Algorithm._get_env_id_and_creator(config["env"], config) + if config["ranked_rewards"]["enable"]: + # if r2 is enabled, tne env is wrapped to include a rewards buffer + # used to normalize rewards + env_cls = get_r2_env_wrapper(env_creator, config["ranked_rewards"]) + + # the wrapped env is used only in the mcts, not in the + # rollout workers + def _env_creator(): + return env_cls(config["env_config"]) + + else: + + def _env_creator(): + return env_creator(config["env_config"]) + + def mcts_creator(): + return MCTS(model, config["mcts_config"]) + + super().__init__( + obs_space, + action_space, + config, + model, + alpha_zero_loss, + TorchCategorical, + mcts_creator, + _env_creator, + ) + + +class AlphaZero(Algorithm): + @classmethod + @override(Algorithm) + def get_default_config(cls) -> AlgorithmConfig: + return AlphaZeroConfig() + + @classmethod + @override(Algorithm) + def get_default_policy_class( + cls, config: AlgorithmConfig + ) -> Optional[Type[Policy]]: + return AlphaZeroPolicyWrapperClass + + @override(Algorithm) + def training_step(self) -> ResultDict: + """TODO: + + Returns: + The results dict from executing the training iteration. + """ + + # Sample n MultiAgentBatches from n workers. + with self._timers[SAMPLE_TIMER]: + new_sample_batches = synchronous_parallel_sample( + worker_set=self.workers, concat=False + ) + + for batch in new_sample_batches: + # Update sampling step counters. + self._counters[NUM_ENV_STEPS_SAMPLED] += batch.env_steps() + self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps() + # Store new samples in the replay buffer + if self.local_replay_buffer is not None: + self.local_replay_buffer.add(batch) + + if self.local_replay_buffer is not None: + # Update target network every `target_network_update_freq` sample steps. + cur_ts = self._counters[ + NUM_AGENT_STEPS_SAMPLED + if self.config.count_steps_by == "agent_steps" + else NUM_ENV_STEPS_SAMPLED + ] + + if cur_ts > self.config.num_steps_sampled_before_learning_starts: + train_batch = self.local_replay_buffer.sample( + self.config.train_batch_size + ) + else: + train_batch = None + else: + train_batch = concat_samples(new_sample_batches) + + # Learn on the training batch. + # Use simple optimizer (only for multi-agent or tf-eager; all other + # cases should use the multi-GPU optimizer, even if only using 1 GPU) + train_results = {} + if train_batch is not None: + if self.config.get("simple_optimizer") is True: + train_results = train_one_step(self, train_batch) + else: + train_results = multi_gpu_train_one_step(self, train_batch) + + # TODO: Move training steps counter update outside of `train_one_step()` method. + # # Update train step counters. + # self._counters[NUM_ENV_STEPS_TRAINED] += train_batch.env_steps() + # self._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps() + + # Update weights and global_vars - after learning on the local worker - on all + # remote workers. + global_vars = { + "timestep": self._counters[NUM_ENV_STEPS_SAMPLED], + } + with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: + self.workers.sync_weights(global_vars=global_vars) + + # Return all collected metrics for the iteration. + return train_results diff --git a/rllib_contrib/alpha_zero/src/rllib_alpha_zero/alpha_zero/alpha_zero_policy.py b/rllib_contrib/alpha_zero/src/rllib_alpha_zero/alpha_zero/alpha_zero_policy.py new file mode 100644 index 0000000000000..a0f4c9cd63b7c --- /dev/null +++ b/rllib_contrib/alpha_zero/src/rllib_alpha_zero/alpha_zero/alpha_zero_policy.py @@ -0,0 +1,158 @@ +import numpy as np +from rllib_alpha_zero.alpha_zero.mcts import Node, RootParentNode + +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.torch_policy import TorchPolicy +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY + +torch, _ = try_import_torch() + + +class AlphaZeroPolicy(TorchPolicy): + def __init__( + self, + observation_space, + action_space, + config, + model, + loss, + action_distribution_class, + mcts_creator, + env_creator, + **kwargs + ): + super().__init__( + observation_space, + action_space, + config, + model=model, + loss=loss, + action_distribution_class=action_distribution_class, + ) + # we maintain an env copy in the policy that is used during mcts + # simulations + self.env_creator = env_creator + self.mcts = mcts_creator() + self.env = self.env_creator() + self.env.reset() + self.obs_space = observation_space + + @override(TorchPolicy) + def compute_actions( + self, + obs_batch, + state_batches=None, + prev_action_batch=None, + prev_reward_batch=None, + info_batch=None, + episodes=None, + **kwargs + ): + + input_dict = {"obs": obs_batch} + if prev_action_batch is not None: + input_dict["prev_actions"] = prev_action_batch + if prev_reward_batch is not None: + input_dict["prev_rewards"] = prev_reward_batch + + return self.compute_actions_from_input_dict( + input_dict=input_dict, + episodes=episodes, + state_batches=state_batches, + ) + + @override(Policy) + def compute_actions_from_input_dict( + self, input_dict, explore=None, timestep=None, episodes=None, **kwargs + ): + with torch.no_grad(): + actions = [] + for i, episode in enumerate(episodes): + if episode.length == 0: + # if first time step of episode, get initial env state + env_state = episode.user_data["initial_state"] + # verify if env has been wrapped for ranked rewards + if self.env.__class__.__name__ == "RankedRewardsEnvWrapper": + # r2 env state contains also the rewards buffer state + env_state = {"env_state": env_state, "buffer_state": None} + # create tree root node + obs = self.env.set_state(env_state) + tree_node = Node( + state=env_state, + obs=obs, + reward=0, + done=False, + action=None, + parent=RootParentNode(env=self.env), + mcts=self.mcts, + ) + else: + # otherwise get last root node from previous time step + tree_node = episode.user_data["tree_node"] + + # run monte carlo simulations to compute the actions + # and record the tree + mcts_policy, action, tree_node = self.mcts.compute_action(tree_node) + # record action + actions.append(action) + # store new node + episode.user_data["tree_node"] = tree_node + + # store mcts policies vectors and current tree root node + if episode.length == 0: + episode.user_data["mcts_policies"] = [mcts_policy] + else: + episode.user_data["mcts_policies"].append(mcts_policy) + + return ( + np.array(actions), + [], + self.extra_action_out( + input_dict, kwargs.get("state_batches", []), self.model, None + ), + ) + + @override(Policy) + def postprocess_trajectory( + self, sample_batch, other_agent_batches=None, episode=None + ): + # add mcts policies to sample batch + sample_batch["mcts_policies"] = np.array(episode.user_data["mcts_policies"])[ + sample_batch["t"] + ] + # final episode reward corresponds to the value (if not discounted) + # for all transitions in episode + final_reward = sample_batch["rewards"][-1] + # if r2 is enabled, then add the reward to the buffer and normalize it + if self.env.__class__.__name__ == "RankedRewardsEnvWrapper": + self.env.r2_buffer.add_reward(final_reward) + final_reward = self.env.r2_buffer.normalize(final_reward) + sample_batch["value_label"] = final_reward * np.ones_like(sample_batch["t"]) + return sample_batch + + @override(TorchPolicy) + def learn_on_batch(self, postprocessed_batch): + train_batch = self._lazy_tensor_dict(postprocessed_batch) + + loss_out, policy_loss, value_loss = self._loss( + self, self.model, self.dist_class, train_batch + ) + self._optimizers[0].zero_grad() + loss_out.backward() + + grad_process_info = self.extra_grad_process(self._optimizers[0], loss_out) + self._optimizers[0].step() + + grad_info = self.extra_grad_info(train_batch) + grad_info.update(grad_process_info) + grad_info.update( + { + "total_loss": loss_out.detach().cpu().numpy(), + "policy_loss": policy_loss.detach().cpu().numpy(), + "value_loss": value_loss.detach().cpu().numpy(), + } + ) + + return {LEARNER_STATS_KEY: grad_info} diff --git a/rllib_contrib/alpha_zero/src/rllib_alpha_zero/alpha_zero/custom_torch_models.py b/rllib_contrib/alpha_zero/src/rllib_alpha_zero/alpha_zero/custom_torch_models.py new file mode 100644 index 0000000000000..9fc7d1037b69c --- /dev/null +++ b/rllib_contrib/alpha_zero/src/rllib_alpha_zero/alpha_zero/custom_torch_models.py @@ -0,0 +1,116 @@ +from abc import ABC + +import numpy as np + +from ray.rllib.models.modelv2 import restore_original_dimensions +from ray.rllib.models.preprocessors import get_preprocessor +from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 +from ray.rllib.utils.framework import try_import_torch + +torch, nn = try_import_torch() + + +def convert_to_tensor(arr): + tensor = torch.from_numpy(np.asarray(arr)) + if tensor.dtype == torch.double: + tensor = tensor.float() + return tensor + + +class ActorCriticModel(TorchModelV2, nn.Module, ABC): + def __init__(self, obs_space, action_space, num_outputs, model_config, name): + TorchModelV2.__init__( + self, obs_space, action_space, num_outputs, model_config, name + ) + nn.Module.__init__(self) + + self.preprocessor = get_preprocessor(obs_space.original_space)( + obs_space.original_space + ) + + self.shared_layers = None + self.actor_layers = None + self.critic_layers = None + + self._value_out = None + + def forward(self, input_dict, state, seq_lens): + x = input_dict["obs"] + x = self.shared_layers(x) + # actor outputs + logits = self.actor_layers(x) + + # compute value + self._value_out = self.critic_layers(x) + return logits, None + + def value_function(self): + return self._value_out + + def compute_priors_and_value(self, obs): + obs = convert_to_tensor([self.preprocessor.transform(obs)]) + input_dict = restore_original_dimensions(obs, self.obs_space, "torch") + + with torch.no_grad(): + model_out = self.forward(input_dict, None, [1]) + logits, _ = model_out + value = self.value_function() + logits, value = torch.squeeze(logits), torch.squeeze(value) + priors = nn.Softmax(dim=-1)(logits) + + priors = priors.cpu().numpy() + value = value.cpu().numpy() + + return priors, value + + +class Flatten(nn.Module): + def forward(self, input): + return input.view(input.size(0), -1) + + +class ConvNetModel(ActorCriticModel): + def __init__(self, obs_space, action_space, num_outputs, model_config, name): + ActorCriticModel.__init__( + self, obs_space, action_space, num_outputs, model_config, name + ) + + in_channels = model_config["custom_model_config"]["in_channels"] + feature_dim = model_config["custom_model_config"]["feature_dim"] + + self.shared_layers = nn.Sequential( + nn.Conv2d(in_channels, 32, kernel_size=4, stride=2), + nn.Conv2d(32, 64, kernel_size=2, stride=1), + nn.Conv2d(64, 64, kernel_size=2, stride=1), + Flatten(), + nn.Linear(1024, feature_dim), + ) + + self.actor_layers = nn.Sequential( + nn.Linear(in_features=feature_dim, out_features=action_space.n) + ) + + self.critic_layers = nn.Sequential( + nn.Linear(in_features=feature_dim, out_features=1) + ) + + self._value_out = None + + +class DenseModel(ActorCriticModel): + def __init__(self, obs_space, action_space, num_outputs, model_config, name): + ActorCriticModel.__init__( + self, obs_space, action_space, num_outputs, model_config, name + ) + + self.shared_layers = nn.Sequential( + nn.Linear( + in_features=obs_space.original_space["obs"].shape[0], out_features=256 + ), + nn.Linear(in_features=256, out_features=256), + ) + self.actor_layers = nn.Sequential( + nn.Linear(in_features=256, out_features=action_space.n) + ) + self.critic_layers = nn.Sequential(nn.Linear(in_features=256, out_features=1)) + self._value_out = None diff --git a/rllib_contrib/alpha_zero/src/rllib_alpha_zero/alpha_zero/mcts.py b/rllib_contrib/alpha_zero/src/rllib_alpha_zero/alpha_zero/mcts.py new file mode 100644 index 0000000000000..72f9712bbf3a4 --- /dev/null +++ b/rllib_contrib/alpha_zero/src/rllib_alpha_zero/alpha_zero/mcts.py @@ -0,0 +1,157 @@ +""" +Mcts implementation modified from +https://github.com/brilee/python_uct/blob/master/numpy_impl.py +""" +import collections +import math + +import numpy as np + + +class Node: + def __init__(self, action, obs, done, reward, state, mcts, parent=None): + self.env = parent.env + self.action = action # Action used to go to this state + + self.is_expanded = False + self.parent = parent + self.children = {} + + self.action_space_size = self.env.action_space.n + self.child_total_value = np.zeros( + [self.action_space_size], dtype=np.float32 + ) # Q + self.child_priors = np.zeros([self.action_space_size], dtype=np.float32) # P + self.child_number_visits = np.zeros( + [self.action_space_size], dtype=np.float32 + ) # N + self.valid_actions = obs["action_mask"].astype(np.bool_) + + self.reward = reward + self.done = done + self.state = state + self.obs = obs + + self.mcts = mcts + + @property + def number_visits(self): + return self.parent.child_number_visits[self.action] + + @number_visits.setter + def number_visits(self, value): + self.parent.child_number_visits[self.action] = value + + @property + def total_value(self): + return self.parent.child_total_value[self.action] + + @total_value.setter + def total_value(self, value): + self.parent.child_total_value[self.action] = value + + def child_Q(self): + # TODO (weak todo) add "softmax" version of the Q-value + return self.child_total_value / (1 + self.child_number_visits) + + def child_U(self): + return ( + math.sqrt(self.number_visits) + * self.child_priors + / (1 + self.child_number_visits) + ) + + def best_action(self): + """ + :return: action + """ + child_score = self.child_Q() + self.mcts.c_puct * self.child_U() + masked_child_score = child_score + masked_child_score[~self.valid_actions] = -np.inf + return np.argmax(masked_child_score) + + def select(self): + current_node = self + while current_node.is_expanded: + best_action = current_node.best_action() + current_node = current_node.get_child(best_action) + return current_node + + def expand(self, child_priors): + self.is_expanded = True + self.child_priors = child_priors + + def get_child(self, action): + if action not in self.children: + self.env.set_state(self.state) + obs, reward, terminated, truncated, _ = self.env.step(action) + next_state = self.env.get_state() + self.children[action] = Node( + state=next_state, + action=action, + parent=self, + reward=reward, + done=terminated, + obs=obs, + mcts=self.mcts, + ) + return self.children[action] + + def backup(self, value): + current = self + while current.parent is not None: + current.number_visits += 1 + current.total_value += value + current = current.parent + + +class RootParentNode: + def __init__(self, env): + self.parent = None + self.child_total_value = collections.defaultdict(float) + self.child_number_visits = collections.defaultdict(float) + self.env = env + + +class MCTS: + def __init__(self, model, mcts_param): + self.model = model + self.temperature = mcts_param["temperature"] + self.dir_epsilon = mcts_param["dirichlet_epsilon"] + self.dir_noise = mcts_param["dirichlet_noise"] + self.num_sims = mcts_param["num_simulations"] + self.exploit = mcts_param["argmax_tree_policy"] + self.add_dirichlet_noise = mcts_param["add_dirichlet_noise"] + self.c_puct = mcts_param["puct_coefficient"] + + def compute_action(self, node): + for _ in range(self.num_sims): + leaf = node.select() + if leaf.done: + value = leaf.reward + else: + child_priors, value = self.model.compute_priors_and_value(leaf.obs) + if self.add_dirichlet_noise: + child_priors = (1 - self.dir_epsilon) * child_priors + child_priors += self.dir_epsilon * np.random.dirichlet( + [self.dir_noise] * child_priors.size + ) + + leaf.expand(child_priors) + leaf.backup(value) + + # Tree policy target (TPT) + tree_policy = node.child_number_visits / node.number_visits + tree_policy = tree_policy / np.max( + tree_policy + ) # to avoid overflows when computing softmax + tree_policy = np.power(tree_policy, self.temperature) + tree_policy = tree_policy / np.sum(tree_policy) + if self.exploit: + # if exploit then choose action that has the maximum + # tree policy probability + action = np.argmax(tree_policy) + else: + # otherwise sample an action according to tree policy probabilities + action = np.random.choice(np.arange(node.action_space_size), p=tree_policy) + return tree_policy, action, node.children[action] diff --git a/rllib_contrib/alpha_zero/src/rllib_alpha_zero/alpha_zero/ranked_rewards.py b/rllib_contrib/alpha_zero/src/rllib_alpha_zero/alpha_zero/ranked_rewards.py new file mode 100644 index 0000000000000..198571a06d76e --- /dev/null +++ b/rllib_contrib/alpha_zero/src/rllib_alpha_zero/alpha_zero/ranked_rewards.py @@ -0,0 +1,78 @@ +from copy import deepcopy + +import numpy as np + + +class RankedRewardsBuffer: + def __init__(self, buffer_max_length, percentile): + self.buffer_max_length = buffer_max_length + self.percentile = percentile + self.buffer = [] + + def add_reward(self, reward): + if len(self.buffer) < self.buffer_max_length: + self.buffer.append(reward) + else: + self.buffer = self.buffer[1:] + [reward] + + def normalize(self, reward): + reward_threshold = np.percentile(self.buffer, self.percentile) + if reward < reward_threshold: + return -1.0 + else: + return 1.0 + + def get_state(self): + return np.array(self.buffer) + + def set_state(self, state): + if state is not None: + self.buffer = list(state) + + +def get_r2_env_wrapper(env_creator, r2_config): + class RankedRewardsEnvWrapper: + def __init__(self, env_config): + self.env = env_creator(env_config) + self.action_space = self.env.action_space + self.observation_space = self.env.observation_space + max_buffer_length = r2_config["buffer_max_length"] + percentile = r2_config["percentile"] + self.r2_buffer = RankedRewardsBuffer(max_buffer_length, percentile) + if r2_config["initialize_buffer"]: + self._initialize_buffer(r2_config["num_init_rewards"]) + + def _initialize_buffer(self, num_init_rewards=100): + # initialize buffer with random policy + for _ in range(num_init_rewards): + obs, info = self.env.reset() + terminated = truncated = False + while not terminated and not truncated: + mask = obs["action_mask"] + probs = mask / mask.sum() + action = np.random.choice(np.arange(mask.shape[0]), p=probs) + obs, reward, terminated, truncated, _ = self.env.step(action) + self.r2_buffer.add_reward(reward) + + def step(self, action): + obs, reward, terminated, truncated, info = self.env.step(action) + if terminated or truncated: + reward = self.r2_buffer.normalize(reward) + return obs, reward, terminated, truncated, info + + def get_state(self): + state = { + "env_state": self.env.get_state(), + "buffer_state": self.r2_buffer.get_state(), + } + return deepcopy(state) + + def reset(self, *, seed=None, options=None): + return self.env.reset() + + def set_state(self, state): + obs = self.env.set_state(state["env_state"]) + self.r2_buffer.set_state(state["buffer_state"]) + return obs + + return RankedRewardsEnvWrapper diff --git a/rllib_contrib/alpha_zero/tests/test_alpha_zero.py b/rllib_contrib/alpha_zero/tests/test_alpha_zero.py new file mode 100644 index 0000000000000..4579b44dd028f --- /dev/null +++ b/rllib_contrib/alpha_zero/tests/test_alpha_zero.py @@ -0,0 +1,44 @@ +import unittest + +import rllib_alpha_zero.alpha_zero as az +from rllib_alpha_zero.alpha_zero.custom_torch_models import DenseModel + +import ray +from ray.rllib.examples.env.cartpole_sparse_rewards import CartPoleSparseRewards +from ray.rllib.utils.test_utils import check_train_results, framework_iterator + + +class TestAlphaZero(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + ray.init() + + @classmethod + def tearDownClass(cls) -> None: + ray.shutdown() + + def test_alpha_zero_compilation(self): + """Test whether AlphaZero can be built with all frameworks.""" + config = ( + az.AlphaZeroConfig() + .environment(env=CartPoleSparseRewards) + .training(model={"custom_model": DenseModel}) + ) + num_iterations = 1 + + # Only working for torch right now. + for _ in framework_iterator(config, frameworks="torch"): + algo = config.build() + for i in range(num_iterations): + results = algo.train() + check_train_results(results) + print(results) + algo.stop() + + +if __name__ == "__main__": + import sys + + import pytest + + sys.exit(pytest.main(["-v", __file__])) From ef603f94ac90f7ed0f7a7ac1e3a9cad8fd253c9c Mon Sep 17 00:00:00 2001 From: Avnish Date: Thu, 6 Jul 2023 13:30:37 -0700 Subject: [PATCH 2/3] Add example Signed-off-by: Avnish --- .buildkite/pipeline.ml.yml | 1 + .../alpha_zero_cartpole_sparse_rewards.py | 62 +++++++++++++++++++ 2 files changed, 63 insertions(+) create mode 100644 rllib_contrib/alpha_zero/examples/alpha_zero_cartpole_sparse_rewards.py diff --git a/.buildkite/pipeline.ml.yml b/.buildkite/pipeline.ml.yml index 03d2ffd29ec0c..caaeb51c8a87f 100644 --- a/.buildkite/pipeline.ml.yml +++ b/.buildkite/pipeline.ml.yml @@ -571,6 +571,7 @@ - (cd rllib_contrib/alpha_zero && pip install -r requirements.txt && pip install -e .) - ./ci/env/env_info.sh - pytest rllib_contrib/alpha_zero/tests/ + - python alpha_zero/examples/alpha_zero_cartpole_sparse_rewards.py --run-as-test - label: ":exploding_death_star: RLlib Contrib: DDPG Tests" conditions: ["NO_WHEELS_REQUIRED", "RAY_CI_RLLIB_CONTRIB_AFFECTED"] diff --git a/rllib_contrib/alpha_zero/examples/alpha_zero_cartpole_sparse_rewards.py b/rllib_contrib/alpha_zero/examples/alpha_zero_cartpole_sparse_rewards.py new file mode 100644 index 0000000000000..1b064591fcea8 --- /dev/null +++ b/rllib_contrib/alpha_zero/examples/alpha_zero_cartpole_sparse_rewards.py @@ -0,0 +1,62 @@ +import argparse + +from rllib_alpha_zero.alpha_zero import AlphaZero, AlphaZeroConfig +from rllib_alpha_zero.alpha_zero.custom_torch_models import DenseModel + +import ray +from ray import air, tune +from ray.rllib.examples.env.cartpole_sparse_rewards import CartPoleSparseRewards +from ray.rllib.utils.test_utils import check_learning_achieved + + +def get_cli_args(): + """Create CLI parser and return parsed arguments""" + parser = argparse.ArgumentParser() + parser.add_argument("--run-as-test", action="store_true", default=False) + args = parser.parse_args() + print(f"Running with following CLI args: {args}") + return args + + +if __name__ == "__main__": + args = get_cli_args() + + ray.init() + + config = ( + AlphaZeroConfig() + .rollouts(num_rollout_workers=6, rollout_fragment_length=50, ) + .framework("torch") + .environment(CartPoleSparseRewards) + .training(train_batch_size=500, + sgd_minibatch_size=64, + lr=1e-4, + num_sgd_iter=1, + mcts_config={"puct_coefficient": 1.5, + "num_simulations": 100, + "temperature": 1.0, + "dirichlet_epsilon": 0.20, + "dirichlet_noise": 0.03, + "argmax_tree_policy": False, + "add_dirichlet_noise": True,}, + ranked_rewards={"enable": True,}, + model={"custom_model": DenseModel,}) + ) + + stop_reward = 30. + + tuner = tune.Tuner( + AlphaZero, + param_space=config.to_dict(), + run_config=air.RunConfig( + stop={ + "sampler_results/episode_reward_mean": stop_reward, + "timesteps_total": 100000, + }, + failure_config=air.FailureConfig(fail_fast="raise"), + ), + ) + results = tuner.fit() + + if args.run_as_test: + check_learning_achieved(results, stop_reward) \ No newline at end of file From 9b861a880eeb8532187b96e65b0b731f5fe3ebc4 Mon Sep 17 00:00:00 2001 From: Avnish Date: Thu, 6 Jul 2023 13:30:47 -0700 Subject: [PATCH 3/3] Lint Signed-off-by: Avnish --- .../alpha_zero_cartpole_sparse_rewards.py | 43 ++++++++++++------- 1 file changed, 27 insertions(+), 16 deletions(-) diff --git a/rllib_contrib/alpha_zero/examples/alpha_zero_cartpole_sparse_rewards.py b/rllib_contrib/alpha_zero/examples/alpha_zero_cartpole_sparse_rewards.py index 1b064591fcea8..739effeec1781 100644 --- a/rllib_contrib/alpha_zero/examples/alpha_zero_cartpole_sparse_rewards.py +++ b/rllib_contrib/alpha_zero/examples/alpha_zero_cartpole_sparse_rewards.py @@ -25,25 +25,36 @@ def get_cli_args(): config = ( AlphaZeroConfig() - .rollouts(num_rollout_workers=6, rollout_fragment_length=50, ) + .rollouts( + num_rollout_workers=6, + rollout_fragment_length=50, + ) .framework("torch") .environment(CartPoleSparseRewards) - .training(train_batch_size=500, - sgd_minibatch_size=64, - lr=1e-4, - num_sgd_iter=1, - mcts_config={"puct_coefficient": 1.5, - "num_simulations": 100, - "temperature": 1.0, - "dirichlet_epsilon": 0.20, - "dirichlet_noise": 0.03, - "argmax_tree_policy": False, - "add_dirichlet_noise": True,}, - ranked_rewards={"enable": True,}, - model={"custom_model": DenseModel,}) + .training( + train_batch_size=500, + sgd_minibatch_size=64, + lr=1e-4, + num_sgd_iter=1, + mcts_config={ + "puct_coefficient": 1.5, + "num_simulations": 100, + "temperature": 1.0, + "dirichlet_epsilon": 0.20, + "dirichlet_noise": 0.03, + "argmax_tree_policy": False, + "add_dirichlet_noise": True, + }, + ranked_rewards={ + "enable": True, + }, + model={ + "custom_model": DenseModel, + }, + ) ) - stop_reward = 30. + stop_reward = 30.0 tuner = tune.Tuner( AlphaZero, @@ -59,4 +70,4 @@ def get_cli_args(): results = tuner.fit() if args.run_as_test: - check_learning_achieved(results, stop_reward) \ No newline at end of file + check_learning_achieved(results, stop_reward)