Skip to content

Commit

Permalink
[RLlib] Cleanup examples folder: Add example restoring 1 of n agents …
Browse files Browse the repository at this point in the history
…from a checkpoint. (#45462)
  • Loading branch information
simonsays1980 committed May 24, 2024
1 parent 7fb0ce1 commit 5cb7c09
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 131 deletions.
18 changes: 9 additions & 9 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2157,15 +2157,6 @@ py_test(
srcs = ["examples/checkpoints/onnx_torch.py"],
)

#@OldAPIStack
py_test(
name = "examples/checkpoints/restore_1_of_n_agents_from_checkpoint",
tags = ["team:rllib", "exclusive", "examples"],
size = "medium",
srcs = ["examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py"],
args = ["--pre-training-iters=1", "--stop-iters=1", "--num-cpus=4"]
)

# subdirectory: connectors/
# ....................................
# Framestacking examples only run in smoke-test mode (a few iters only).
Expand Down Expand Up @@ -2751,6 +2742,15 @@ py_test(
# args = ["--enable-new-api-stack", "--num-agents=2", "--as-test", "--framework=torch", "--stop-reward=-100.0", "--num-cpus=4"],
# )

py_test(
name = "examples/checkpoints/restore_1_of_n_agents_from_checkpoint",
main = "examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py",
tags = ["team:rllib", "exclusive", "examples", "examples_use_all_core", "no_main"],
size = "large",
srcs = ["examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py"],
args = ["--enable-new-api-stack", "--num-agents=2", "--framework=torch", "--checkpoint-freq=20", "--checkpoint-at-end", "--num-cpus=4", "--algo=PPO"]
)

py_test(
name = "examples/multi_agent/rock_paper_scissors_heuristic_vs_learned",
main = "examples/multi_agent/rock_paper_scissors_heuristic_vs_learned.py",
Expand Down
257 changes: 135 additions & 122 deletions rllib/examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,140 +1,153 @@
# TODO (sven): Move this example script into the new API stack.

"""Simple example of how to restore only one of n agents from a trained
multi-agent Algorithm using Ray tune.
Control the number of agents and policies via --num-agents and --num-policies.
"""An example script showing how to load module weights for 1 of n agents
from checkpoint.
This example:
- Runs a multi-agent `Pendulum-v1` experiment with >= 2 policies.
- Saves a checkpoint of the `MultiAgentRLModule` used every `--checkpoint-freq`
iterations.
- Stops the experiments after the agents reach a combined return of `-800`.
- Picks the best checkpoint by combined return and restores policy 0 from it.
- Runs a second experiment with the restored `RLModule` for policy 0 and
a fresh `RLModule` for the other policies.
- Stops the second experiment after the agents reach a combined return of `-800`.
How to run this script
----------------------
`python [script file name].py --enable-new-api-stack --num-agents=2
--checkpoint-freq=20 --checkpoint-at-end`
Control the number of agents and policies (RLModules) via --num-agents and
--num-policies.
Control the number of checkpoints by setting `--checkpoint-freq` to a value > 0.
Note that the checkpoint frequency is per iteration and this example needs at
least a single checkpoint to load the RLModule weights for policy 0.
If `--checkpoint-at-end` is set, a checkpoint will be saved at the end of the
experiment.
For debugging, use the following additional command line options
`--no-tune --num-env-runners=0`
which should allow you to set breakpoints anywhere in the RLlib code and
have the execution stop there for inspection and debugging.
For logging to your WandB account, use:
`--wandb-key=[your WandB API key] --wandb-project=[some project name]
--wandb-run-name=[optional: WandB run name (within the defined project)]`
Results to expect
-----------------
You should expect a reward of -400.0 eventually being achieved by a simple
single PPO policy (no tuning, just using RLlib's default settings). In the
second run of the experiment, the MARL module weights for policy 0 are
restored from the checkpoint of the first run. The reward for a single agent
should be -400.0 again, but the training time should be shorter (around 30
iterations instead of 190).
"""

import argparse
import gymnasium as gym
import os
import random

import ray
from ray import air, tune
from ray.air.constants import TRAINING_ITERATION
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.metrics import (
ENV_RUNNER_RESULTS,
EPISODE_RETURN_MEAN,
NUM_ENV_STEPS_SAMPLED_LIFETIME,
from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec
from ray.rllib.examples.envs.classes.multi_agent import MultiAgentPendulum
from ray.rllib.utils.metrics import ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME
from ray.rllib.utils.test_utils import (
add_rllib_example_script_args,
run_rllib_example_script_experiment,
)
from ray.rllib.utils.test_utils import check_learning_achieved
from ray.tune.registry import get_trainable_cls, register_env

tf1, tf, tfv = try_import_tf()

parser = argparse.ArgumentParser()

parser.add_argument("--num-agents", type=int, default=4)
parser.add_argument("--num-policies", type=int, default=2)
parser.add_argument("--pre-training-iters", type=int, default=5)
parser.add_argument("--num-cpus", type=int, default=0)
parser.add_argument(
"--framework",
choices=["tf", "tf2", "torch"],
default="torch",
help="The DL framework specifier.",
)
parser.add_argument(
"--as-test",
action="store_true",
help="Whether this script should be run as a test: --stop-reward must "
"be achieved within --stop-timesteps AND --stop-iters.",
)
parser.add_argument(
"--stop-iters", type=int, default=200, help="Number of iterations to train."
)
parser.add_argument(
"--stop-timesteps", type=int, default=100000, help="Number of timesteps to train."
)
parser.add_argument(
"--stop-reward", type=float, default=150.0, help="Reward at which we stop training."
parser = add_rllib_example_script_args(
default_iters=200,
default_timesteps=100000,
default_reward=-400.0,
)
# TODO (sven): This arg is currently ignored (hard-set to 2).
parser.add_argument("--num-policies", type=int, default=2)


if __name__ == "__main__":
args = parser.parse_args()

ray.init(num_cpus=args.num_cpus or None)

# Get obs- and action Spaces.
single_env = gym.make("CartPole-v1")
obs_space = single_env.observation_space
act_space = single_env.action_space

# Setup PPO with an ensemble of `num_policies` different policies.
policies = {
f"policy_{i}": (None, obs_space, act_space, None)
for i in range(args.num_policies)
}
policy_ids = list(policies.keys())

def policy_mapping_fn(agent_id, episode, worker, **kwargs):
pol_id = random.choice(policy_ids)
return pol_id

config = (
PPOConfig()
.environment(MultiAgentCartPole, env_config={"num_agents": args.num_agents})
.framework(args.framework)
.training(num_sgd_iter=10)
.multi_agent(policies=policies, policy_mapping_fn=policy_mapping_fn)
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
# Register our environment with tune.
if args.num_agents > 1:
register_env(
"env",
lambda _: MultiAgentPendulum(config={"num_agents": args.num_agents}),
)
else:
raise ValueError(
f"`num_agents` must be > 1, but is {args.num_agents}."
"Read the script docstring for more information."
)

assert args.checkpoint_freq > 0, (
"This example requires at least one checkpoint to load the RLModule "
"weights for policy 0."
)

# Do some training and store the checkpoint.
results = tune.Tuner(
"PPO",
param_space=config.to_dict(),
run_config=air.RunConfig(
stop={TRAINING_ITERATION: args.pre_training_iters},
verbose=1,
checkpoint_config=air.CheckpointConfig(
checkpoint_frequency=1, checkpoint_at_end=True
),
),
).fit()
print("Pre-training done.")

best_checkpoint = results.get_best_result().checkpoint
print(f".. best checkpoint was: {best_checkpoint}")

policy_0_checkpoint = os.path.join(
best_checkpoint.to_directory(), "policies/policy_0"
base_config = (
get_trainable_cls(args.algo)
.get_default_config()
.environment("env")
.training(
train_batch_size_per_learner=512,
mini_batch_size_per_learner=64,
lambda_=0.1,
gamma=0.95,
lr=0.0003,
vf_clip_param=10.0,
)
.rl_module(
model_config_dict={"fcnet_activation": "relu"},
)
)
restored_policy_0 = Policy.from_checkpoint(policy_0_checkpoint)
restored_policy_0_weights = restored_policy_0.get_weights()
print("Starting new tune.Tuner().fit()")

# Start our actual experiment.
stop = {
f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": args.stop_reward,
NUM_ENV_STEPS_SAMPLED_LIFETIME: args.stop_timesteps,
TRAINING_ITERATION: args.stop_iters,
# Add a simple multi-agent setup.
if args.num_agents > 0:
base_config.multi_agent(
policies={f"p{i}" for i in range(args.num_agents)},
policy_mapping_fn=lambda aid, *a, **kw: f"p{aid}",
)

# Augment the base config with further settings and train the agents.
results = run_rllib_example_script_experiment(base_config, args)

# Create an env instance to get the observation and action spaces.
env = MultiAgentPendulum(config={"num_agents": args.num_agents})
# Get the default module spec from the algorithm config.
module_spec = base_config.get_default_rl_module_spec()
module_spec.model_config_dict = base_config.model_config | {
"fcnet_activation": "relu",
}

class RestoreWeightsCallback(DefaultCallbacks):
def on_algorithm_init(self, *, algorithm: "Algorithm", **kwargs) -> None:
algorithm.set_weights({"policy_0": restored_policy_0_weights})

# Make sure, the non-1st policies are not updated anymore.
config.policies_to_train = [pid for pid in policy_ids if pid != "policy_0"]
config.callbacks(RestoreWeightsCallback)

results = tune.run(
"PPO",
stop=stop,
config=config.to_dict(),
verbose=1,
module_spec.observation_space = env.envs[0].observation_space
module_spec.action_space = env.envs[0].action_space
# Create the module for each policy, but policy 0.
module_specs = {}
for i in range(1, args.num_agents or 1):
module_specs[f"p{i}"] = module_spec

# Now swap in the RLModule weights for policy 0.
chkpt_path = results.get_best_result().checkpoint.path
p_0_module_state_path = os.path.join(chkpt_path, "learner", "module_state", "p0")
module_spec.load_state_path = p_0_module_state_path
module_specs["p0"] = module_spec

# Create the MARL module.
marl_module_spec = MultiAgentRLModuleSpec(module_specs=module_specs)
# Define the MARL module in the base config.
base_config.rl_module(rl_module_spec=marl_module_spec)
# We need to re-register the environment when starting a new run.
register_env(
"env",
lambda _: MultiAgentPendulum(config={"num_agents": args.num_agents}),
)
# Define stopping criteria.
stop = {
# TODO (simon): Change to -800 once the metrics are fixed. Currently
# the combined return is not correctly computed.
f"{ENV_RUNNER_RESULTS}/episode_return_mean": -400,
f"{NUM_ENV_STEPS_SAMPLED_LIFETIME}": 20000,
TRAINING_ITERATION: 30,
}

if args.as_test:
check_learning_achieved(results, args.stop_reward)

ray.shutdown()
# Run the experiment again with the restored MARL module.
run_rllib_example_script_experiment(base_config, args, stop=stop)

0 comments on commit 5cb7c09

Please sign in to comment.