Skip to content

Commit

Permalink
[RLlib] Metrics do-over 05: Add example script for a custom `render()…
Browse files Browse the repository at this point in the history
…` method (with WandB videos). (ray-project#45107)
  • Loading branch information
sven1977 authored and ryanaoleary committed Jun 7, 2024
1 parent af341dc commit 8abd88a
Show file tree
Hide file tree
Showing 15 changed files with 280 additions and 25 deletions.
18 changes: 18 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2278,6 +2278,24 @@ py_test(
args = ["--enable-new-api-stack", "--as-test"]
)

py_test(
name = "examples/envs/custom_env_render_method",
main = "examples/envs/custom_env_render_method.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "medium",
srcs = ["examples/envs/custom_env_render_method.py"],
args = ["--enable-new-api-stack", "--num-agents=0"]
)

py_test(
name = "examples/envs/custom_env_render_method_multi_agent",
main = "examples/envs/custom_env_render_method.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "medium",
srcs = ["examples/envs/custom_env_render_method.py"],
args = ["--enable-new-api-stack", "--num-agents=2"]
)

py_test(
name = "examples/envs/env_rendering_and_recording",
srcs = ["examples/envs/env_rendering_and_recording.py"],
Expand Down
22 changes: 15 additions & 7 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,7 @@ def setup(self, config: AlgorithmConfig) -> None:
self.workers.sync_weights(inference_only=True)

# Run `on_algorithm_init` callback after initialization is done.
self.callbacks.on_algorithm_init(algorithm=self)
self.callbacks.on_algorithm_init(algorithm=self, metrics_logger=self.metrics)

@OverrideToImplementCustomLogic
@classmethod
Expand Down Expand Up @@ -999,7 +999,7 @@ def evaluate(
config=self.evaluation_config,
)

self.callbacks.on_evaluate_start(algorithm=self)
self.callbacks.on_evaluate_start(algorithm=self, metrics_logger=self.metrics)

env_steps = agent_steps = 0
batches = []
Expand Down Expand Up @@ -1097,7 +1097,11 @@ def evaluate(
eval_results["off_policy_estimator"][name] = avg_estimate

# Trigger `on_evaluate_end` callback.
self.callbacks.on_evaluate_end(algorithm=self, evaluation_metrics=eval_results)
self.callbacks.on_evaluate_end(
algorithm=self,
metrics_logger=self.metrics,
evaluation_metrics=eval_results,
)

# Also return the results here for convenience.
return eval_results
Expand Down Expand Up @@ -2447,9 +2451,13 @@ def load_checkpoint(self, checkpoint_dir: str) -> None:
def log_result(self, result: ResultDict) -> None:
# Log after the callback is invoked, so that the user has a chance
# to mutate the result.
# TODO: Remove `algorithm` arg at some point to fully deprecate the old
# signature.
self.callbacks.on_train_result(algorithm=self, result=result)
# TODO (sven): It might not make sense to pass in the MetricsLogger at this late
# point in time. In here, the result dict has already been "compiled" (reduced)
# by the MetricsLogger and there is probably no point in adding more Stats
# here.
self.callbacks.on_train_result(
algorithm=self, metrics_logger=self.metrics, result=result
)
# Then log according to Trainable's logging logic.
Trainable.log_result(self, result)

Expand Down Expand Up @@ -3264,7 +3272,7 @@ def _run_one_training_iteration_and_evaluation_in_parallel_wo_thread(
config=self.evaluation_config,
)

self.callbacks.on_evaluate_start(algorithm=self)
self.callbacks.on_evaluate_start(algorithm=self, metrics_logger=self.metrics)

env_steps = agent_steps = 0

Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/tests/test_callbacks_on_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def on_workers_recreated(


class InitAndCheckpointRestoredCallbacks(DefaultCallbacks):
def on_algorithm_init(self, *, algorithm, **kwargs):
def on_algorithm_init(self, *, algorithm, metrics_logger, **kwargs):
self._on_init_was_called = True

def on_checkpoint_loaded(self, *, algorithm, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions rllib/algorithms/tests/test_callbacks_on_env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def on_episode_created(
episode,
worker=None,
env_runner=None,
metrics_logger=None,
base_env=None,
env=None,
policies=None,
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/tests/test_worker_failures.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ class AddModuleCallback(DefaultCallbacks):
def __init__(self):
super().__init__()

def on_algorithm_init(self, *, algorithm, **kwargs):
def on_algorithm_init(self, *, algorithm, metrics_logger, **kwargs):
# Add a custom module to algorithm.
spec = algorithm.config.get_default_rl_module_spec()
spec.observation_space = gym.spaces.Box(low=0, high=1, shape=(8,))
Expand Down
12 changes: 11 additions & 1 deletion rllib/env/multi_agent_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import logging
from typing import Callable, Dict, List, Tuple, Optional, Union, Set, Type

import numpy as np

from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.env_context import EnvContext
from ray.rllib.utils.annotations import (
Expand Down Expand Up @@ -554,7 +556,15 @@ def step(self, action_dict):

@override(MultiAgentEnv)
def render(self):
return self.envs[0].render(self.render_mode)
# This render method simply renders all n underlying individual single-agent
# envs and concatenates their images (on top of each other if the returned
# images have dims where [width] > [height], otherwise next to each other).
render_images = [e.render() for e in self.envs]
if render_images[0].shape[1] > render_images[0].shape[0]:
concat_dim = 0
else:
concat_dim = 1
return np.concatenate(render_images, axis=concat_dim)

return MultiEnv

Expand Down
9 changes: 5 additions & 4 deletions rllib/env/multi_agent_env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ def __init__(self, config: AlgorithmConfig, **kwargs):
# Get the worker index on which this instance is running.
self.worker_index: int = kwargs.get("worker_index")

# Set up all metrics-related structures and counters.
self.metrics: Optional[MetricsLogger] = None
self._setup_metrics()

# Create our callbacks object.
self._callbacks: DefaultCallbacks = self.config.callbacks_class()

Expand All @@ -86,10 +90,6 @@ def __init__(self, config: AlgorithmConfig, **kwargs):
# Create the two connector pipelines: env-to-module and module-to-env.
self._module_to_env = self.config.build_module_to_env_connector(self.env)

# Set up all metrics-related structures and counters.
self.metrics: Optional[MetricsLogger] = None
self._setup_metrics()

self._needs_initial_reset: bool = True
self._episode: Optional[MultiAgentEpisode] = None
self._shared_data = None
Expand Down Expand Up @@ -749,6 +749,7 @@ def make_env(self):
# Call the `on_environment_created` callback.
self._callbacks.on_environment_created(
env_runner=self,
metrics_logger=self.metrics,
env=self.env,
env_context=env_ctx,
)
Expand Down
15 changes: 8 additions & 7 deletions rllib/env/single_agent_env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,16 @@ def __init__(self, config: AlgorithmConfig, **kwargs):
"""
super().__init__(config=config)

self.worker_index = kwargs.get("worker_index")

# Create a MetricsLogger object for logging custom stats.
self.metrics = MetricsLogger()
# Initialize lifetime counts.
self.metrics.log_value(NUM_ENV_STEPS_SAMPLED_LIFETIME, 0, reduce="sum")

# Create our callbacks object.
self._callbacks: DefaultCallbacks = self.config.callbacks_class()

self.worker_index = kwargs.get("worker_index")

# Create the vectorized gymnasium env.
self.env: Optional[gym.Wrapper] = None
self.num_envs: int = 0
Expand Down Expand Up @@ -98,11 +103,6 @@ def __init__(self, config: AlgorithmConfig, **kwargs):
# Create the two connector pipelines: env-to-module and module-to-env.
self._module_to_env = self.config.build_module_to_env_connector(self.env)

# Create a MetricsLogger object for logging custom stats.
self.metrics = MetricsLogger()
# Initialize lifetime counts.
self.metrics.log_value(NUM_ENV_STEPS_SAMPLED_LIFETIME, 0, reduce="sum")

# This should be the default.
self._needs_initial_reset: bool = True
self._episodes: List[Optional[SingleAgentEpisode]] = [
Expand Down Expand Up @@ -702,6 +702,7 @@ def make_env(self) -> None:
# Call the `on_environment_created` callback.
self._callbacks.on_environment_created(
env_runner=self,
metrics_logger=self.metrics,
env=self.env,
env_context=env_ctx,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self, checkpoint_dir):
self._checkpoint_dir = checkpoint_dir
super().__init__()

def on_algorithm_init(self, *, algorithm, **kwargs):
def on_algorithm_init(self, *, algorithm, metrics_logger, **kwargs):
policy = Policy.from_checkpoint(
self._checkpoint_dir, policy_ids=[OPPONENT_POLICY_ID]
)
Expand Down
1 change: 1 addition & 0 deletions rllib/examples/curriculum/curriculum_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def on_train_result(
self,
*,
algorithm: Algorithm,
metrics_logger=None,
result: dict,
**kwargs,
) -> None:
Expand Down
Loading

0 comments on commit 8abd88a

Please sign in to comment.