Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Fix issues with action masking examples #38095

Merged
merged 14 commits into from Aug 9, 2023
Merged
4 changes: 2 additions & 2 deletions rllib/BUILD
Expand Up @@ -2977,12 +2977,12 @@ py_test(
# --------------------------------------------------------------------

py_test(
name = "examples/action_masking_tf",
name = "examples/action_masking_tf2",
main = "examples/action_masking.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "small",
srcs = ["examples/action_masking.py"],
args = ["--stop-iter=2", "--framework=tf"]
args = ["--stop-iter=2", "--framework=tf2"]
)

py_test(
Expand Down
137 changes: 45 additions & 92 deletions rllib/examples/action_masking.py
Expand Up @@ -41,31 +41,21 @@

from gymnasium.spaces import Box, Discrete
import ray
from ray import air, tune
from ray.rllib.algorithms import ppo
from ray.rllib.examples.env.action_mask_env import ActionMaskEnv
from ray.rllib.examples.models.action_mask_model import (
ActionMaskModel,
TorchActionMaskModel,
from ray.rllib.examples.rl_module.action_masking_rlm import (
TorchActionMaskRLM,
TFActionMaskRLM,
)
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec

from ray.tune.logger import pretty_print


def get_cli_args():
"""Create CLI parser and return parsed arguments"""
parser = argparse.ArgumentParser()

# example-specific args
parser.add_argument(
"--no-masking",
action="store_true",
help="Do NOT mask invalid actions. This will likely lead to errors.",
)

# general args
parser.add_argument(
"--run", type=str, default="APPO", help="The RLlib-registered algorithm to use."
)
parser.add_argument("--num-cpus", type=int, default=0)
parser.add_argument(
"--framework",
Expand All @@ -76,24 +66,6 @@ def get_cli_args():
parser.add_argument(
"--stop-iters", type=int, default=10, help="Number of iterations to train."
)
parser.add_argument(
"--stop-timesteps",
type=int,
default=10000,
help="Number of timesteps to train.",
)
parser.add_argument(
"--stop-reward",
type=float,
default=80.0,
help="Reward at which we stop training.",
)
parser.add_argument(
"--no-tune",
action="store_true",
help="Run without Tune using a manual train loop instead. Here,"
"there is no TensorBoard support.",
)
parser.add_argument(
"--local-mode",
action="store_true",
Expand All @@ -110,6 +82,15 @@ def get_cli_args():

ray.init(num_cpus=args.num_cpus or None, local_mode=args.local_mode)

if args.framework == "torch":
rlm_class = TorchActionMaskRLM
elif args.framework == "tf2":
rlm_class = TFActionMaskRLM
else:
raise ValueError(f"Unsupported framework: {args.framework}")

rlm_spec = SingleAgentRLModuleSpec(module_class=rlm_class)

# main part: configure the ActionMaskEnv and ActionMaskModel
config = (
ppo.PPOConfig()
Expand All @@ -119,75 +100,47 @@ def get_cli_args():
ActionMaskEnv,
env_config={
"action_space": Discrete(100),
# This is not going to be the observation space that our RLModule sees.
# It's only the configuration provided to the environment.
# The environment will instead create Dict observations with
# the keys "observations" and "action_mask".
"observation_space": Box(-1.0, 1.0, (5,)),
},
)
.training(
# the ActionMaskModel retrieves the invalid actions and avoids them
model={
"custom_model": ActionMaskModel
if args.framework != "torch"
else TorchActionMaskModel,
# disable action masking according to CLI
"custom_model_config": {"no_masking": args.no_masking},
},
)
# We need to disable preprocessing of observations, because preprocessing
# would flatten the observation dict of the environment.
.experimental(_disable_preprocessor_api=True)
.framework(args.framework)
.resources(
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0"))
)
.rl_module(rl_module_spec=rlm_spec)
)

stop = {
"training_iteration": args.stop_iters,
"timesteps_total": args.stop_timesteps,
"episode_reward_mean": args.stop_reward,
}

# manual training loop (no Ray tune)
if args.no_tune:
if args.run not in {"APPO", "PPO"}:
raise ValueError("This example only supports APPO and PPO.")

algo = config.build()

# run manual training loop and print results after each iteration
for _ in range(args.stop_iters):
result = algo.train()
print(pretty_print(result))
# stop training if the target train steps or reward are reached
if (
result["timesteps_total"] >= args.stop_timesteps
or result["episode_reward_mean"] >= args.stop_reward
):
break

# manual test loop
print("Finished training. Running manual test/inference loop.")
# prepare environment with max 10 steps
config["env_config"]["max_episode_len"] = 10
env = ActionMaskEnv(config["env_config"])
obs, info = env.reset()
done = False
# run one iteration until done
print(f"ActionMaskEnv with {config['env_config']}")
while not done:
action = algo.compute_single_action(obs)
next_obs, reward, done, truncated, _ = env.step(action)
# observations contain original observations and the action mask
# reward is random and irrelevant here and therefore not printed
print(f"Obs: {obs}, Action: {action}")
obs = next_obs

# Run with tune for auto Algorithm creation, stopping, TensorBoard, etc.
else:
tuner = tune.Tuner(
args.run,
param_space=config.to_dict(),
run_config=air.RunConfig(stop=stop, verbose=2),
)
tuner.fit()
algo = config.build()

# run manual training loop and print results after each iteration
for _ in range(args.stop_iters):
result = algo.train()
print(pretty_print(result))

# manual test loop
print("Finished training. Running manual test/inference loop.")
# prepare environment with max 10 steps
config["env_config"]["max_episode_len"] = 10
env = ActionMaskEnv(config["env_config"])
obs, info = env.reset()
done = False
# run one iteration until done
print(f"ActionMaskEnv with {config['env_config']}")
while not done:
action = algo.compute_single_action(obs)
next_obs, reward, done, truncated, _ = env.step(action)
# observations contain original observations and the action mask
# reward is random and irrelevant here and therefore not printed
print(f"Obs: {obs}, Action: {action}")
obs = next_obs

print("Finished successfully without selecting invalid actions.")
ray.shutdown()
121 changes: 121 additions & 0 deletions rllib/examples/rl_module/action_masking_rlm.py
@@ -0,0 +1,121 @@
import gymnasium as gym

from ray.rllib.algorithms.ppo.tf.ppo_tf_rl_module import PPOTfRLModule
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule
from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleConfig
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_torch, try_import_tf
from ray.rllib.utils.torch_utils import FLOAT_MIN

torch, nn = try_import_torch()
_, tf, _ = try_import_tf()


class ActionMaskRLMBase(RLModule):
def __init__(self, config: RLModuleConfig):
if not isinstance(config.observation_space, gym.spaces.Dict):
raise ValueError(
"This model requires the environment to provide a "
"gym.spaces.Dict observation space."
)
# We need to adjust the observation space for this RL Module so that, when
# building the default models, the RLModule does not "see" the action mask but
# only the original observation space without the action mask. This tricks it
# into building models that are compatible with the original observation space.
config.observation_space = config.observation_space["observations"]

# The PPORLModule, in its constructor, will build models for the modified
# observation space.
super().__init__(config)


class TorchActionMaskRLM(ActionMaskRLMBase, PPOTorchRLModule):
def _forward_inference(self, batch, **kwargs):
return mask_forward_fn_torch(super()._forward_inference, batch, **kwargs)

def _forward_train(self, batch, *args, **kwargs):
return mask_forward_fn_torch(super()._forward_train, batch, **kwargs)

def _forward_exploration(self, batch, *args, **kwargs):
return mask_forward_fn_torch(super()._forward_exploration, batch, **kwargs)


class TFActionMaskRLM(ActionMaskRLMBase, PPOTfRLModule):
def _forward_inference(self, batch, **kwargs):
return mask_forward_fn_tf(super()._forward_inference, batch, **kwargs)

def _forward_train(self, batch, *args, **kwargs):
return mask_forward_fn_tf(super()._forward_train, batch, **kwargs)

def _forward_exploration(self, batch, *args, **kwargs):
return mask_forward_fn_tf(super()._forward_exploration, batch, **kwargs)


def mask_forward_fn_torch(forward_fn, batch, **kwargs):
_check_batch(batch)

# Extract the available actions tensor from the observation.
action_mask = batch[SampleBatch.OBS]["action_mask"]

# Modify the incoming batch so that the default models can compute logits and
# values as usual.
batch[SampleBatch.OBS] = batch[SampleBatch.OBS]["observations"]

outputs = forward_fn(batch, **kwargs)

# Mask logits
logits = outputs[SampleBatch.ACTION_DIST_INPUTS]
# Convert action_mask into a [0.0 || -inf]-type mask.
inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)
masked_logits = logits + inf_mask

# Replace original values with masked values.
outputs[SampleBatch.ACTION_DIST_INPUTS] = masked_logits

return outputs


def mask_forward_fn_tf(forward_fn, batch, **kwargs):
_check_batch(batch)

# Extract the available actions tensor from the observation.
action_mask = batch[SampleBatch.OBS]["action_mask"]

# Modify the incoming batch so that the default models can compute logits and
# values as usual.
batch[SampleBatch.OBS] = batch[SampleBatch.OBS]["observations"]

outputs = forward_fn(batch, **kwargs)

# Mask logits
logits = outputs[SampleBatch.ACTION_DIST_INPUTS]
# Convert action_mask into a [0.0 || -inf]-type mask.
inf_mask = tf.maximum(tf.math.log(action_mask), tf.float32.min)
masked_logits = logits + inf_mask

# Replace original values with masked values.
outputs[SampleBatch.ACTION_DIST_INPUTS] = masked_logits

return outputs


def _check_batch(batch):
"""Check whether the batch contains the required keys."""
if "action_mask" not in batch[SampleBatch.OBS]:
raise ValueError(
"Action mask not found in observation. This model requires "
"the environment to provide observations that include an "
"action mask (i.e. an observation space of the Dict space "
"type that looks as follows: \n"
"{'action_mask': Box(0.0, 1.0, shape=(self.action_space.n,)),"
"'observations': <observation_space>}"
)
if "observations" not in batch[SampleBatch.OBS]:
raise ValueError(
"Observations not found in observation.This model requires "
"the environment to provide observations that include a "
" (i.e. an observation space of the Dict space "
"type that looks as follows: \n"
"{'action_mask': Box(0.0, 1.0, shape=(self.action_space.n,)),"
"'observations': <observation_space>}"
)
10 changes: 0 additions & 10 deletions rllib/models/tests/test_distributions.py
Expand Up @@ -129,16 +129,6 @@ def test_categorical(self):
expected = (probs * (probs / probs2).log()).sum(dim=-1)
check(dist_with_probs.kl(dist2), expected)

# test temperature
dist_with_logits = TorchCategorical(logits=logits, temperature=1e-20)
samples = dist_with_logits.sample()
rsamples = dist_with_logits.rsample()
# expected is armax of logits
expected = logits.argmax(dim=-1)
check(samples, expected)
# rsample should be the same as sample, but one-hot encoded
check(samples, torch.argmax(rsamples, dim=-1))

def test_diag_gaussian(self):
batch_size = 128
ndim = 4
Expand Down