Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
89 lines (69 sloc) 2.79 KB
"""Example of using RLlib's debug callbacks.
Here we use callbacks to track the average CartPole pole angle magnitude as a
custom metric.
"""
import argparse
import numpy as np
import ray
from ray import tune
def on_episode_start(info):
episode = info["episode"]
print("episode {} started".format(episode.episode_id))
episode.user_data["pole_angles"] = []
def on_episode_step(info):
episode = info["episode"]
pole_angle = abs(episode.last_observation_for()[2])
raw_angle = abs(episode.last_raw_obs_for()[2])
assert pole_angle == raw_angle
episode.user_data["pole_angles"].append(pole_angle)
def on_episode_end(info):
episode = info["episode"]
pole_angle = np.mean(episode.user_data["pole_angles"])
print("episode {} ended with length {} and pole angles {}".format(
episode.episode_id, episode.length, pole_angle))
episode.custom_metrics["pole_angle"] = pole_angle
def on_sample_end(info):
print("returned sample batch of size {}".format(info["samples"].count))
def on_train_result(info):
print("trainer.train() result: {} -> {} episodes".format(
info["trainer"], info["result"]["episodes_this_iter"]))
# you can mutate the result dict to add new fields to return
info["result"]["callback_ok"] = True
def on_postprocess_traj(info):
episode = info["episode"]
batch = info["post_batch"]
print("postprocessed {} steps".format(batch.count))
if "num_batches" not in episode.custom_metrics:
episode.custom_metrics["num_batches"] = 0
episode.custom_metrics["num_batches"] += 1
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-iters", type=int, default=2000)
args = parser.parse_args()
ray.init()
trials = tune.run(
"PG",
stop={
"training_iteration": args.num_iters,
},
config={
"env": "CartPole-v0",
"callbacks": {
"on_episode_start": tune.function(on_episode_start),
"on_episode_step": tune.function(on_episode_step),
"on_episode_end": tune.function(on_episode_end),
"on_sample_end": tune.function(on_sample_end),
"on_train_result": tune.function(on_train_result),
"on_postprocess_traj": tune.function(on_postprocess_traj),
},
},
return_trials=True)
# verify custom metrics for integration tests
custom_metrics = trials[0].last_result["custom_metrics"]
print(custom_metrics)
assert "pole_angle_mean" in custom_metrics
assert "pole_angle_min" in custom_metrics
assert "pole_angle_max" in custom_metrics
assert "num_batches_mean" in custom_metrics
assert type(custom_metrics["pole_angle_mean"]) is float
assert "callback_ok" in trials[0].last_result
You can’t perform that action at this time.