-
Notifications
You must be signed in to change notification settings - Fork 5.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[RLlib] Cleanup examples folder: Add example restoring 1 of n agents …
…from a checkpoint. (#45462)
- Loading branch information
1 parent
7fb0ce1
commit 5cb7c09
Showing
2 changed files
with
144 additions
and
131 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
257 changes: 135 additions & 122 deletions
257
rllib/examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,140 +1,153 @@ | ||
# TODO (sven): Move this example script into the new API stack. | ||
|
||
"""Simple example of how to restore only one of n agents from a trained | ||
multi-agent Algorithm using Ray tune. | ||
Control the number of agents and policies via --num-agents and --num-policies. | ||
"""An example script showing how to load module weights for 1 of n agents | ||
from checkpoint. | ||
This example: | ||
- Runs a multi-agent `Pendulum-v1` experiment with >= 2 policies. | ||
- Saves a checkpoint of the `MultiAgentRLModule` used every `--checkpoint-freq` | ||
iterations. | ||
- Stops the experiments after the agents reach a combined return of `-800`. | ||
- Picks the best checkpoint by combined return and restores policy 0 from it. | ||
- Runs a second experiment with the restored `RLModule` for policy 0 and | ||
a fresh `RLModule` for the other policies. | ||
- Stops the second experiment after the agents reach a combined return of `-800`. | ||
How to run this script | ||
---------------------- | ||
`python [script file name].py --enable-new-api-stack --num-agents=2 | ||
--checkpoint-freq=20 --checkpoint-at-end` | ||
Control the number of agents and policies (RLModules) via --num-agents and | ||
--num-policies. | ||
Control the number of checkpoints by setting `--checkpoint-freq` to a value > 0. | ||
Note that the checkpoint frequency is per iteration and this example needs at | ||
least a single checkpoint to load the RLModule weights for policy 0. | ||
If `--checkpoint-at-end` is set, a checkpoint will be saved at the end of the | ||
experiment. | ||
For debugging, use the following additional command line options | ||
`--no-tune --num-env-runners=0` | ||
which should allow you to set breakpoints anywhere in the RLlib code and | ||
have the execution stop there for inspection and debugging. | ||
For logging to your WandB account, use: | ||
`--wandb-key=[your WandB API key] --wandb-project=[some project name] | ||
--wandb-run-name=[optional: WandB run name (within the defined project)]` | ||
Results to expect | ||
----------------- | ||
You should expect a reward of -400.0 eventually being achieved by a simple | ||
single PPO policy (no tuning, just using RLlib's default settings). In the | ||
second run of the experiment, the MARL module weights for policy 0 are | ||
restored from the checkpoint of the first run. The reward for a single agent | ||
should be -400.0 again, but the training time should be shorter (around 30 | ||
iterations instead of 190). | ||
""" | ||
|
||
import argparse | ||
import gymnasium as gym | ||
import os | ||
import random | ||
|
||
import ray | ||
from ray import air, tune | ||
from ray.air.constants import TRAINING_ITERATION | ||
from ray.rllib.algorithms.algorithm import Algorithm | ||
from ray.rllib.algorithms.callbacks import DefaultCallbacks | ||
from ray.rllib.algorithms.ppo import PPOConfig | ||
from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole | ||
from ray.rllib.policy.policy import Policy | ||
from ray.rllib.utils.framework import try_import_tf | ||
from ray.rllib.utils.metrics import ( | ||
ENV_RUNNER_RESULTS, | ||
EPISODE_RETURN_MEAN, | ||
NUM_ENV_STEPS_SAMPLED_LIFETIME, | ||
from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec | ||
from ray.rllib.examples.envs.classes.multi_agent import MultiAgentPendulum | ||
from ray.rllib.utils.metrics import ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME | ||
from ray.rllib.utils.test_utils import ( | ||
add_rllib_example_script_args, | ||
run_rllib_example_script_experiment, | ||
) | ||
from ray.rllib.utils.test_utils import check_learning_achieved | ||
from ray.tune.registry import get_trainable_cls, register_env | ||
|
||
tf1, tf, tfv = try_import_tf() | ||
|
||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument("--num-agents", type=int, default=4) | ||
parser.add_argument("--num-policies", type=int, default=2) | ||
parser.add_argument("--pre-training-iters", type=int, default=5) | ||
parser.add_argument("--num-cpus", type=int, default=0) | ||
parser.add_argument( | ||
"--framework", | ||
choices=["tf", "tf2", "torch"], | ||
default="torch", | ||
help="The DL framework specifier.", | ||
) | ||
parser.add_argument( | ||
"--as-test", | ||
action="store_true", | ||
help="Whether this script should be run as a test: --stop-reward must " | ||
"be achieved within --stop-timesteps AND --stop-iters.", | ||
) | ||
parser.add_argument( | ||
"--stop-iters", type=int, default=200, help="Number of iterations to train." | ||
) | ||
parser.add_argument( | ||
"--stop-timesteps", type=int, default=100000, help="Number of timesteps to train." | ||
) | ||
parser.add_argument( | ||
"--stop-reward", type=float, default=150.0, help="Reward at which we stop training." | ||
parser = add_rllib_example_script_args( | ||
default_iters=200, | ||
default_timesteps=100000, | ||
default_reward=-400.0, | ||
) | ||
# TODO (sven): This arg is currently ignored (hard-set to 2). | ||
parser.add_argument("--num-policies", type=int, default=2) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parser.parse_args() | ||
|
||
ray.init(num_cpus=args.num_cpus or None) | ||
|
||
# Get obs- and action Spaces. | ||
single_env = gym.make("CartPole-v1") | ||
obs_space = single_env.observation_space | ||
act_space = single_env.action_space | ||
|
||
# Setup PPO with an ensemble of `num_policies` different policies. | ||
policies = { | ||
f"policy_{i}": (None, obs_space, act_space, None) | ||
for i in range(args.num_policies) | ||
} | ||
policy_ids = list(policies.keys()) | ||
|
||
def policy_mapping_fn(agent_id, episode, worker, **kwargs): | ||
pol_id = random.choice(policy_ids) | ||
return pol_id | ||
|
||
config = ( | ||
PPOConfig() | ||
.environment(MultiAgentCartPole, env_config={"num_agents": args.num_agents}) | ||
.framework(args.framework) | ||
.training(num_sgd_iter=10) | ||
.multi_agent(policies=policies, policy_mapping_fn=policy_mapping_fn) | ||
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0. | ||
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0"))) | ||
# Register our environment with tune. | ||
if args.num_agents > 1: | ||
register_env( | ||
"env", | ||
lambda _: MultiAgentPendulum(config={"num_agents": args.num_agents}), | ||
) | ||
else: | ||
raise ValueError( | ||
f"`num_agents` must be > 1, but is {args.num_agents}." | ||
"Read the script docstring for more information." | ||
) | ||
|
||
assert args.checkpoint_freq > 0, ( | ||
"This example requires at least one checkpoint to load the RLModule " | ||
"weights for policy 0." | ||
) | ||
|
||
# Do some training and store the checkpoint. | ||
results = tune.Tuner( | ||
"PPO", | ||
param_space=config.to_dict(), | ||
run_config=air.RunConfig( | ||
stop={TRAINING_ITERATION: args.pre_training_iters}, | ||
verbose=1, | ||
checkpoint_config=air.CheckpointConfig( | ||
checkpoint_frequency=1, checkpoint_at_end=True | ||
), | ||
), | ||
).fit() | ||
print("Pre-training done.") | ||
|
||
best_checkpoint = results.get_best_result().checkpoint | ||
print(f".. best checkpoint was: {best_checkpoint}") | ||
|
||
policy_0_checkpoint = os.path.join( | ||
best_checkpoint.to_directory(), "policies/policy_0" | ||
base_config = ( | ||
get_trainable_cls(args.algo) | ||
.get_default_config() | ||
.environment("env") | ||
.training( | ||
train_batch_size_per_learner=512, | ||
mini_batch_size_per_learner=64, | ||
lambda_=0.1, | ||
gamma=0.95, | ||
lr=0.0003, | ||
vf_clip_param=10.0, | ||
) | ||
.rl_module( | ||
model_config_dict={"fcnet_activation": "relu"}, | ||
) | ||
) | ||
restored_policy_0 = Policy.from_checkpoint(policy_0_checkpoint) | ||
restored_policy_0_weights = restored_policy_0.get_weights() | ||
print("Starting new tune.Tuner().fit()") | ||
|
||
# Start our actual experiment. | ||
stop = { | ||
f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": args.stop_reward, | ||
NUM_ENV_STEPS_SAMPLED_LIFETIME: args.stop_timesteps, | ||
TRAINING_ITERATION: args.stop_iters, | ||
# Add a simple multi-agent setup. | ||
if args.num_agents > 0: | ||
base_config.multi_agent( | ||
policies={f"p{i}" for i in range(args.num_agents)}, | ||
policy_mapping_fn=lambda aid, *a, **kw: f"p{aid}", | ||
) | ||
|
||
# Augment the base config with further settings and train the agents. | ||
results = run_rllib_example_script_experiment(base_config, args) | ||
|
||
# Create an env instance to get the observation and action spaces. | ||
env = MultiAgentPendulum(config={"num_agents": args.num_agents}) | ||
# Get the default module spec from the algorithm config. | ||
module_spec = base_config.get_default_rl_module_spec() | ||
module_spec.model_config_dict = base_config.model_config | { | ||
"fcnet_activation": "relu", | ||
} | ||
|
||
class RestoreWeightsCallback(DefaultCallbacks): | ||
def on_algorithm_init(self, *, algorithm: "Algorithm", **kwargs) -> None: | ||
algorithm.set_weights({"policy_0": restored_policy_0_weights}) | ||
|
||
# Make sure, the non-1st policies are not updated anymore. | ||
config.policies_to_train = [pid for pid in policy_ids if pid != "policy_0"] | ||
config.callbacks(RestoreWeightsCallback) | ||
|
||
results = tune.run( | ||
"PPO", | ||
stop=stop, | ||
config=config.to_dict(), | ||
verbose=1, | ||
module_spec.observation_space = env.envs[0].observation_space | ||
module_spec.action_space = env.envs[0].action_space | ||
# Create the module for each policy, but policy 0. | ||
module_specs = {} | ||
for i in range(1, args.num_agents or 1): | ||
module_specs[f"p{i}"] = module_spec | ||
|
||
# Now swap in the RLModule weights for policy 0. | ||
chkpt_path = results.get_best_result().checkpoint.path | ||
p_0_module_state_path = os.path.join(chkpt_path, "learner", "module_state", "p0") | ||
module_spec.load_state_path = p_0_module_state_path | ||
module_specs["p0"] = module_spec | ||
|
||
# Create the MARL module. | ||
marl_module_spec = MultiAgentRLModuleSpec(module_specs=module_specs) | ||
# Define the MARL module in the base config. | ||
base_config.rl_module(rl_module_spec=marl_module_spec) | ||
# We need to re-register the environment when starting a new run. | ||
register_env( | ||
"env", | ||
lambda _: MultiAgentPendulum(config={"num_agents": args.num_agents}), | ||
) | ||
# Define stopping criteria. | ||
stop = { | ||
# TODO (simon): Change to -800 once the metrics are fixed. Currently | ||
# the combined return is not correctly computed. | ||
f"{ENV_RUNNER_RESULTS}/episode_return_mean": -400, | ||
f"{NUM_ENV_STEPS_SAMPLED_LIFETIME}": 20000, | ||
TRAINING_ITERATION: 30, | ||
} | ||
|
||
if args.as_test: | ||
check_learning_achieved(results, args.stop_reward) | ||
|
||
ray.shutdown() | ||
# Run the experiment again with the restored MARL module. | ||
run_rllib_example_script_experiment(base_config, args, stop=stop) |