Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Metrics do-over 05: Add example script for a custom render() method (with WandB videos). #45107

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2278,6 +2278,24 @@ py_test(
args = ["--enable-new-api-stack", "--as-test"]
)

py_test(
name = "examples/envs/custom_env_render_method",
main = "examples/envs/custom_env_render_method.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "medium",
srcs = ["examples/envs/custom_env_render_method.py"],
args = ["--enable-new-api-stack", "--num-agents=0"]
)

py_test(
name = "examples/envs/custom_env_render_method_multi_agent",
main = "examples/envs/custom_env_render_method.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "medium",
srcs = ["examples/envs/custom_env_render_method.py"],
args = ["--enable-new-api-stack", "--num-agents=2"]
)

py_test(
name = "examples/envs/env_rendering_and_recording",
srcs = ["examples/envs/env_rendering_and_recording.py"],
Expand Down
22 changes: 15 additions & 7 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,7 @@ def setup(self, config: AlgorithmConfig) -> None:
self.workers.sync_weights(inference_only=True)

# Run `on_algorithm_init` callback after initialization is done.
self.callbacks.on_algorithm_init(algorithm=self)
self.callbacks.on_algorithm_init(algorithm=self, metrics_logger=self.metrics)

@OverrideToImplementCustomLogic
@classmethod
Expand Down Expand Up @@ -999,7 +999,7 @@ def evaluate(
config=self.evaluation_config,
)

self.callbacks.on_evaluate_start(algorithm=self)
self.callbacks.on_evaluate_start(algorithm=self, metrics_logger=self.metrics)

env_steps = agent_steps = 0
batches = []
Expand Down Expand Up @@ -1097,7 +1097,11 @@ def evaluate(
eval_results["off_policy_estimator"][name] = avg_estimate

# Trigger `on_evaluate_end` callback.
self.callbacks.on_evaluate_end(algorithm=self, evaluation_metrics=eval_results)
self.callbacks.on_evaluate_end(
algorithm=self,
metrics_logger=self.metrics,
evaluation_metrics=eval_results,
)

# Also return the results here for convenience.
return eval_results
Expand Down Expand Up @@ -2447,9 +2451,13 @@ def load_checkpoint(self, checkpoint_dir: str) -> None:
def log_result(self, result: ResultDict) -> None:
# Log after the callback is invoked, so that the user has a chance
# to mutate the result.
# TODO: Remove `algorithm` arg at some point to fully deprecate the old
# signature.
self.callbacks.on_train_result(algorithm=self, result=result)
# TODO (sven): It might not make sense to pass in the MetricsLogger at this late
# point in time. In here, the result dict has already been "compiled" (reduced)
# by the MetricsLogger and there is probably no point in adding more Stats
# here.
self.callbacks.on_train_result(
algorithm=self, metrics_logger=self.metrics, result=result
)
# Then log according to Trainable's logging logic.
Trainable.log_result(self, result)

Expand Down Expand Up @@ -3264,7 +3272,7 @@ def _run_one_training_iteration_and_evaluation_in_parallel_wo_thread(
config=self.evaluation_config,
)

self.callbacks.on_evaluate_start(algorithm=self)
self.callbacks.on_evaluate_start(algorithm=self, metrics_logger=self.metrics)

env_steps = agent_steps = 0

Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/tests/test_callbacks_on_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def on_workers_recreated(


class InitAndCheckpointRestoredCallbacks(DefaultCallbacks):
def on_algorithm_init(self, *, algorithm, **kwargs):
def on_algorithm_init(self, *, algorithm, metrics_logger, **kwargs):
self._on_init_was_called = True

def on_checkpoint_loaded(self, *, algorithm, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions rllib/algorithms/tests/test_callbacks_on_env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def on_episode_created(
episode,
worker=None,
env_runner=None,
metrics_logger=None,
base_env=None,
env=None,
policies=None,
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/tests/test_worker_failures.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ class AddModuleCallback(DefaultCallbacks):
def __init__(self):
super().__init__()

def on_algorithm_init(self, *, algorithm, **kwargs):
def on_algorithm_init(self, *, algorithm, metrics_logger, **kwargs):
# Add a custom module to algorithm.
spec = algorithm.config.get_default_rl_module_spec()
spec.observation_space = gym.spaces.Box(low=0, high=1, shape=(8,))
Expand Down
12 changes: 11 additions & 1 deletion rllib/env/multi_agent_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import logging
from typing import Callable, Dict, List, Tuple, Optional, Union, Set, Type

import numpy as np

from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.env_context import EnvContext
from ray.rllib.utils.annotations import (
Expand Down Expand Up @@ -554,7 +556,15 @@ def step(self, action_dict):

@override(MultiAgentEnv)
def render(self):
return self.envs[0].render(self.render_mode)
# This render method simply renders all n underlying individual single-agent
# envs and concatenates their images (on top of each other if the returned
# images have dims where [width] > [height], otherwise next to each other).
render_images = [e.render() for e in self.envs]
if render_images[0].shape[1] > render_images[0].shape[0]:
concat_dim = 0
else:
concat_dim = 1
return np.concatenate(render_images, axis=concat_dim)

return MultiEnv

Expand Down
9 changes: 5 additions & 4 deletions rllib/env/multi_agent_env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ def __init__(self, config: AlgorithmConfig, **kwargs):
# Get the worker index on which this instance is running.
self.worker_index: int = kwargs.get("worker_index")

# Set up all metrics-related structures and counters.
self.metrics: Optional[MetricsLogger] = None
self._setup_metrics()

# Create our callbacks object.
self._callbacks: DefaultCallbacks = self.config.callbacks_class()

Expand All @@ -86,10 +90,6 @@ def __init__(self, config: AlgorithmConfig, **kwargs):
# Create the two connector pipelines: env-to-module and module-to-env.
self._module_to_env = self.config.build_module_to_env_connector(self.env)

# Set up all metrics-related structures and counters.
self.metrics: Optional[MetricsLogger] = None
self._setup_metrics()

self._needs_initial_reset: bool = True
self._episode: Optional[MultiAgentEpisode] = None
self._shared_data = None
Expand Down Expand Up @@ -749,6 +749,7 @@ def make_env(self):
# Call the `on_environment_created` callback.
self._callbacks.on_environment_created(
env_runner=self,
metrics_logger=self.metrics,
env=self.env,
env_context=env_ctx,
)
Expand Down
15 changes: 8 additions & 7 deletions rllib/env/single_agent_env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,16 @@ def __init__(self, config: AlgorithmConfig, **kwargs):
"""
super().__init__(config=config)

self.worker_index = kwargs.get("worker_index")

# Create a MetricsLogger object for logging custom stats.
self.metrics = MetricsLogger()
# Initialize lifetime counts.
self.metrics.log_value(NUM_ENV_STEPS_SAMPLED_LIFETIME, 0, reduce="sum")

# Create our callbacks object.
self._callbacks: DefaultCallbacks = self.config.callbacks_class()

self.worker_index = kwargs.get("worker_index")

# Create the vectorized gymnasium env.
self.env: Optional[gym.Wrapper] = None
self.num_envs: int = 0
Expand Down Expand Up @@ -98,11 +103,6 @@ def __init__(self, config: AlgorithmConfig, **kwargs):
# Create the two connector pipelines: env-to-module and module-to-env.
self._module_to_env = self.config.build_module_to_env_connector(self.env)

# Create a MetricsLogger object for logging custom stats.
self.metrics = MetricsLogger()
# Initialize lifetime counts.
self.metrics.log_value(NUM_ENV_STEPS_SAMPLED_LIFETIME, 0, reduce="sum")

# This should be the default.
self._needs_initial_reset: bool = True
self._episodes: List[Optional[SingleAgentEpisode]] = [
Expand Down Expand Up @@ -702,6 +702,7 @@ def make_env(self) -> None:
# Call the `on_environment_created` callback.
self._callbacks.on_environment_created(
env_runner=self,
metrics_logger=self.metrics,
env=self.env,
env_context=env_ctx,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self, checkpoint_dir):
self._checkpoint_dir = checkpoint_dir
super().__init__()

def on_algorithm_init(self, *, algorithm, **kwargs):
def on_algorithm_init(self, *, algorithm, metrics_logger, **kwargs):
policy = Policy.from_checkpoint(
self._checkpoint_dir, policy_ids=[OPPONENT_POLICY_ID]
)
Expand Down
1 change: 1 addition & 0 deletions rllib/examples/curriculum/curriculum_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def on_train_result(
self,
*,
algorithm: Algorithm,
metrics_logger=None,
result: dict,
**kwargs,
) -> None:
Expand Down
Loading
Loading