Permalink
Fetching contributors…
Cannot retrieve contributors at this time
78 lines (60 sloc) 2.37 KB
"""Example of using RLlib's debug callbacks.
Here we use callbacks to track the average CartPole pole angle magnitude as a
custom metric.
"""
import argparse
import numpy as np
import ray
from ray import tune
def on_episode_start(info):
episode = info["episode"]
print("episode {} started".format(episode.episode_id))
episode.user_data["pole_angles"] = []
def on_episode_step(info):
episode = info["episode"]
pole_angle = abs(episode.last_observation_for()[2])
episode.user_data["pole_angles"].append(pole_angle)
def on_episode_end(info):
episode = info["episode"]
pole_angle = np.mean(episode.user_data["pole_angles"])
print("episode {} ended with length {} and pole angles {}".format(
episode.episode_id, episode.length, pole_angle))
episode.custom_metrics["pole_angle"] = pole_angle
def on_sample_end(info):
print("returned sample batch of size {}".format(info["samples"].count))
def on_train_result(info):
print("agent.train() result: {} -> {} episodes".format(
info["agent"], info["result"]["episodes_this_iter"]))
# you can mutate the result dict to add new fields to return
info["result"]["callback_ok"] = True
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-iters", type=int, default=2000)
args = parser.parse_args()
ray.init()
trials = tune.run_experiments({
"test": {
"env": "CartPole-v0",
"run": "PG",
"stop": {
"training_iteration": args.num_iters,
},
"config": {
"callbacks": {
"on_episode_start": tune.function(on_episode_start),
"on_episode_step": tune.function(on_episode_step),
"on_episode_end": tune.function(on_episode_end),
"on_sample_end": tune.function(on_sample_end),
"on_train_result": tune.function(on_train_result),
},
},
}
})
# verify custom metrics for integration tests
custom_metrics = trials[0].last_result["custom_metrics"]
print(custom_metrics)
assert "pole_angle_mean" in custom_metrics
assert "pole_angle_min" in custom_metrics
assert "pole_angle_max" in custom_metrics
assert type(custom_metrics["pole_angle_mean"]) is float
assert "callback_ok" in trials[0].last_result