From a95ec7fad786a3dfac995ccb126398f3f046e8df Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Fri, 31 May 2024 18:24:27 +0200 Subject: [PATCH] [RLlib] DreamerV3 on tf: Fix bug w/ `reduce_fn` still passed into `LearnerGroup.update_from_batch()`. (#45419) --- rllib/algorithms/dqn/dqn.py | 19 +- rllib/algorithms/dreamerv3/dreamerv3.py | 222 +++++++++------ .../algorithms/dreamerv3/dreamerv3_learner.py | 56 +--- .../dreamerv3/dreamerv3_rl_module.py | 86 ++++-- .../dreamerv3/tests/test_dreamerv3.py | 126 ++++----- .../dreamerv3/tf/dreamerv3_tf_learner.py | 33 ++- .../dreamerv3/tf/dreamerv3_tf_rl_module.py | 40 +-- .../dreamerv3/tf/models/actor_network.py | 8 +- .../tf/models/components/cnn_atari.py | 2 +- .../models/components/continue_predictor.py | 4 +- .../tf/models/components/reward_predictor.py | 4 +- .../tf/models/components/sequence_model.py | 5 +- .../tf/models/components/vector_decoder.py | 4 +- .../dreamerv3/tf/models/critic_network.py | 4 +- .../dreamerv3/tf/models/dreamer_model.py | 1 + rllib/algorithms/dreamerv3/utils/__init__.py | 2 +- .../algorithms/dreamerv3/utils/env_runner.py | 266 ++++++++++++------ rllib/algorithms/dreamerv3/utils/summaries.py | 253 ++++++++++------- rllib/algorithms/impala/impala.py | 7 + rllib/algorithms/ppo/ppo.py | 9 +- .../ppo/tests/test_ppo_with_env_runner.py | 8 +- .../ppo/tests/test_ppo_with_rl_module.py | 8 +- rllib/core/learner/learner.py | 121 ++++---- rllib/core/learner/learner_group.py | 38 ++- rllib/core/learner/torch/torch_learner.py | 2 +- rllib/core/models/torch/heads.py | 2 +- rllib/core/models/torch/primitives.py | 25 +- rllib/core/models/torch/utils.py | 23 +- rllib/core/rl_module/torch/torch_rl_module.py | 3 +- .../examples/catalogs/mobilenet_v2_encoder.py | 5 +- .../examples/evaluation/custom_evaluation.py | 1 + rllib/execution/rollout_ops.py | 9 +- rllib/tuned_examples/dreamerv3/atari_100k.py | 50 ++-- rllib/tuned_examples/dreamerv3/atari_200M.py | 53 ++-- .../dreamerv3/dm_control_suite_vision.py | 41 ++- rllib/utils/metrics/__init__.py | 1 + rllib/utils/metrics/metrics_logger.py | 30 +- rllib/utils/metrics/stats.py | 40 ++- 38 files changed, 951 insertions(+), 660 deletions(-) diff --git a/rllib/algorithms/dqn/dqn.py b/rllib/algorithms/dqn/dqn.py index 2674f99dbbd5d..3bf51e0edd42c 100644 --- a/rllib/algorithms/dqn/dqn.py +++ b/rllib/algorithms/dqn/dqn.py @@ -629,6 +629,11 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict: env_runner_results, key=ENV_RUNNER_RESULTS ) + self.metrics.log_dict( + self.metrics.peek(ENV_RUNNER_RESULTS, NUM_AGENT_STEPS_SAMPLED, default={}), + key=NUM_AGENT_STEPS_SAMPLED_LIFETIME, + reduce="sum", + ) self.metrics.log_value( NUM_ENV_STEPS_SAMPLED_LIFETIME, self.metrics.peek(ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED, default=0), @@ -639,11 +644,6 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict: self.metrics.peek(ENV_RUNNER_RESULTS, NUM_EPISODES, default=0), reduce="sum", ) - self.metrics.log_dict( - self.metrics.peek(ENV_RUNNER_RESULTS, NUM_AGENT_STEPS_SAMPLED, default={}), - key=NUM_AGENT_STEPS_SAMPLED_LIFETIME, - reduce="sum", - ) self.metrics.log_dict( self.metrics.peek(ENV_RUNNER_RESULTS, NUM_MODULE_STEPS_SAMPLED, default={}), key=NUM_MODULE_STEPS_SAMPLED_LIFETIME, @@ -680,6 +680,14 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict: with self.metrics.log_time((TIMERS, LEARNER_UPDATE_TIMER)): learner_results = self.learner_group.update_from_episodes( episodes=episodes, + timesteps={ + NUM_ENV_STEPS_SAMPLED_LIFETIME: ( + self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME) + ), + NUM_AGENT_STEPS_SAMPLED_LIFETIME: ( + self.metrics.peek(NUM_AGENT_STEPS_SAMPLED_LIFETIME) + ), + }, ) # Isolate TD-errors from result dicts (we should not log these to # disk or WandB, they might be very large). @@ -730,6 +738,7 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict: # Update the target networks, if necessary. with self.metrics.log_time((TIMERS, LEARNER_ADDITIONAL_UPDATE_TIMER)): modules_to_update = set(learner_results[0].keys()) - {ALL_MODULES} + # TODO (sven): Move to Learner._after_gradient_based_update(). additional_results = self.learner_group.additional_update( module_ids_to_update=modules_to_update, timestep=current_ts, diff --git a/rllib/algorithms/dreamerv3/dreamerv3.py b/rllib/algorithms/dreamerv3/dreamerv3.py index abb163f8b6711..50bcce11ad909 100644 --- a/rllib/algorithms/dreamerv3/dreamerv3.py +++ b/rllib/algorithms/dreamerv3/dreamerv3.py @@ -10,11 +10,9 @@ import gc import logging -import tree # pip install dm_tree -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, Optional, Union import gymnasium as gym -import numpy as np from ray.rllib.algorithms.algorithm import Algorithm from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided @@ -22,30 +20,36 @@ from ray.rllib.algorithms.dreamerv3.utils import do_symlog_obs from ray.rllib.algorithms.dreamerv3.utils.env_runner import DreamerV3EnvRunner from ray.rllib.algorithms.dreamerv3.utils.summaries import ( + report_dreamed_eval_trajectory_vs_samples, report_predicted_vs_sampled_obs, report_sampling_and_replay_buffer, ) from ray.rllib.core import DEFAULT_MODULE_ID from ray.rllib.core.columns import Columns from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.execution.rollout_ops import synchronous_parallel_sample from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils import deep_update from ray.rllib.utils.annotations import override, PublicAPI from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.numpy import one_hot from ray.rllib.utils.metrics import ( - ALL_MODULES, + ENV_RUNNER_RESULTS, GARBAGE_COLLECTION_TIMER, LEARN_ON_BATCH_TIMER, + LEARNER_RESULTS, NUM_AGENT_STEPS_SAMPLED, - NUM_AGENT_STEPS_TRAINED, + NUM_AGENT_STEPS_SAMPLED_LIFETIME, NUM_ENV_STEPS_SAMPLED, - NUM_ENV_STEPS_TRAINED, + NUM_ENV_STEPS_SAMPLED_LIFETIME, + NUM_ENV_STEPS_TRAINED_LIFETIME, + NUM_EPISODES, + NUM_EPISODES_LIFETIME, NUM_GRAD_UPDATES_LIFETIME, NUM_SYNCH_WORKER_WEIGHTS, - NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS, SAMPLE_TIMER, SYNCH_WORKER_WEIGHTS_TIMER, + TIMERS, ) from ray.rllib.utils.replay_buffers.episode_replay_buffer import EpisodeReplayBuffer from ray.rllib.utils.typing import LearningRateOrSchedule, ResultDict @@ -145,10 +149,6 @@ def __init__(self, algo_class=None): self.env_runner_cls = DreamerV3EnvRunner self.num_env_runners = 0 self.rollout_fragment_length = 1 - # Since we are using a gymnasium-based EnvRunner, we can utilitze its - # vectorization capabilities w/o suffering performance losses (as we would - # with RLlib's `RemoteVectorEnv`). - self.remote_worker_envs = True # Dreamer only runs on the new API stack. self.enable_rl_module_and_learner = True self.enable_env_runner_and_connector_v2 = True @@ -506,8 +506,6 @@ def setup(self, config: AlgorithmConfig): @override(Algorithm) def training_step(self) -> ResultDict: - results = {} - env_runner = self.workers.local_worker() # Push enough samples into buffer initially before we start training. @@ -520,7 +518,7 @@ def training_step(self) -> ResultDict: # Have we sampled yet in this `training_step()` call? have_sampled = False - with self._timers[SAMPLE_TIMER]: + with self.metrics.log_time((TIMERS, SAMPLE_TIMER)): # Continue sampling from the actual environment (and add collected samples # to our replay buffer) as long as we: while ( @@ -535,45 +533,76 @@ def training_step(self) -> ResultDict: or not have_sampled ): # Sample using the env runner's module. - done_episodes, ongoing_episodes = env_runner.sample() + episodes, env_runner_results = synchronous_parallel_sample( + worker_set=self.workers, + max_agent_steps=( + self.config.rollout_fragment_length + * self.config.num_envs_per_env_runner + ), + sample_timeout_s=self.config.sample_timeout_s, + _uses_new_env_runners=True, + _return_metrics=True, + ) + self.metrics.merge_and_log_n_dicts( + env_runner_results, key=ENV_RUNNER_RESULTS + ) # Add ongoing and finished episodes into buffer. The buffer will # automatically take care of properly concatenating (by episode IDs) # the different chunks of the same episodes, even if they come in via # separate `add()` calls. - self.replay_buffer.add(episodes=done_episodes + ongoing_episodes) + self.replay_buffer.add(episodes=episodes) have_sampled = True # We took B x T env steps. - env_steps_last_regular_sample = sum( - len(eps) for eps in done_episodes + ongoing_episodes - ) + env_steps_last_regular_sample = sum(len(eps) for eps in episodes) total_sampled = env_steps_last_regular_sample # If we have never sampled before (just started the algo and not # recovered from a checkpoint), sample B random actions first. - if self._counters[NUM_AGENT_STEPS_SAMPLED] == 0: - d_, o_ = env_runner.sample( - num_timesteps=( + if self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME, default=0) == 0: + _episodes, _env_runner_results = synchronous_parallel_sample( + worker_set=self.workers, + max_agent_steps=( self.config.batch_size_B * self.config.batch_length_T - ) - - env_steps_last_regular_sample, + - env_steps_last_regular_sample + ), + sample_timeout_s=self.config.sample_timeout_s, random_actions=True, + _uses_new_env_runners=True, + _return_metrics=True, ) - self.replay_buffer.add(episodes=d_ + o_) - total_sampled += sum(len(eps) for eps in d_ + o_) - - self._counters[NUM_AGENT_STEPS_SAMPLED] += total_sampled - self._counters[NUM_ENV_STEPS_SAMPLED] += total_sampled + self.metrics.merge_and_log_n_dicts( + _env_runner_results, key=ENV_RUNNER_RESULTS + ) + self.replay_buffer.add(episodes=_episodes) + total_sampled += sum(len(eps) for eps in _episodes) + + # Update lifetime counts (now that we gathered results from all + # EnvRunners). + self.metrics.log_dict( + { + NUM_AGENT_STEPS_SAMPLED_LIFETIME: self.metrics.peek( + ENV_RUNNER_RESULTS, NUM_AGENT_STEPS_SAMPLED + ), + NUM_ENV_STEPS_SAMPLED_LIFETIME: self.metrics.peek( + ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED + ), + NUM_EPISODES_LIFETIME: self.metrics.peek( + ENV_RUNNER_RESULTS, NUM_EPISODES + ), + }, + reduce="sum", + ) # Summarize environment interaction and buffer data. - results[ALL_MODULES] = report_sampling_and_replay_buffer( - replay_buffer=self.replay_buffer, + report_sampling_and_replay_buffer( + metrics=self.metrics, replay_buffer=self.replay_buffer ) # Continue sampling batch_size_B x batch_length_T sized batches from the buffer - # and using these to update our models (`LearnerGroup.update()`) until the - # computed `training_ratio` is larger than the configured one, meaning we should - # go back and collect more samples again from the actual environment. + # and using these to update our models (`LearnerGroup.update_from_batch()`) + # until the computed `training_ratio` is larger than the configured one, meaning + # we should go back and collect more samples again from the actual environment. # However, when calculating the `training_ratio` here, we use only the # trained steps in this very `training_step()` call over the most recent sample # amount (`env_steps_last_regular_sample`), not the global values. This is to @@ -584,7 +613,7 @@ def training_step(self) -> ResultDict: replayed_steps_this_iter / env_steps_last_regular_sample ) < self.config.training_ratio: # Time individual batch updates. - with self._timers[LEARN_ON_BATCH_TIMER]: + with self.metrics.log_time((TIMERS, LEARN_ON_BATCH_TIMER)): logger.info(f"\tSub-iteration {self.training_iteration}/{sub_iter})") # Draw a new sample from the replay buffer. @@ -603,65 +632,76 @@ def training_step(self) -> ResultDict: ) # Perform the actual update via our learner group. - train_results = self.learner_group.update_from_batch( + learner_results = self.learner_group.update_from_batch( batch=SampleBatch(sample).as_multi_agent(), - reduce_fn=self._reduce_results, + # TODO(sven): Maybe we should do this broadcase of global timesteps + # at the end, like for EnvRunner global env step counts. Maybe when + # we request the state from the Learners, we can - at the same + # time - send the current globally summed/reduced-timesteps. + timesteps={ + NUM_ENV_STEPS_SAMPLED_LIFETIME: self.metrics.peek( + NUM_ENV_STEPS_SAMPLED_LIFETIME, default=0 + ) + }, ) - self._counters[NUM_AGENT_STEPS_TRAINED] += replayed_steps - self._counters[NUM_ENV_STEPS_TRAINED] += replayed_steps - - # Perform additional (non-gradient updates), such as the critic EMA-copy - # update. - with self._timers["critic_ema_update"]: - self.learner_group.additional_update( - timestep=self._counters[NUM_ENV_STEPS_SAMPLED], - reduce_fn=self._reduce_results, - ) - - if self.config.report_images_and_videos: - report_predicted_vs_sampled_obs( - # TODO (sven): DreamerV3 is single-agent only. - results=train_results[DEFAULT_MODULE_ID], - sample=sample, - batch_size_B=self.config.batch_size_B, - batch_length_T=self.config.batch_length_T, - symlog_obs=do_symlog_obs( - env_runner.env.single_observation_space, - self.config.symlog_obs, - ), - ) - - res = train_results[DEFAULT_MODULE_ID] - logger.info( - f"\t\tWORLD_MODEL_L_total={res['WORLD_MODEL_L_total']:.5f} (" - f"L_pred={res['WORLD_MODEL_L_prediction']:.5f} (" - f"decoder/obs={res['WORLD_MODEL_L_decoder']} " - f"L_rew={res['WORLD_MODEL_L_reward']} " - f"L_cont={res['WORLD_MODEL_L_continue']}); " - f"L_dyn/rep={res['WORLD_MODEL_L_dynamics']:.5f})" + self.metrics.merge_and_log_n_dicts(learner_results, key=LEARNER_RESULTS) + self.metrics.log_value( + NUM_ENV_STEPS_TRAINED_LIFETIME, replayed_steps, reduce="sum" ) - msg = "\t\t" - if self.config.train_actor: - msg += f"L_actor={res['ACTOR_L_total']:.5f} " - if self.config.train_critic: - msg += f"L_critic={res['CRITIC_L_total']:.5f} " - logger.info(msg) sub_iter += 1 - self._counters[NUM_GRAD_UPDATES_LIFETIME] += 1 + self.metrics.log_value(NUM_GRAD_UPDATES_LIFETIME, 1, reduce="sum") + + # Log videos showing how the decoder produces observation predictions + # from the posterior states. + # Only every n iterations and only for the first sampled batch row + # (videos are `config.batch_length_T` frames long). + report_predicted_vs_sampled_obs( + # TODO (sven): DreamerV3 is single-agent only. + metrics=self.metrics, + sample=sample, + batch_size_B=self.config.batch_size_B, + batch_length_T=self.config.batch_length_T, + symlog_obs=do_symlog_obs( + env_runner.env.single_observation_space, + self.config.symlog_obs, + ), + do_report=( + self.config.report_images_and_videos + and self.training_iteration % 100 == 0 + ), + ) + + # Log videos showing some of the dreamed trajectories and compare them with the + # actual trajectories from the train batch. + # Only every n iterations and only for the first sampled batch row AND first ts. + # (videos are `config.horizon_H` frames long originating from the observation + # at B=0 and T=0 in the train batch). + report_dreamed_eval_trajectory_vs_samples( + metrics=self.metrics, + sample=sample, + burn_in_T=0, + dreamed_T=self.config.horizon_H + 1, + dreamer_model=self.workers.local_worker().module.dreamer_model, + symlog_obs=do_symlog_obs( + env_runner.env.single_observation_space, + self.config.symlog_obs, + ), + do_report=( + self.config.report_dream_data and self.training_iteration % 100 == 0 + ), + ) # Update weights - after learning on the LearnerGroup - on all EnvRunner # workers. - with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: + with self.metrics.log_time((TIMERS, SYNCH_WORKER_WEIGHTS_TIMER)): # Only necessary if RLModule is not shared between (local) EnvRunner and # (local) Learner. if not self.config.share_module_between_env_runner_and_learner: - self._counters[ - NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS - ] = 0 - self._counters[NUM_SYNCH_WORKER_WEIGHTS] += 1 + self.metrics.log_value(NUM_SYNCH_WORKER_WEIGHTS, 1, reduce="sum") self.workers.sync_weights( - from_worker_or_learner_group=self.learner_group + from_worker_or_learner_group=self.learner_group, + inference_only=True, ) # Try trick from https://medium.com/dive-into-ml-ai/dealing-with-memory-leak- @@ -669,33 +709,29 @@ def training_step(self) -> ResultDict: if self.config.gc_frequency_train_steps and ( self.training_iteration % self.config.gc_frequency_train_steps == 0 ): - with self._timers[GARBAGE_COLLECTION_TIMER]: + with self.metrics.log_time((TIMERS, GARBAGE_COLLECTION_TIMER)): gc.collect() # Add train results and the actual training ratio to stats. The latter should # be close to the configured `training_ratio`. - results.update(train_results) - results[ALL_MODULES]["actual_training_ratio"] = self.training_ratio + self.metrics.log_value("actual_training_ratio", self.training_ratio, window=1) # Return all results. - return results + return self.metrics.reduce() @property def training_ratio(self) -> float: - """Returns the actual training ratio of this Algorithm. + """Returns the actual training ratio of this Algorithm (not the configured one). The training ratio is copmuted by dividing the total number of steps trained thus far (replayed from the buffer) over the total number of actual env steps taken thus far. """ - return self._counters[NUM_ENV_STEPS_TRAINED] / ( - self._counters[NUM_ENV_STEPS_SAMPLED] + eps = 0.0001 + return self.metrics.peek(NUM_ENV_STEPS_TRAINED_LIFETIME, default=0) / ( + (self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME, default=eps) or eps) ) - @staticmethod - def _reduce_results(results: List[Dict[str, Any]]): - return tree.map_structure(lambda *s: np.mean(s, axis=0), *results) - # TODO (sven): Remove this once DreamerV3 is on the new SingleAgentEnvRunner. @PublicAPI def __setstate__(self, state) -> None: diff --git a/rllib/algorithms/dreamerv3/dreamerv3_learner.py b/rllib/algorithms/dreamerv3/dreamerv3_learner.py index 2ee6f4b16187c..684829e3f1946 100644 --- a/rllib/algorithms/dreamerv3/dreamerv3_learner.py +++ b/rllib/algorithms/dreamerv3/dreamerv3_learner.py @@ -7,61 +7,25 @@ D. Hafner, T. Lillicrap, M. Norouzi, J. Ba https://arxiv.org/pdf/2010.02193.pdf """ -from typing import Any, DefaultDict, Dict - -from ray.rllib.algorithms.dreamerv3.dreamerv3 import DreamerV3Config from ray.rllib.core.learner.learner import Learner -from ray.rllib.policy.sample_batch import MultiAgentBatch -from ray.rllib.utils.annotations import override -from ray.rllib.utils.typing import ModuleID, TensorType +from ray.rllib.utils.annotations import ( + override, + OverrideToImplementCustomLogic_CallToSuperRecommended, +) class DreamerV3Learner(Learner): """DreamerV3 specific Learner class. - Only implements the `additional_update_for_module()` method to define the logic + Only implements the `_after_gradient_based_update()` method to define the logic for updating the critic EMA-copy after each training step. """ + @OverrideToImplementCustomLogic_CallToSuperRecommended @override(Learner) - def compile_results( - self, - *, - batch: MultiAgentBatch, - fwd_out: Dict[str, Any], - loss_per_module: Dict[str, TensorType], - metrics_per_module: DefaultDict[ModuleID, Dict[str, Any]], - ) -> Dict[str, Any]: - results = super().compile_results( - batch=batch, - fwd_out=fwd_out, - loss_per_module=loss_per_module, - metrics_per_module=metrics_per_module, - ) - - # Add the predicted obs distributions for possible (video) summarization. - if self.config.report_images_and_videos: - for module_id, res in results.items(): - if module_id in fwd_out: - res["WORLD_MODEL_fwd_out_obs_distribution_means_BxT"] = fwd_out[ - module_id - ]["obs_distribution_means_BxT"] - return results - - @override(Learner) - def additional_update_for_module( - self, - *, - module_id: ModuleID, - config: DreamerV3Config, - timestep: int, - ) -> None: - """Updates the EMA weights of the critic network.""" - - # Call the base class' method. - super().additional_update_for_module( - module_id=module_id, config=config, timestep=timestep - ) + def _after_gradient_based_update(self, timesteps): + super()._after_gradient_based_update(timesteps) # Update EMA weights of the critic. - self.module[module_id].critic.update_ema() + for module_id, module in self.module._rl_modules.items(): + module.critic.update_ema() diff --git a/rllib/algorithms/dreamerv3/dreamerv3_rl_module.py b/rllib/algorithms/dreamerv3/dreamerv3_rl_module.py index 9d5bf26055297..c95363eaa9074 100644 --- a/rllib/algorithms/dreamerv3/dreamerv3_rl_module.py +++ b/rllib/algorithms/dreamerv3/dreamerv3_rl_module.py @@ -3,6 +3,7 @@ """ import abc +from typing import Any, Dict import gymnasium as gym import numpy as np @@ -29,6 +30,8 @@ class DreamerV3RLModule(RLModule, abc.ABC): @override(RLModule) def setup(self): + super().setup() + # Gather model-relevant settings. B = 1 T = self.config.model_config_dict["batch_length_T"] @@ -79,35 +82,39 @@ def setup(self): self.action_dist_cls = catalog.get_action_dist_cls(framework=self.framework) # Perform a test `call()` to force building the dreamer model's variables. - test_obs = np.tile( - np.expand_dims(self.config.observation_space.sample(), (0, 1)), - reps=(B, T) + (1,) * len(self.config.observation_space.shape), - ) - if isinstance(self.config.action_space, gym.spaces.Discrete): - test_actions = np.tile( - np.expand_dims( - one_hot( - self.config.action_space.sample(), - depth=self.config.action_space.n, + if self.framework == "tf2": + test_obs = np.tile( + np.expand_dims(self.config.observation_space.sample(), (0, 1)), + reps=(B, T) + (1,) * len(self.config.observation_space.shape), + ) + if isinstance(self.config.action_space, gym.spaces.Discrete): + test_actions = np.tile( + np.expand_dims( + one_hot( + self.config.action_space.sample(), + depth=self.config.action_space.n, + ), + (0, 1), ), - (0, 1), + reps=(B, T, 1), + ) + else: + test_actions = np.tile( + np.expand_dims(self.config.action_space.sample(), (0, 1)), + reps=(B, T, 1), + ) + + self.dreamer_model( + inputs=None, + observations=_convert_to_tf(test_obs, dtype=tf.float32), + actions=_convert_to_tf(test_actions, dtype=tf.float32), + is_first=_convert_to_tf(np.ones((B, T)), dtype=tf.bool), + start_is_terminated_BxT=_convert_to_tf( + np.zeros((B * T,)), dtype=tf.bool ), - reps=(B, T, 1), - ) - else: - test_actions = np.tile( - np.expand_dims(self.config.action_space.sample(), (0, 1)), - reps=(B, T, 1), + gamma=gamma, ) - self.dreamer_model( - None, - _convert_to_tf(test_obs, dtype=tf.float32), - _convert_to_tf(test_actions, dtype=tf.float32), - _convert_to_tf(np.ones((B, T)), dtype=tf.bool), - _convert_to_tf(np.zeros((B * T,)), dtype=tf.bool), - ) - # Initialize the critic EMA net: self.critic.init_ema() @@ -152,3 +159,32 @@ def output_specs_train(self) -> SpecDict: # Deterministic, continuous h-states (t1 to T). "h_states_BxT", ] + + @override(RLModule) + def _forward_inference(self, batch: NestedDict) -> Dict[str, Any]: + # Call the Dreamer-Model's forward_inference method and return a dict. + actions, next_state = self.dreamer_model.forward_inference( + observations=batch[Columns.OBS], + previous_states=batch[Columns.STATE_IN], + is_first=batch["is_first"], + ) + return {Columns.ACTIONS: actions, Columns.STATE_OUT: next_state} + + @override(RLModule) + def _forward_exploration(self, batch: NestedDict) -> Dict[str, Any]: + # Call the Dreamer-Model's forward_exploration method and return a dict. + actions, next_state = self.dreamer_model.forward_exploration( + observations=batch[Columns.OBS], + previous_states=batch[Columns.STATE_IN], + is_first=batch["is_first"], + ) + return {Columns.ACTIONS: actions, Columns.STATE_OUT: next_state} + + @override(RLModule) + def _forward_train(self, batch: NestedDict): + # Call the Dreamer-Model's forward_train method and return its outputs as-is. + return self.dreamer_model.forward_train( + observations=batch[Columns.OBS], + actions=batch[Columns.ACTIONS], + is_first=batch["is_first"], + ) diff --git a/rllib/algorithms/dreamerv3/tests/test_dreamerv3.py b/rllib/algorithms/dreamerv3/tests/test_dreamerv3.py index 92bb33dda483d..0c8875b54f10b 100644 --- a/rllib/algorithms/dreamerv3/tests/test_dreamerv3.py +++ b/rllib/algorithms/dreamerv3/tests/test_dreamerv3.py @@ -40,6 +40,7 @@ def test_dreamerv3_compilation(self): # Build a DreamerV3Config object. config = ( dreamerv3.DreamerV3Config() + .framework(eager_tracing=False) .training( # Keep things simple. Especially the long dream rollouts seem # to take an enormous amount of time (initially). @@ -51,7 +52,7 @@ def test_dreamerv3_compilation(self): use_float16=False, ) .learners( - num_learners=2, # Try with 2 Learners. + num_learners=0, # TODO 2 # Try with 2 Learners. num_cpus_per_learner=1, num_gpus_per_learner=0, ) @@ -59,70 +60,69 @@ def test_dreamerv3_compilation(self): num_iterations = 2 - for _ in framework_iterator(config, frameworks="tf2"): - for env in [ - "FrozenLake-v1", - "CartPole-v1", - "ALE/MsPacman-v5", - "Pendulum-v1", - ]: - print("Env={}".format(env)) - # Add one-hot observations for FrozenLake env. - if env == "FrozenLake-v1": - - def env_creator(ctx): - import gymnasium as gym - from ray.rllib.algorithms.dreamerv3.utils.env_runner import ( - OneHot, - ) - - return OneHot(gym.make("FrozenLake-v1")) - - tune.register_env("frozen-lake-one-hot", env_creator) - env = "frozen-lake-one-hot" - - config.environment(env) - algo = config.build() - obs_space = algo.workers.local_worker().env.single_observation_space - act_space = algo.workers.local_worker().env.single_action_space - rl_module = algo.workers.local_worker().module - - for i in range(num_iterations): - results = algo.train() - print(results) - # Test dream trajectory w/ recreated observations. - sample = algo.replay_buffer.sample() - dream = rl_module.dreamer_model.dream_trajectory_with_burn_in( - start_states=rl_module.dreamer_model.get_initial_state(), - timesteps_burn_in=5, - timesteps_H=45, - observations=sample["obs"][:1], # B=1 - actions=( - one_hot( - sample["actions"], - depth=act_space.n, - ) - if isinstance(act_space, gym.spaces.Discrete) - else sample["actions"] - )[ - :1 - ], # B=1 - ) - self.assertTrue( - dream["actions_dreamed_t0_to_H_BxT"].shape - == (46, 1) - + ( - (act_space.n,) - if isinstance(act_space, gym.spaces.Discrete) - else tuple(act_space.shape) + for env in [ + "FrozenLake-v1", + "CartPole-v1", + "ALE/MsPacman-v5", + "Pendulum-v1", + ]: + print("Env={}".format(env)) + # Add one-hot observations for FrozenLake env. + if env == "FrozenLake-v1": + + def env_creator(ctx): + import gymnasium as gym + from ray.rllib.algorithms.dreamerv3.utils.env_runner import ( + OneHot, ) + + return OneHot(gym.make("FrozenLake-v1")) + + tune.register_env("frozen-lake-one-hot", env_creator) + env = "frozen-lake-one-hot" + + config.environment(env) + algo = config.build() + obs_space = algo.workers.local_worker().env.single_observation_space + act_space = algo.workers.local_worker().env.single_action_space + rl_module = algo.workers.local_worker().module + + for i in range(num_iterations): + results = algo.train() + print(results) + # Test dream trajectory w/ recreated observations. + sample = algo.replay_buffer.sample() + dream = rl_module.dreamer_model.dream_trajectory_with_burn_in( + start_states=rl_module.dreamer_model.get_initial_state(), + timesteps_burn_in=5, + timesteps_H=45, + observations=sample["obs"][:1], # B=1 + actions=( + one_hot( + sample["actions"], + depth=act_space.n, + ) + if isinstance(act_space, gym.spaces.Discrete) + else sample["actions"] + )[ + :1 + ], # B=1 + ) + self.assertTrue( + dream["actions_dreamed_t0_to_H_BxT"].shape + == (46, 1) + + ( + (act_space.n,) + if isinstance(act_space, gym.spaces.Discrete) + else tuple(act_space.shape) ) - self.assertTrue(dream["continues_dreamed_t0_to_H_BxT"].shape == (46, 1)) - self.assertTrue( - dream["observations_dreamed_t0_to_H_BxT"].shape - == [46, 1] + list(obs_space.shape) - ) - algo.stop() + ) + self.assertTrue(dream["continues_dreamed_t0_to_H_BxT"].shape == (46, 1)) + self.assertTrue( + dream["observations_dreamed_t0_to_H_BxT"].shape + == [46, 1] + list(obs_space.shape) + ) + algo.stop() def test_dreamerv3_dreamer_model_sizes(self): """Tests, whether the different model sizes match the ones reported in [1].""" diff --git a/rllib/algorithms/dreamerv3/tf/dreamerv3_tf_learner.py b/rllib/algorithms/dreamerv3/tf/dreamerv3_tf_learner.py index 2fa6dce2f02d7..d35717e4aa44d 100644 --- a/rllib/algorithms/dreamerv3/tf/dreamerv3_tf_learner.py +++ b/rllib/algorithms/dreamerv3/tf/dreamerv3_tf_learner.py @@ -18,7 +18,6 @@ from ray.rllib.core.learner.learner import ParamDict from ray.rllib.core.learner.tf.tf_learner import TfLearner from ray.rllib.utils.annotations import override -from ray.rllib.utils.deprecation import deprecation_warning from ray.rllib.utils.framework import try_import_tf, try_import_tfp from ray.rllib.utils.tf_utils import symlog, two_hot, clip_gradients from ray.rllib.utils.typing import ModuleID, TensorType @@ -40,7 +39,7 @@ class DreamerV3TfLearner(DreamerV3Learner, TfLearner): @override(TfLearner) def configure_optimizers_for_module( - self, module_id: ModuleID, config: DreamerV3Config = None, hps=None + self, module_id: ModuleID, config: DreamerV3Config = None ): """Create the 3 optimizers for Dreamer learning: world_model, actor, critic. @@ -48,12 +47,6 @@ def configure_optimizers_for_module( - albeit probably not that important - are used by the author's own implementation. """ - if hps is not None: - deprecation_warning( - old="Learner.configure_optimizers_for_module(.., hps=..)", - help="Deprecated argument. Use `config` (AlgorithmConfig) instead.", - error=True, - ) dreamerv3_module = self._module[module_id] @@ -242,10 +235,20 @@ def compute_loss_for_module( key=module_id, window=1, # <- single items (should not be mean/ema-reduced over time). ) + + # Add the predicted obs distributions for possible (video) summarization. + if config.report_images_and_videos: + self.metrics.log_value( + (module_id, "WORLD_MODEL_fwd_out_obs_distribution_means_b0xT"), + fwd_out["obs_distribution_means_BxT"][: self.config.batch_length_T], + reduce=None, # No reduction, we want the tensor to stay in-tact. + window=1, # <- single items (should not be mean/ema-reduced over time). + ) + if config.report_individual_batch_item_stats: # Log important world-model loss stats. self.metrics.log_dict( - metrics_dict={ + { "WORLD_MODEL_L_decoder_B_T": prediction_losses["L_decoder_B_T"], "WORLD_MODEL_L_reward_B_T": prediction_losses["L_reward_B_T"], "WORLD_MODEL_L_continue_B_T": prediction_losses["L_continue_B_T"], @@ -270,10 +273,10 @@ def compute_loss_for_module( "h": fwd_out["h_states_BxT"], "z": fwd_out["z_posterior_states_BxT"], }, - start_is_terminated=tf.reshape(batch["is_terminated"], [-1]), # ->BxT + start_is_terminated=tf.reshape(batch["is_terminated"], [-1]), # -> BxT ) if config.report_dream_data: - # To reduce this massive mount of data a little, slice out a T=1 piece + # To reduce this massive amount of data a little, slice out a T=1 piece # from each stats that has the shape (H, BxT), meaning convert e.g. # `rewards_dreamed_t0_to_H_BxT` into `rewards_dreamed_t0_to_H_Bx1`. # This will reduce the amount of data to be transferred and reported @@ -281,9 +284,9 @@ def compute_loss_for_module( self.metrics.log_dict( { # Replace 'T' with '1'. - "DREAM_DATA_" - + key[:-1] - + "1": (value[:, config.batch_size_B_per_learner]) + f"DREAM_DATA_{key[:-1]}1": ( + value[:, config.batch_size_B_per_learner] + ) for key, value in dream_data.items() if key.endswith("H_BxT") }, @@ -733,7 +736,7 @@ def _compute_critic_loss( :-1 ] - # Reduce over H- (time) axis (sum) and then B-axis (mean). + # Reduce over both H- (time) axis and B-axis (mean). L_critic = tf.reduce_mean(L_critic_H_B) # Log important critic loss stats. diff --git a/rllib/algorithms/dreamerv3/tf/dreamerv3_tf_rl_module.py b/rllib/algorithms/dreamerv3/tf/dreamerv3_tf_rl_module.py index a05c516d29bdd..44952e8297419 100644 --- a/rllib/algorithms/dreamerv3/tf/dreamerv3_tf_rl_module.py +++ b/rllib/algorithms/dreamerv3/tf/dreamerv3_tf_rl_module.py @@ -7,17 +7,8 @@ D. Hafner, T. Lillicrap, M. Norouzi, J. Ba https://arxiv.org/pdf/2010.02193.pdf """ -from typing import Any, Dict - from ray.rllib.algorithms.dreamerv3.dreamerv3_rl_module import DreamerV3RLModule -from ray.rllib.core.columns import Columns -from ray.rllib.core.rl_module.rl_module import RLModule from ray.rllib.core.rl_module.tf.tf_rl_module import TfRLModule -from ray.rllib.utils.annotations import override -from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.nested_dict import NestedDict - -tf1, tf, _ = try_import_tf() class DreamerV3TfRLModule(TfRLModule, DreamerV3RLModule): @@ -26,33 +17,4 @@ class DreamerV3TfRLModule(TfRLModule, DreamerV3RLModule): Serves mainly as a thin-wrapper around the `DreamerModel` (a tf.keras.Model) class. """ - framework: str = "tf2" - - @override(RLModule) - def _forward_inference(self, batch: NestedDict) -> Dict[str, Any]: - # Call the Dreamer-Model's forward_inference method and return a dict. - actions, next_state = self.dreamer_model.forward_inference( - observations=batch[Columns.OBS], - previous_states=batch[Columns.STATE_IN], - is_first=batch["is_first"], - ) - return {Columns.ACTIONS: actions, Columns.STATE_OUT: next_state} - - @override(RLModule) - def _forward_exploration(self, batch: NestedDict) -> Dict[str, Any]: - # Call the Dreamer-Model's forward_exploration method and return a dict. - actions, next_state = self.dreamer_model.forward_exploration( - observations=batch[Columns.OBS], - previous_states=batch[Columns.STATE_IN], - is_first=batch["is_first"], - ) - return {Columns.ACTIONS: actions, Columns.STATE_OUT: next_state} - - @override(RLModule) - def _forward_train(self, batch: NestedDict): - # Call the Dreamer-Model's forward_train method and return its outputs as-is. - return self.dreamer_model.forward_train( - observations=batch[Columns.OBS], - actions=batch[Columns.ACTIONS], - is_first=batch["is_first"], - ) + framework = "tf2" diff --git a/rllib/algorithms/dreamerv3/tf/models/actor_network.py b/rllib/algorithms/dreamerv3/tf/models/actor_network.py index 44785323711ff..6fe1a7ef5b712 100644 --- a/rllib/algorithms/dreamerv3/tf/models/actor_network.py +++ b/rllib/algorithms/dreamerv3/tf/models/actor_network.py @@ -3,8 +3,6 @@ D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap https://arxiv.org/pdf/2301.04104v1.pdf """ -from typing import Optional - import gymnasium as gym from gymnasium.spaces import Box, Discrete import numpy as np @@ -35,15 +33,15 @@ class ActorNetwork(tf.keras.Model): def __init__( self, *, - model_size: Optional[str] = "XS", + model_size: str = "XS", action_space: gym.Space, ): """Initializes an ActorNetwork instance. Args: - model_size: The "Model Size" used according to [1] Appendinx B. + model_size: The "Model Size" used according to [1] Appendix B. Use None for manually setting the different network sizes. - action_space: The action space the our environment used. + action_space: The action space the our environment used. """ super().__init__(name="actor") diff --git a/rllib/algorithms/dreamerv3/tf/models/components/cnn_atari.py b/rllib/algorithms/dreamerv3/tf/models/components/cnn_atari.py index 16e733ce4b17f..c0f7ee09b092b 100644 --- a/rllib/algorithms/dreamerv3/tf/models/components/cnn_atari.py +++ b/rllib/algorithms/dreamerv3/tf/models/components/cnn_atari.py @@ -23,7 +23,7 @@ def __init__( """Initializes a CNNAtari instance. Args: - model_size: The "Model Size" used according to [1] Appendinx B. + model_size: The "Model Size" used according to [1] Appendix B. Use None for manually setting the `cnn_multiplier`. cnn_multiplier: Optional override for the additional factor used to multiply the number of filters with each CNN layer. Starting with diff --git a/rllib/algorithms/dreamerv3/tf/models/components/continue_predictor.py b/rllib/algorithms/dreamerv3/tf/models/components/continue_predictor.py index dd948b9951f02..d5434d8aca315 100644 --- a/rllib/algorithms/dreamerv3/tf/models/components/continue_predictor.py +++ b/rllib/algorithms/dreamerv3/tf/models/components/continue_predictor.py @@ -3,8 +3,6 @@ D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap https://arxiv.org/pdf/2301.04104v1.pdf """ -from typing import Optional - from ray.rllib.algorithms.dreamerv3.tf.models.components.mlp import MLP from ray.rllib.algorithms.dreamerv3.utils import ( get_gru_units, @@ -29,7 +27,7 @@ class ContinuePredictor(tf.keras.Model): terminal. """ - def __init__(self, *, model_size: Optional[str] = "XS"): + def __init__(self, *, model_size: str = "XS"): """Initializes a ContinuePredictor instance. Args: diff --git a/rllib/algorithms/dreamerv3/tf/models/components/reward_predictor.py b/rllib/algorithms/dreamerv3/tf/models/components/reward_predictor.py index c281565897cbf..3e7cb6de93f97 100644 --- a/rllib/algorithms/dreamerv3/tf/models/components/reward_predictor.py +++ b/rllib/algorithms/dreamerv3/tf/models/components/reward_predictor.py @@ -3,8 +3,6 @@ D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap https://arxiv.org/pdf/2301.04104v1.pdf """ -from typing import Optional - from ray.rllib.algorithms.dreamerv3.tf.models.components.mlp import MLP from ray.rllib.algorithms.dreamerv3.tf.models.components.reward_predictor_layer import ( RewardPredictorLayer, @@ -28,7 +26,7 @@ class RewardPredictor(tf.keras.Model): def __init__( self, *, - model_size: Optional[str] = "XS", + model_size: str = "XS", num_buckets: int = 255, lower_bound: float = -20.0, upper_bound: float = 20.0, diff --git a/rllib/algorithms/dreamerv3/tf/models/components/sequence_model.py b/rllib/algorithms/dreamerv3/tf/models/components/sequence_model.py index 0d2de1970471a..7e21c9860578b 100644 --- a/rllib/algorithms/dreamerv3/tf/models/components/sequence_model.py +++ b/rllib/algorithms/dreamerv3/tf/models/components/sequence_model.py @@ -76,7 +76,6 @@ def __init__( num_gru_units, return_sequences=False, return_state=False, - time_major=True, # Note: Changing these activations is most likely a bad idea! # In experiments, setting one of both of them to silu deteriorated # performance significantly. @@ -139,7 +138,7 @@ def call(self, a, h, z): ) # Pass through pre-GRU layer. out = self.pre_gru_layer(out) - # Pass through (time-major) GRU. - h_next = self.gru_unit(tf.expand_dims(out, axis=0), initial_state=h) + # Pass through (batch-major) GRU (expand axis=1 as the time axis). + h_next = self.gru_unit(tf.expand_dims(out, axis=1), initial_state=h) # Return the GRU's output (the next h-state). return h_next diff --git a/rllib/algorithms/dreamerv3/tf/models/components/vector_decoder.py b/rllib/algorithms/dreamerv3/tf/models/components/vector_decoder.py index a384d1473bda1..e183561f9217e 100644 --- a/rllib/algorithms/dreamerv3/tf/models/components/vector_decoder.py +++ b/rllib/algorithms/dreamerv3/tf/models/components/vector_decoder.py @@ -3,8 +3,6 @@ D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap https://arxiv.org/pdf/2301.04104v1.pdf """ -from typing import Optional - import gymnasium as gym from ray.rllib.algorithms.dreamerv3.tf.models.components.mlp import MLP @@ -28,7 +26,7 @@ class VectorDecoder(tf.keras.Model): def __init__( self, *, - model_size: Optional[str] = "XS", + model_size: str = "XS", observation_space: gym.Space, ): """Initializes a VectorDecoder instance. diff --git a/rllib/algorithms/dreamerv3/tf/models/critic_network.py b/rllib/algorithms/dreamerv3/tf/models/critic_network.py index e2b2d45d94358..4eb9b99401336 100644 --- a/rllib/algorithms/dreamerv3/tf/models/critic_network.py +++ b/rllib/algorithms/dreamerv3/tf/models/critic_network.py @@ -3,8 +3,6 @@ D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap https://arxiv.org/pdf/2301.04104v1.pdf """ -from typing import Optional - from ray.rllib.algorithms.dreamerv3.tf.models.components.mlp import MLP from ray.rllib.algorithms.dreamerv3.tf.models.components.reward_predictor_layer import ( RewardPredictorLayer, @@ -33,7 +31,7 @@ class CriticNetwork(tf.keras.Model): def __init__( self, *, - model_size: Optional[str] = "XS", + model_size: str = "XS", num_buckets: int = 255, lower_bound: float = -20.0, upper_bound: float = 20.0, diff --git a/rllib/algorithms/dreamerv3/tf/models/dreamer_model.py b/rllib/algorithms/dreamerv3/tf/models/dreamer_model.py index 26a44d1fb3b02..dc4eec8579ae2 100644 --- a/rllib/algorithms/dreamerv3/tf/models/dreamer_model.py +++ b/rllib/algorithms/dreamerv3/tf/models/dreamer_model.py @@ -113,6 +113,7 @@ def call( actions, is_first, start_is_terminated_BxT, + gamma, ): """Main call method for building this model in order to generate its variables. diff --git a/rllib/algorithms/dreamerv3/utils/__init__.py b/rllib/algorithms/dreamerv3/utils/__init__.py index 592bbf9b32e82..fe7b58cf515ee 100644 --- a/rllib/algorithms/dreamerv3/utils/__init__.py +++ b/rllib/algorithms/dreamerv3/utils/__init__.py @@ -124,7 +124,7 @@ def get_num_curiosity_nets(model_size, override=None): num_curiosity_nets = { "nano": 8, "micro": 8, - "mini": 16, + "mini": 8, "XXS": 8, "XS": 8, "S": 8, diff --git a/rllib/algorithms/dreamerv3/utils/env_runner.py b/rllib/algorithms/dreamerv3/utils/env_runner.py index c0b73ef824fde..93ed2beb6240e 100644 --- a/rllib/algorithms/dreamerv3/utils/env_runner.py +++ b/rllib/algorithms/dreamerv3/utils/env_runner.py @@ -15,21 +15,42 @@ import numpy as np import tree # pip install dm_tree +import ray from ray.rllib.algorithms.algorithm_config import AlgorithmConfig -from ray.rllib.core import DEFAULT_MODULE_ID +from ray.rllib.core import DEFAULT_AGENT_ID, DEFAULT_MODULE_ID from ray.rllib.core.columns import Columns from ray.rllib.env.env_runner import EnvRunner +from ray.rllib.env.single_agent_episode import SingleAgentEpisode from ray.rllib.env.wrappers.atari_wrappers import NoopResetEnv, MaxAndSkipEnv from ray.rllib.env.wrappers.dm_control_wrapper import DMCEnv from ray.rllib.env.utils import _gym_env_creator -from ray.rllib.evaluation.metrics import RolloutMetrics from ray.rllib.utils.annotations import override -from ray.rllib.utils.framework import try_import_tf -from ray.rllib.env.single_agent_episode import SingleAgentEpisode -from ray.rllib.utils.numpy import one_hot +from ray.rllib.utils.framework import try_import_tf, try_import_torch +from ray.rllib.utils.metrics import ( + EPISODE_DURATION_SEC_MEAN, + EPISODE_LEN_MAX, + EPISODE_LEN_MEAN, + EPISODE_LEN_MIN, + EPISODE_RETURN_MAX, + EPISODE_RETURN_MEAN, + EPISODE_RETURN_MIN, + NUM_AGENT_STEPS_SAMPLED, + NUM_AGENT_STEPS_SAMPLED_LIFETIME, + NUM_EPISODES, + NUM_ENV_STEPS_SAMPLED, + NUM_ENV_STEPS_SAMPLED_LIFETIME, + NUM_MODULE_STEPS_SAMPLED, + NUM_MODULE_STEPS_SAMPLED_LIFETIME, +) +from ray.rllib.utils.metrics.metrics_logger import MetricsLogger +from ray.rllib.utils.numpy import convert_to_numpy, one_hot +from ray.rllib.utils.spaces.space_utils import batch, unbatch +from ray.rllib.utils.torch_utils import convert_to_torch_tensor +from ray.rllib.utils.typing import ResultDict from ray.tune.registry import ENV_CREATOR, _global_registry _, tf, _ = try_import_tf() +torch, _ = try_import_torch() # TODO (sven): Use SingleAgentEnvRunner instead of this as soon as we have the new @@ -70,23 +91,38 @@ def __init__( # However, in Danijar's repo, Atari100k experiments are configured as: # noop=30, 64x64x3 (no grayscaling), sticky actions=False, # full action space=False, - wrappers = [ - partial(gym.wrappers.TimeLimit, max_episode_steps=108000), - partial(resize_v1, x_size=64, y_size=64), # resize to 64x64 - NormalizedImageEnv, - NoopResetEnv, - MaxAndSkipEnv, - ] + + def _entry_point(): + return gym.make( + self.config.env, + **dict( + self.config.env_config, + **{ + # "sticky actions" but not according to Danijar's 100k + # configs. + "repeat_action_probability": 0.0, + # "full action space" but not according to Danijar's 100k + # configs. + "full_action_space": False, + # Already done by MaxAndSkip wrapper: "action repeat" == 4. + "frameskip": 1, + }, + ), + ) + + gym.register("rllib-single-agent-env-v0", entry_point=_entry_point) self.env = gym.vector.make( - "GymV26Environment-v0", - env_id=self.config.env, - wrappers=wrappers, + "rllib-single-agent-env-v0", num_envs=self.config.num_envs_per_env_runner, asynchronous=self.config.remote_worker_envs, - make_kwargs=dict( - self.config.env_config, **{"render_mode": "rgb_array"} - ), + wrappers=[ + partial(gym.wrappers.TimeLimit, max_episode_steps=108000), + partial(resize_v1, x_size=64, y_size=64), # resize to 64x64 + NormalizedImageEnv, + NoopResetEnv, + MaxAndSkipEnv, + ], ) # DeepMind Control. elif self.config.env.startswith("DMC/"): @@ -147,6 +183,24 @@ def __init__( # TODO (sven): DreamerV3 is currently single-agent only. self.module = self.marl_module_spec.build()[DEFAULT_MODULE_ID] + self.metrics = MetricsLogger() + + self._device = None + if ( + torch + and torch.cuda.is_available() + and self.config.framework_str == "torch" + and self.config.share_module_between_env_runner_and_learner + and self.config.num_gpus_per_learner > 0 + ): + gpu_ids = ray.get_gpu_ids() + self._device = f"cuda:{gpu_ids[0]}" + self.convert_to_tensor = ( + partial(convert_to_torch_tensor, device=self._device) + if self.config.framework_str == "torch" + else tf.convert_to_tensor + ) + self._needs_initial_reset = True self._episodes = [None for _ in range(self.num_envs)] self._states = [None for _ in range(self.num_envs)] @@ -158,7 +212,6 @@ def __init__( # via its replay buffer, etc..). self._done_episodes_for_metrics = [] self._ongoing_episodes_for_metrics = defaultdict(list) - self._ts_since_last_metrics = 0 @override(EnvRunner) def sample( @@ -228,7 +281,7 @@ def _sample_timesteps( explore: bool = True, random_actions: bool = False, force_reset: bool = False, - ) -> Tuple[List[SingleAgentEpisode], List[SingleAgentEpisode]]: + ) -> List[SingleAgentEpisode]: """Helper method to run n timesteps. See docstring of self.sample() for more details. @@ -238,47 +291,25 @@ def _sample_timesteps( # Get initial states for all `batch_size_B` rows in the forward batch. initial_states = tree.map_structure( lambda s: np.repeat(s, self.num_envs, axis=0), - self.module.get_initial_state(), + convert_to_numpy(self.module.get_initial_state()), ) # Have to reset the env (on all vector sub-envs). if force_reset or self._needs_initial_reset: obs, _ = self.env.reset() + self._needs_initial_reset = False self._episodes = [SingleAgentEpisode() for _ in range(self.num_envs)] - states = initial_states - # Set is_first to True for all rows (all sub-envs just got reset). - is_first = np.ones((self.num_envs,)) - self._needs_initial_reset = False # Set initial obs and states in the episodes. for i in range(self.num_envs): self._episodes[i].add_env_reset(observation=obs[i]) - self._states[i] = {k: s[i] for k, s in states.items()} + self._states[i] = None + # Don't reset existing envs; continue in already started episodes. else: # Pick up stored observations and states from previous timesteps. obs = np.stack([eps.observations[-1] for eps in self._episodes]) - # Compile the initial state for each batch row: If episode just started, use - # model's initial state, if not, use state stored last in - # SingleAgentEpisode. - states = { - k: np.stack( - [ - initial_states[k][i] - if self._states[i] is None - else self._states[i][k] - for i, eps in enumerate(self._episodes) - ] - ) - for k in initial_states.keys() - } - # If a batch row is at the beginning of an episode, set its `is_first` flag - # to 1.0, otherwise 0.0. - is_first = np.zeros((self.num_envs,)) - for i, eps in enumerate(self._episodes): - if len(eps) == 0: - is_first[i] = 1.0 # Loop through env for n timesteps. ts = 0 @@ -288,33 +319,35 @@ def _sample_timesteps( actions = self.env.action_space.sample() # Compute an action using our RLModule. else: - batch = { + is_first = np.zeros((self.num_envs,)) + for i, eps in enumerate(self._episodes): + if self._states[i] is None: + is_first[i] = 1.0 + self._states[i] = {k: s[i] for k, s in initial_states.items()} + to_module = { Columns.STATE_IN: tree.map_structure( - lambda s: tf.convert_to_tensor(s), states + lambda s: self.convert_to_tensor(s), batch(self._states) ), - Columns.OBS: tf.convert_to_tensor(obs), - "is_first": tf.convert_to_tensor(is_first), + Columns.OBS: self.convert_to_tensor(obs), + "is_first": self.convert_to_tensor(is_first), } # Explore or not. if explore: - outs = self.module.forward_exploration(batch) + outs = self.module.forward_exploration(to_module) else: - outs = self.module.forward_inference(batch) + outs = self.module.forward_inference(to_module) # Model outputs one-hot actions (if discrete). Convert to int actions # as well. - actions = outs[Columns.ACTIONS].numpy() + actions = convert_to_numpy(outs[Columns.ACTIONS]) if isinstance(self.env.single_action_space, gym.spaces.Discrete): actions = np.argmax(actions, axis=-1) - states = tree.map_structure( - lambda s: s.numpy(), outs[Columns.STATE_OUT] - ) + self._states = unbatch(convert_to_numpy(outs[Columns.STATE_OUT])) obs, rewards, terminateds, truncateds, infos = self.env.step(actions) ts += self.num_envs for i in range(self.num_envs): - s = {k: s[i] for k, s in states.items()} # The last entry in self.observations[i] is already the reset # obs of the new episode. if terminateds[i] or truncateds[i]: @@ -327,12 +360,7 @@ def _sample_timesteps( terminated=terminateds[i], truncated=truncateds[i], ) - self._states[i] = s - # Reset h-states to the model's initial ones b/c we are starting a - # new episode. - for k, v in self.module.get_initial_state().items(): - states[k][i] = v.numpy() - is_first[i] = True + self._states[i] = None done_episodes_to_return.append(self._episodes[i]) # Create a new episode object. self._episodes[i] = SingleAgentEpisode(observations=[obs[i]]) @@ -342,9 +370,6 @@ def _sample_timesteps( action=actions[i], reward=rewards[i], ) - is_first[i] = False - - self._states[i] = s # Return done episodes ... self._done_episodes_for_metrics.extend(done_episodes_to_return) @@ -356,9 +381,9 @@ def _sample_timesteps( for eps in ongoing_episodes: self._ongoing_episodes_for_metrics[eps.id_].append(eps) - self._ts_since_last_metrics += ts + self._increase_sampled_metrics(ts) - return done_episodes_to_return, ongoing_episodes + return done_episodes_to_return + ongoing_episodes def _sample_episodes( self, @@ -378,7 +403,7 @@ def _sample_episodes( # Multiply states n times according to our vector env batch size (num_envs). states = tree.map_structure( lambda s: np.repeat(s, self.num_envs, axis=0), - self.module.get_initial_state(), + convert_to_numpy(self.module.get_initial_state()), ) is_first = np.ones((self.num_envs,)) @@ -392,10 +417,10 @@ def _sample_episodes( else: batch = { Columns.STATE_IN: tree.map_structure( - lambda s: tf.convert_to_tensor(s), states + lambda s: self.convert_to_tensor(s), states ), - Columns.OBS: tf.convert_to_tensor(obs), - "is_first": tf.convert_to_tensor(is_first), + Columns.OBS: self.convert_to_tensor(obs), + "is_first": self.convert_to_tensor(is_first), } if explore: @@ -403,12 +428,10 @@ def _sample_episodes( else: outs = self.module.forward_inference(batch) - actions = outs[Columns.ACTIONS].numpy() + actions = convert_to_numpy(outs[Columns.ACTIONS]) if isinstance(self.env.single_action_space, gym.spaces.Discrete): actions = np.argmax(actions, axis=-1) - states = tree.map_structure( - lambda s: s.numpy(), outs[Columns.STATE_OUT] - ) + states = convert_to_numpy(outs[Columns.STATE_OUT]) obs, rewards, terminateds, truncateds, infos = self.env.step(actions) @@ -434,8 +457,10 @@ def _sample_episodes( # Reset h-states to the model's initial ones b/c we are starting a # new episode. - for k, v in self.module.get_initial_state().items(): - states[k][i] = v.numpy() + for k, v in convert_to_numpy( + self.module.get_initial_state() + ).items(): + states[k][i] = v is_first[i] = True episodes[i] = SingleAgentEpisode(observations=[obs[i]]) @@ -448,41 +473,50 @@ def _sample_episodes( is_first[i] = False self._done_episodes_for_metrics.extend(done_episodes_to_return) - self._ts_since_last_metrics += sum(len(eps) for eps in done_episodes_to_return) # If user calls sample(num_timesteps=..) after this, we must reset again # at the beginning. self._needs_initial_reset = True + ts = sum(map(len, done_episodes_to_return)) + self._increase_sampled_metrics(ts) + return done_episodes_to_return - # TODO (sven): Remove the requirement for EnvRunners/RolloutWorkers to have this - # API. Instead Algorithm should compile episode metrics itself via its local - # buffer. - def get_metrics(self) -> List[RolloutMetrics]: + def get_metrics(self) -> ResultDict: # Compute per-episode metrics (only on already completed episodes). - metrics = [] for eps in self._done_episodes_for_metrics: + assert eps.is_done + episode_length = len(eps) - episode_reward = eps.get_return() + episode_return = eps.get_return() + episode_duration_s = eps.get_duration_s() + # Don't forget about the already returned chunks of this episode. if eps.id_ in self._ongoing_episodes_for_metrics: for eps2 in self._ongoing_episodes_for_metrics[eps.id_]: episode_length += len(eps2) - episode_reward += eps2.get_return() + episode_return += eps2.get_return() del self._ongoing_episodes_for_metrics[eps.id_] - metrics.append( - RolloutMetrics( - episode_length=episode_length, - episode_reward=episode_reward, - ) + self._log_episode_metrics( + episode_length, episode_return, episode_duration_s ) + # Log num episodes counter for this iteration. + self.metrics.log_value( + NUM_EPISODES, + len(self._done_episodes_for_metrics), + reduce="sum", + # Reset internal data on `reduce()` call below (not a lifetime count). + clear_on_reduce=True, + ) + + # Now that we have logged everything, clear cache of done episodes. self._done_episodes_for_metrics.clear() - self._ts_since_last_metrics = 0 - return metrics + # Return reduced metrics. + return self.metrics.reduce() # TODO (sven): Remove the requirement for EnvRunners/RolloutWorkers to have this # API. Replace by proper state overriding via `EnvRunner.set_state()` @@ -503,6 +537,52 @@ def stop(self): # Close our env object via gymnasium's API. self.env.close() + def _increase_sampled_metrics(self, num_steps): + # Per sample cycle stats. + self.metrics.log_value( + NUM_ENV_STEPS_SAMPLED, num_steps, reduce="sum", clear_on_reduce=True + ) + self.metrics.log_value( + (NUM_AGENT_STEPS_SAMPLED, DEFAULT_AGENT_ID), + num_steps, + reduce="sum", + clear_on_reduce=True, + ) + self.metrics.log_value( + (NUM_MODULE_STEPS_SAMPLED, DEFAULT_MODULE_ID), + num_steps, + reduce="sum", + clear_on_reduce=True, + ) + # Lifetime stats. + self.metrics.log_value(NUM_ENV_STEPS_SAMPLED_LIFETIME, num_steps, reduce="sum") + self.metrics.log_value( + (NUM_AGENT_STEPS_SAMPLED_LIFETIME, DEFAULT_AGENT_ID), + num_steps, + reduce="sum", + ) + self.metrics.log_value( + (NUM_MODULE_STEPS_SAMPLED_LIFETIME, DEFAULT_MODULE_ID), + num_steps, + reduce="sum", + ) + return num_steps + + def _log_episode_metrics(self, length, ret, sec): + # Log general episode metrics. + # To mimick the old API stack behavior, we'll use `window` here for + # these particular stats (instead of the default EMA). + win = self.config.metrics_num_episodes_for_smoothing + self.metrics.log_value(EPISODE_LEN_MEAN, length, window=win) + self.metrics.log_value(EPISODE_RETURN_MEAN, ret, window=win) + self.metrics.log_value(EPISODE_DURATION_SEC_MEAN, sec, window=win) + + # For some metrics, log min/max as well. + self.metrics.log_value(EPISODE_LEN_MIN, length, reduce="min") + self.metrics.log_value(EPISODE_RETURN_MIN, ret, reduce="min") + self.metrics.log_value(EPISODE_LEN_MAX, length, reduce="max") + self.metrics.log_value(EPISODE_RETURN_MAX, ret, reduce="max") + class NormalizedImageEnv(gym.ObservationWrapper): def __init__(self, *args, **kwargs): diff --git a/rllib/algorithms/dreamerv3/utils/summaries.py b/rllib/algorithms/dreamerv3/utils/summaries.py index f78876c83fe76..dd36adbb31604 100644 --- a/rllib/algorithms/dreamerv3/utils/summaries.py +++ b/rllib/algorithms/dreamerv3/utils/summaries.py @@ -13,16 +13,16 @@ create_cartpole_dream_image, create_frozenlake_dream_image, ) +from ray.rllib.core import DEFAULT_MODULE_ID from ray.rllib.core.columns import Columns +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.metrics import ( + LEARNER_RESULTS, + REPLAY_BUFFER_RESULTS, +) from ray.rllib.utils.tf_utils import inverse_symlog - -def _summarize(*, results, data_to_summarize, keys_to_log, include_histograms=False): - for k in keys_to_log: - if data_to_summarize[k].shape == (): - results.update({k: data_to_summarize[k]}) - elif include_histograms: - results.update({k: data_to_summarize[k]}) +torch, _ = try_import_torch() def reconstruct_obs_from_h_and_z( @@ -39,10 +39,18 @@ def reconstruct_obs_from_h_and_z( # Note that the last h-state (T+1) is NOT used here as it's already part of # a new trajectory. # Use mean() of the Gaussian, no sample! -> No need to construct dist object here. - reconstructed_obs_distr_means_TxB = dreamer_model.world_model.decoder( - # Fold time rank. - h=np.reshape(h_t0_to_H, (T * B, -1)), - z=np.reshape(z_t0_to_H, (T * B,) + z_t0_to_H.shape[2:]), + device = next(iter(dreamer_model.world_model.decoder.parameters())).device + reconstructed_obs_distr_means_TxB = ( + dreamer_model.world_model.decoder( + # Fold time rank. + h=torch.from_numpy(h_t0_to_H).reshape((T * B, -1)).to(device), + z=torch.from_numpy(z_t0_to_H) + .reshape((T * B,) + z_t0_to_H.shape[2:]) + .to(device), + ) + .detach() + .cpu() + .numpy() ) # Unfold time rank again. reconstructed_obs_T_B = np.reshape( @@ -115,11 +123,12 @@ def report_dreamed_trajectory( def report_predicted_vs_sampled_obs( *, - results, + metrics, sample, batch_size_B, batch_length_T, symlog_obs: bool = True, + do_report: bool = True, ): """Summarizes sampled data (from the replay buffer) vs world-model predictions. @@ -133,127 +142,188 @@ def report_predicted_vs_sampled_obs( Continues: Compute MSE (sampled vs predicted). Args: - results: The results dict that was returned by - `LearnerGroup.update_from_batch()`. + metrics: The MetricsLogger object of the DreamerV3 algo. sample: The sampled data (dict) from the replay buffer. Already tf-tensor converted. batch_size_B: The batch size (B). This is the number of trajectories sampled from the buffer. batch_length_T: The batch length (T). This is the length of an individual trajectory sampled from the buffer. + do_report: Whether to actually log the report (default). If this is set to + False, this function serves as a clean-up on the given metrics, making sure + they do NOT contain anymore any (spacious) data relevant for producing + the report/videos. """ - predicted_observation_means_BxT = results[ - "WORLD_MODEL_fwd_out_obs_distribution_means_BxT" - ] + fwd_output_key = ( + LEARNER_RESULTS, + DEFAULT_MODULE_ID, + "WORLD_MODEL_fwd_out_obs_distribution_means_b0xT", + ) + # logged as a non-reduced item (still a list) + predicted_observation_means_single_example = metrics.peek( + fwd_output_key, default=[None] + )[-1] + metrics.delete(fwd_output_key, key_error=False) + + final_result_key = ( + f"WORLD_MODEL_sampled_vs_predicted_posterior_b0x{batch_length_T}_videos" + ) + if not do_report: + metrics.delete(final_result_key, key_error=False) + return + _report_obs( - results=results, + metrics=metrics, computed_float_obs_B_T_dims=np.reshape( - predicted_observation_means_BxT, - (batch_size_B, batch_length_T) + sample[Columns.OBS].shape[2:], + predicted_observation_means_single_example, + # WandB videos need to be channels first. + (1, batch_length_T) + sample[Columns.OBS].shape[2:], ), - sampled_obs_B_T_dims=sample[Columns.OBS], - descr_prefix="WORLD_MODEL", - descr_obs=f"predicted_posterior_T{batch_length_T}", + sampled_obs_B_T_dims=sample[Columns.OBS][0:1], + metrics_key=final_result_key, symlog_obs=symlog_obs, ) def report_dreamed_eval_trajectory_vs_samples( *, - results, - dream_data, + metrics, sample, burn_in_T, dreamed_T, dreamer_model, symlog_obs: bool = True, -): + do_report: bool = True, +) -> None: + """Logs dreamed observations, rewards, continues and compares them vs sampled data. + + For obs, we'll try to create videos (side-by-side comparison) of the dreamed, + recreated-from-prior obs vs the sampled ones (over dreamed_T timesteps). + + Args: + metrics: The MetricsLogger object of the DreamerV3 algo. + sample: The sampled data (dict) from the replay buffer. Already tf-tensor + converted. + burn_in_T: The number of burn-in timesteps (these will be skipped over in the + reported video comparisons and MSEs). + dreamed_T: The number of timesteps to produce dreamed data for. + dreamer_model: The DreamerModel to use to create observation vectors/images + from dreamed h- and (prior) z-states. + symlog_obs: Whether to inverse-symlog the computed observations or not. Set this + to True for environments, in which we should symlog the observations. + do_report: Whether to actually log the report (default). If this is set to + False, this function serves as a clean-up on the given metrics, making sure + they do NOT contain anymore any (spacious) data relevant for producing + the report/videos. + """ + dream_data = metrics.peek( + LEARNER_RESULTS, + DEFAULT_MODULE_ID, + "dream_data", + default={}, + ) + metrics.delete(LEARNER_RESULTS, DEFAULT_MODULE_ID, "dream_data", key_error=False) + + final_result_key_obs = f"EVALUATION_sampled_vs_dreamed_prior_H{dreamed_T}_obs" + final_result_key_rew = ( + f"EVALUATION_sampled_vs_dreamed_prior_H{dreamed_T}_rewards_MSE" + ) + final_result_key_cont = ( + f"EVALUATION_sampled_vs_dreamed_prior_H{dreamed_T}_continues_MSE" + ) + if not do_report: + metrics.delete(final_result_key_obs, key_error=False) + metrics.delete(final_result_key_rew, key_error=False) + metrics.delete(final_result_key_cont, key_error=False) + return + # Obs MSE. - dreamed_obs_T_B = reconstruct_obs_from_h_and_z( - h_t0_to_H=dream_data["h_states_t0_to_H_BxT"], - z_t0_to_H=dream_data["z_states_prior_t0_to_H_BxT"], + dreamed_obs_H_B = reconstruct_obs_from_h_and_z( + h_t0_to_H=dream_data["h_states_t0_to_H_Bx1"][0], + z_t0_to_H=dream_data["z_states_prior_t0_to_H_Bx1"][0], dreamer_model=dreamer_model, obs_dims_shape=sample[Columns.OBS].shape[2:], ) - t0 = burn_in_T - 1 + t0 = burn_in_T tH = t0 + dreamed_T # Observation MSE and - if applicable - images comparisons. - mse_sampled_vs_dreamed_obs = _report_obs( - results=results, - # Have to transpose b/c dreamed data is time-major. - computed_float_obs_B_T_dims=np.transpose( - dreamed_obs_T_B, - axes=[1, 0] + list(range(2, len(dreamed_obs_T_B.shape))), - ), - sampled_obs_B_T_dims=sample[Columns.OBS][:, t0 : tH + 1], - descr_prefix="EVALUATION", - descr_obs=f"dreamed_prior_H{dreamed_T}", + _report_obs( + metrics=metrics, + # WandB videos need to be 5D (B, L, c, h, w) -> transpose/swap H and B axes. + computed_float_obs_B_T_dims=np.swapaxes(dreamed_obs_H_B, 0, 1)[ + 0:1 + ], # for now: only B=1 + sampled_obs_B_T_dims=sample[Columns.OBS][0:1, t0:tH], + metrics_key=final_result_key_obs, symlog_obs=symlog_obs, ) # Reward MSE. _report_rewards( - results=results, - computed_rewards=dream_data["rewards_dreamed_t0_to_H_BxT"], - sampled_rewards=sample[Columns.REWARDS][:, t0 : tH + 1], - descr_prefix="EVALUATION", - descr_reward=f"dreamed_prior_H{dreamed_T}", + metrics=metrics, + computed_rewards=dream_data["rewards_dreamed_t0_to_H_Bx1"][0], + sampled_rewards=sample[Columns.REWARDS][:, t0:tH], + metrics_key=final_result_key_rew, ) # Continues MSE. _report_continues( - results=results, - computed_continues=dream_data["continues_dreamed_t0_to_H_BxT"], - sampled_continues=(1.0 - sample["is_terminated"])[:, t0 : tH + 1], - descr_prefix="EVALUATION", - descr_cont=f"dreamed_prior_H{dreamed_T}", + metrics=metrics, + computed_continues=dream_data["continues_dreamed_t0_to_H_Bx1"][0], + sampled_continues=(1.0 - sample["is_terminated"])[:, t0:tH], + metrics_key=final_result_key_cont, ) - return mse_sampled_vs_dreamed_obs -def report_sampling_and_replay_buffer(*, replay_buffer): +def report_sampling_and_replay_buffer(*, metrics, replay_buffer): episodes_in_buffer = replay_buffer.get_num_episodes() ts_in_buffer = replay_buffer.get_num_timesteps() replayed_steps = replay_buffer.get_sampled_timesteps() added_steps = replay_buffer.get_added_timesteps() # Summarize buffer, sampling, and train ratio stats. - return { - "BUFFER_capacity": replay_buffer.capacity, - "BUFFER_size_num_episodes": episodes_in_buffer, - "BUFFER_size_timesteps": ts_in_buffer, - "BUFFER_replayed_steps": replayed_steps, - "BUFFER_added_steps": added_steps, - } + metrics.log_dict( + { + "capacity": replay_buffer.capacity, + "size_num_episodes": episodes_in_buffer, + "size_timesteps": ts_in_buffer, + "replayed_steps": replayed_steps, + "added_steps": added_steps, + }, + key=REPLAY_BUFFER_RESULTS, + window=1, + ) # window=1 b/c these are current (total count/state) values. def _report_obs( *, - results, + metrics, computed_float_obs_B_T_dims, sampled_obs_B_T_dims, - descr_prefix=None, - descr_obs, + metrics_key, symlog_obs, ): """Summarizes computed- vs sampled observations: MSE and (if applicable) images. Args: + metrics: The MetricsLogger object of the DreamerV3 algo. computed_float_obs_B_T_dims: Computed float observations (not clipped, not cast'd). Shape=(B, T, [dims ...]). sampled_obs_B_T_dims: Sampled observations (as-is from the environment, meaning this could be uint8, 0-255 clipped images). Shape=(B, T, [dims ...]). - B: The batch size B (see shapes of `computed_float_obs_B_T_dims` and - `sampled_obs_B_T_dims` above). - T: The batch length T (see shapes of `computed_float_obs_B_T_dims` and - `sampled_obs_B_T_dims` above). - descr: A string used to describe the computed data to be used in the TB - summaries. + metrics_key: The metrics key (or key sequence) under which to log ths resulting + video sequence. + symlog_obs: Whether to inverse-symlog the computed observations or not. Set this + to True for environments, in which we should symlog the observations. + """ # Videos: Create summary, comparing computed images with actual sampled ones. # 4=[B, T, w, h] grayscale image; 5=[B, T, w, h, C] RGB image. if len(sampled_obs_B_T_dims.shape) in [4, 5]: - descr_prefix = (descr_prefix + "_") if descr_prefix else "" + # WandB videos need to be channels first. + transpose_axes = ( + (0, 1, 4, 2, 3) if len(sampled_obs_B_T_dims.shape) == 5 else (0, 3, 1, 2) + ) if symlog_obs: computed_float_obs_B_T_dims = inverse_symlog(computed_float_obs_B_T_dims) @@ -265,68 +335,63 @@ def _report_obs( sampled_obs_B_T_dims = np.clip(sampled_obs_B_T_dims, 0.0, 255.0).astype( np.uint8 ) + sampled_obs_B_T_dims = np.transpose(sampled_obs_B_T_dims, transpose_axes) computed_images = np.clip(computed_float_obs_B_T_dims, 0.0, 255.0).astype( np.uint8 ) + computed_images = np.transpose(computed_images, transpose_axes) # Concat sampled and computed images along the height axis (3) such that # real images show below respective predicted ones. # (B, T, C, h, w) sampled_vs_computed_images = np.concatenate( [computed_images, sampled_obs_B_T_dims], - axis=3, + axis=-1, # concat on width axis (looks nicer) ) # Add grayscale dim, if necessary. if len(sampled_obs_B_T_dims.shape) == 2 + 2: sampled_vs_computed_images = np.expand_dims(sampled_vs_computed_images, -1) - results.update( - {f"{descr_prefix}sampled_vs_{descr_obs}_videos": sampled_vs_computed_images} + metrics.log_value( + metrics_key, + sampled_vs_computed_images, + reduce=None, # No reduction, we want the obs tensor to stay in-tact. + window=1, ) - # return mse_sampled_vs_computed_obs - def _report_rewards( *, - results, + metrics, computed_rewards, sampled_rewards, - descr_prefix=None, - descr_reward, + metrics_key, ): - descr_prefix = (descr_prefix + "_") if descr_prefix else "" mse_sampled_vs_computed_rewards = np.mean( np.square(computed_rewards - sampled_rewards) ) mse_sampled_vs_computed_rewards = np.mean(mse_sampled_vs_computed_rewards) - results.update( - { - f"{descr_prefix}sampled_vs_{descr_reward}_rewards_mse": ( - mse_sampled_vs_computed_rewards - ), - } + metrics.log_value( + metrics_key, + mse_sampled_vs_computed_rewards, + window=1, ) def _report_continues( *, - results, + metrics, computed_continues, sampled_continues, - descr_prefix=None, - descr_cont, + metrics_key, ): - descr_prefix = (descr_prefix + "_") if descr_prefix else "" # Continue MSE. mse_sampled_vs_computed_continues = np.mean( np.square( computed_continues - sampled_continues.astype(computed_continues.dtype) ) ) - results.update( - { - f"{descr_prefix}sampled_vs_{descr_cont}_continues_mse": ( - mse_sampled_vs_computed_continues - ), - } + metrics.log_value( + metrics_key, + mse_sampled_vs_computed_continues, + window=1, ) diff --git a/rllib/algorithms/impala/impala.py b/rllib/algorithms/impala/impala.py index e1d414c3179b6..b506dcf546aa1 100644 --- a/rllib/algorithms/impala/impala.py +++ b/rllib/algorithms/impala/impala.py @@ -37,6 +37,7 @@ NUM_AGENT_STEPS_SAMPLED, NUM_AGENT_STEPS_TRAINED, NUM_ENV_STEPS_SAMPLED, + NUM_ENV_STEPS_SAMPLED_LIFETIME, NUM_ENV_STEPS_TRAINED, NUM_MODULE_STEPS_TRAINED, NUM_SYNCH_WORKER_WEIGHTS, @@ -720,6 +721,7 @@ def training_step(self) -> ResultDict: if self.config.enable_rl_module_and_learner: train_results = self.learn_on_processed_samples() module_ids_to_update = set(train_results.keys()) - {ALL_MODULES} + # TODO (sven): Move to Learner._after_gradient_based_update(). additional_results = self.learner_group.additional_update( module_ids_to_update=module_ids_to_update, timestep=self._counters[ @@ -959,6 +961,11 @@ def learn_on_processed_samples(self) -> ResultDict: for batch in batches: result = self.learner_group.update_from_batch( batch=batch, + timesteps={ + NUM_ENV_STEPS_SAMPLED_LIFETIME: ( + self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME) + ), + }, async_update=async_update, num_iters=self.config.num_sgd_iter, minibatch_size=self.config.minibatch_size, diff --git a/rllib/algorithms/ppo/ppo.py b/rllib/algorithms/ppo/ppo.py index ce491f9ca09d7..60dfe4b6eed6a 100644 --- a/rllib/algorithms/ppo/ppo.py +++ b/rllib/algorithms/ppo/ppo.py @@ -479,6 +479,11 @@ def _training_step_new_api_stack(self) -> ResultDict: with self.metrics.log_time((TIMERS, LEARNER_UPDATE_TIMER)): learner_results = self.learner_group.update_from_episodes( episodes=episodes, + timesteps={ + NUM_ENV_STEPS_SAMPLED_LIFETIME: ( + self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME) + ), + }, minibatch_size=( self.config.mini_batch_size_per_learner or self.config.sgd_minibatch_size @@ -515,7 +520,6 @@ def _training_step_new_api_stack(self) -> ResultDict: # Sync weights from learner_group to all rollout workers. from_worker_or_learner_group=self.learner_group, policies=modules_to_update, - global_vars=None, inference_only=True, ) else: @@ -542,7 +546,8 @@ def _training_step_new_api_stack(self) -> ResultDict: ) kl_dict[mid] = kl - # triggers a special update method on RLOptimizer to update the KL values. + # TODO (sven): Move to Learner._after_gradient_based_update(). + # Triggers a special update method on RLOptimizer to update the KL values. additional_results = self.learner_group.additional_update( module_ids_to_update=modules_to_update, sampled_kl_values=kl_dict, diff --git a/rllib/algorithms/ppo/tests/test_ppo_with_env_runner.py b/rllib/algorithms/ppo/tests/test_ppo_with_env_runner.py index f3defd5f7520e..f7c89f167f8bd 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_with_env_runner.py +++ b/rllib/algorithms/ppo/tests/test_ppo_with_env_runner.py @@ -7,9 +7,7 @@ ) from ray.rllib.algorithms.callbacks import DefaultCallbacks from ray.rllib.core import DEFAULT_MODULE_ID -from ray.rllib.core.learner.learner import ( - LEARNER_RESULTS_CURR_LR_KEY, -) +from ray.rllib.core.learner.learner import DEFAULT_OPTIMIZER, LR_KEY from ray.rllib.utils.metrics import LEARNER_RESULTS from ray.rllib.utils.test_utils import ( @@ -44,7 +42,7 @@ def on_train_result(self, *, algorithm, result: dict, **kwargs): # Learning rate should decrease by 0.0001/4 per iteration. check( - stats[LEARNER_RESULTS_CURR_LR_KEY], + stats[DEFAULT_OPTIMIZER + "_" + LR_KEY], 0.0000075 if algorithm.iteration == 1 else 0.000005, ) # Compare reported curr lr vs the actual lr found in the optimizer object. @@ -54,7 +52,7 @@ def on_train_result(self, *, algorithm, result: dict, **kwargs): if algorithm.config.framework_str == "torch" else optim.lr ) - check(stats[LEARNER_RESULTS_CURR_LR_KEY], actual_optimizer_lr) + check(stats[DEFAULT_OPTIMIZER + "_" + LR_KEY], actual_optimizer_lr) class TestPPO(unittest.TestCase): diff --git a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py index 7c89be47b1898..4f9b320418297 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py @@ -10,9 +10,7 @@ from ray.rllib.algorithms.callbacks import DefaultCallbacks from ray.rllib.algorithms.ppo.tests.test_ppo import PENDULUM_FAKE_BATCH from ray.rllib.core import DEFAULT_MODULE_ID -from ray.rllib.core.learner.learner import ( - LEARNER_RESULTS_CURR_LR_KEY, -) +from ray.rllib.core.learner.learner import DEFAULT_OPTIMIZER, LR_KEY from ray.rllib.evaluation.postprocessing import ( compute_gae_for_sample_batch, ) @@ -50,7 +48,7 @@ def on_train_result(self, *, algorithm, result: dict, **kwargs): # Learning rate should decrease by 0.0001/4 per iteration. check( - stats[LEARNER_RESULTS_CURR_LR_KEY], + stats[DEFAULT_OPTIMIZER + "_" + LR_KEY], 0.0000075 if algorithm.iteration == 1 else 0.000005, ) # Compare reported curr lr vs the actual lr found in the optimizer object. @@ -60,7 +58,7 @@ def on_train_result(self, *, algorithm, result: dict, **kwargs): if algorithm.config.framework_str == "torch" else optim.lr ) - check(stats[LEARNER_RESULTS_CURR_LR_KEY], actual_optimizer_lr) + check(stats[DEFAULT_OPTIMIZER + "_" + LR_KEY], actual_optimizer_lr) class TestPPO(unittest.TestCase): diff --git a/rllib/core/learner/learner.py b/rllib/core/learner/learner.py index 887be4f2bf92a..b71877fdf8928 100644 --- a/rllib/core/learner/learner.py +++ b/rllib/core/learner/learner.py @@ -43,6 +43,7 @@ from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.metrics import ( ALL_MODULES, + NUM_ENV_STEPS_SAMPLED_LIFETIME, NUM_ENV_STEPS_TRAINED, NUM_MODULE_STEPS_TRAINED, ) @@ -85,7 +86,7 @@ ENTROPY_KEY = "entropy" # Additional update keys -LEARNER_RESULTS_CURR_LR_KEY = "curr_lr" +LR_KEY = "learning_rate" @dataclass @@ -287,6 +288,10 @@ def __init__( # and return the resulting (reduced) dict. self.metrics = MetricsLogger() + # TODO (sven): Do we really need this API? It seems like LearnerGroup constructs + # all Learner workers and then immediately builds them any ways? Seems to make + # thing more complicated. Unless there is a reason related to Train worker group + # setup. @OverrideToImplementCustomLogic_CallToSuperRecommended def build(self) -> None: """Builds the Learner. @@ -528,7 +533,6 @@ def postprocess_gradients_for_module( module_id: ModuleID, config: Optional["AlgorithmConfig"] = None, module_gradients_dict: ParamDict, - hps=None, ) -> ParamDict: """Applies postprocessing operations on the gradients of the given module. @@ -547,13 +551,6 @@ def postprocess_gradients_for_module( A dictionary with the updated gradients and the exact same (flat) structure as the incoming `module_gradients_dict` arg. """ - if hps is not None: - deprecation_warning( - old="Learner.postprocess_gradients_for_module(.., hps=..)", - help="Deprecated argument. Use `config` (AlgorithmConfig) instead.", - error=True, - ) - postprocessed_grads = {} if config.grad_clip is None: @@ -1029,7 +1026,6 @@ def additional_update_for_module( module_id: ModuleID, config: Optional["AlgorithmConfig"] = None, timestep: int, - hps=None, **kwargs, ) -> None: """Apply additional non-gradient based updates for a single module. @@ -1045,36 +1041,14 @@ def additional_update_for_module( Returns: A dictionary of results from the update """ - if hps is not None: - deprecation_warning( - old="Learner.additional_update_for_module(.., hps=..)", - help="Deprecated argument. Use `config` (AlgorithmConfig) instead.", - error=True, - ) - - # Only cover the optimizer mapped to this particular module. - for optimizer_name, optimizer in self.get_optimizers_for_module(module_id): - # Only update this optimizer's lr, if a scheduler has been registered - # along with it. - if optimizer in self._optimizer_lr_schedules: - new_lr = self._optimizer_lr_schedules[optimizer].update( - timestep=timestep - ) - self._set_optimizer_lr(optimizer, lr=new_lr) - - # Make sure our returned results differentiate by optimizer name - # (if not the default name). - stats_name = LEARNER_RESULTS_CURR_LR_KEY - if optimizer_name != DEFAULT_OPTIMIZER: - stats_name += "_" + optimizer_name - self.metrics.log_value( - key=(module_id, stats_name), value=new_lr, window=1 - ) + pass def update_from_batch( self, batch: MultiAgentBatch, *, + # TODO (sven): Make this a more formal structure with its own type. + timesteps: Optional[Dict[str, Any]] = None, # TODO (sven): Deprecate these in favor of config attributes for only those # algos that actually need (and know how) to do minibatching. minibatch_size: Optional[int] = None, @@ -1090,6 +1064,9 @@ def update_from_batch( Args: batch: A batch of training data to update from. + timesteps: Timesteps dict, which must have the key + `NUM_ENV_STEPS_SAMPLED_LIFETIME`. + # TODO (sven): Make this a more formal structure with its own type. minibatch_size: The size of the minibatch to use for each update. num_iters: The number of complete passes over all the sub-batches in the input multi-agent batch. @@ -1113,7 +1090,7 @@ def update_from_batch( ) return self._update_from_batch_or_episodes( batch=batch, - episodes=None, + timesteps=timesteps, minibatch_size=minibatch_size, num_iters=num_iters, ) @@ -1122,6 +1099,8 @@ def update_from_episodes( self, episodes: List[EpisodeType], *, + # TODO (sven): Make this a more formal structure with its own type. + timesteps: Optional[Dict[str, Any]] = None, # TODO (sven): Deprecate these in favor of config attributes for only those # algos that actually need (and know how) to do minibatching. minibatch_size: Optional[int] = None, @@ -1138,6 +1117,9 @@ def update_from_episodes( Args: episodes: An list of episode objects to update from. + timesteps: Timesteps dict, which must have the key + `NUM_ENV_STEPS_SAMPLED_LIFETIME`. + # TODO (sven): Make this a more formal structure with its own type. minibatch_size: The size of the minibatch to use for each update. num_iters: The number of complete passes over all the sub-batches in the input multi-agent batch. @@ -1167,8 +1149,8 @@ def update_from_episodes( error=True, ) return self._update_from_batch_or_episodes( - batch=None, episodes=episodes, + timesteps=timesteps, minibatch_size=minibatch_size, num_iters=num_iters, min_total_mini_batches=min_total_mini_batches, @@ -1270,17 +1252,16 @@ def _update_from_batch_or_episodes( # as well for simplicity. batch: Optional[MultiAgentBatch] = None, episodes: Optional[List[EpisodeType]] = None, + # TODO (sven): Make this a more formal structure with its own type. + timesteps: Optional[Dict[str, Any]] = None, # TODO (sven): Deprecate these in favor of config attributes for only those # algos that actually need (and know how) to do minibatching. minibatch_size: Optional[int] = None, num_iters: int = 1, min_total_mini_batches: int = 0, ) -> Union[Dict[str, Any], List[Dict[str, Any]]]: - self._check_is_built() - if num_iters < 1: - # We must do at least one pass on the batch for training. - raise ValueError("`num_iters` must be >= 1") + self._check_is_built() # Call the learner connector. if self._learner_connector is not None and episodes is not None: @@ -1310,6 +1291,7 @@ def _update_from_batch_or_episodes( f"Found IDs: {unknown_module_ids}" ) + # TODO: Move this into LearnerConnector pipeline? # Filter out those RLModules from the final train batch that should not be # updated. for module_id in list(batch.policy_batches.keys()): @@ -1326,14 +1308,10 @@ def _update_from_batch_or_episodes( { (ALL_MODULES, NUM_ENV_STEPS_TRAINED): batch.env_steps(), (ALL_MODULES, NUM_MODULE_STEPS_TRAINED): batch.agent_steps(), - }, - reduce="sum", - clear_on_reduce=True, - ) - self.metrics.log_dict( - { - (mid, NUM_MODULE_STEPS_TRAINED): len(b) - for mid, b in batch.policy_batches.items() + **{ + (mid, NUM_MODULE_STEPS_TRAINED): len(b) + for mid, b in batch.policy_batches.items() + }, }, reduce="sum", clear_on_reduce=True, @@ -1388,11 +1366,48 @@ def _update_from_batch_or_episodes( self._set_slicing_by_batch_id(batch, value=False) + # Call `_after_gradient_based_update` to allow for non-gradient based + # cleanups-, logging-, and update logic to happen. + self._after_gradient_based_update(timesteps) + + # Reduce results across all minibatch update steps. + return self.metrics.reduce() + + @OverrideToImplementCustomLogic_CallToSuperRecommended + def _after_gradient_based_update(self, timesteps: Dict[str, Any]) -> None: + """Called after gradient-based updates are completed. + + Should be overridden to implement custom cleanup-, logging-, or non-gradient- + based Learner/RLModule update logic after(!) gradient-based updates have been + completed. + + Args: + timesteps: Timesteps dict, which must have the key + `NUM_ENV_STEPS_SAMPLED_LIFETIME`. + # TODO (sven): Make this a more formal structure with its own type. + """ + timesteps = timesteps or {} + + # Only update this optimizer's lr, if a scheduler has been registered + # along with it. + for module_id, optimizer_names in self._module_optimizers.items(): + for optimizer_name in optimizer_names: + optimizer = self._named_optimizers[optimizer_name] + lr_schedule = self._optimizer_lr_schedules.get(optimizer) + if lr_schedule is None: + continue + new_lr = lr_schedule.update( + timestep=timesteps.get(NUM_ENV_STEPS_SAMPLED_LIFETIME, 0) + ) + self._set_optimizer_lr(optimizer, lr=new_lr) + # Log all current learning rates of all our optimizers (registered under the # different ModuleIDs). self.metrics.log_dict( { - (mid, f"{full_name[len(mid) + 1 :]}_lr"): convert_to_numpy( + # Cut out the module ID from the beginning since it's already part of + # the key sequence: (ModuleID, "[optim name]_lr"). + (mid, f"{full_name[len(mid) + 1:]}_{LR_KEY}"): convert_to_numpy( self._get_optimizer_lr(self._named_optimizers[full_name]) ) for mid, full_names in self._module_optimizers.items() @@ -1401,9 +1416,6 @@ def _update_from_batch_or_episodes( window=1, ) - # Reduce results across all minibatch update steps. - return self.metrics.reduce() - def _set_slicing_by_batch_id( self, batch: MultiAgentBatch, *, value: bool ) -> MultiAgentBatch: @@ -1456,7 +1468,6 @@ def _save_optimizers(self, path: Union[str, pathlib.Path]) -> None: Args: path: The path to the directory to save the state to. - """ pass @@ -1465,7 +1476,6 @@ def _load_optimizers(self, path: Union[str, pathlib.Path]) -> None: Args: path: The path to the directory to load the state from. - """ pass @@ -1491,7 +1501,6 @@ def save_state(self, path: Union[str, pathlib.Path]) -> None: Args: path: The path to the directory to save the state to. - """ self._check_is_built() path = pathlib.Path(path) diff --git a/rllib/core/learner/learner_group.py b/rllib/core/learner/learner_group.py index acc42ef83deac..5496cb0ce5ec1 100644 --- a/rllib/core/learner/learner_group.py +++ b/rllib/core/learner/learner_group.py @@ -220,6 +220,7 @@ def update_from_batch( self, batch: MultiAgentBatch, *, + timesteps: Optional[Dict[str, Any]] = None, async_update: bool = False, # TODO (sven): Deprecate the following args. They should be extracted from # self.config of those specific algorithms that actually require these @@ -266,7 +267,7 @@ def update_from_batch( ) return self._update( batch=batch, - episodes=None, + timesteps=timesteps, async_update=async_update, minibatch_size=minibatch_size, num_iters=num_iters, @@ -276,6 +277,7 @@ def update_from_episodes( self, episodes: List[EpisodeType], *, + timesteps: Optional[Dict[str, Any]] = None, async_update: bool = False, # TODO (sven): Deprecate the following args. They should be extracted from # self.config of those specific algorithms that actually require these @@ -322,8 +324,8 @@ def update_from_episodes( ) return self._update( - batch=None, episodes=episodes, + timesteps=timesteps, async_update=async_update, minibatch_size=minibatch_size, num_iters=num_iters, @@ -334,6 +336,7 @@ def _update( *, batch: Optional[MultiAgentBatch] = None, episodes: Optional[List[EpisodeType]] = None, + timesteps: Optional[Dict[str, Any]] = None, async_update: bool = False, minibatch_size: Optional[int] = None, num_iters: int = 1, @@ -342,22 +345,25 @@ def _update( # Define function to be called on all Learner actors (or the local learner). def _learner_update( learner: Learner, - batch_shard=None, - episodes_shard=None, - min_total_mini_batches=0, + _batch_shard=None, + _episodes_shard=None, + _timesteps=None, + _min_total_mini_batches=0, ): - if batch_shard is not None: + if _batch_shard is not None: return learner.update_from_batch( - batch=batch_shard, + batch=_batch_shard, + timesteps=_timesteps, minibatch_size=minibatch_size, num_iters=num_iters, ) else: return learner.update_from_episodes( - episodes=episodes_shard, + episodes=_episodes_shard, + timesteps=_timesteps, minibatch_size=minibatch_size, num_iters=num_iters, - min_total_mini_batches=min_total_mini_batches, + min_total_mini_batches=_min_total_mini_batches, ) # Local Learner worker: Don't shard batch/episodes, just run data as-is through @@ -372,8 +378,9 @@ def _learner_update( results = [ _learner_update( learner=self._learner, - batch_shard=batch, - episodes_shard=episodes, + _batch_shard=batch, + _episodes_shard=episodes, + _timesteps=timesteps, ) ] # One or more remote Learners: Shard batch/episodes into equal pieces (roughly @@ -387,7 +394,9 @@ def _learner_update( # "lockstep"), the `ShardBatchIterator` should not be used. if episodes is None: partials = [ - partial(_learner_update, batch_shard=batch_shard) + partial( + _learner_update, _batch_shard=batch_shard, _timesteps=timesteps + ) for batch_shard in ShardBatchIterator(batch, len(self._workers)) ] # Single- or MultiAgentEpisodes: Shard into equal pieces (only roughly equal @@ -424,8 +433,9 @@ def _learner_update( partials = [ partial( _learner_update, - episodes_shard=eps_shard, - min_total_mini_batches=min_total_mini_batches, + _episodes_shard=eps_shard, + _timesteps=timesteps, + _min_total_mini_batches=min_total_mini_batches, ) for eps_shard in eps_shards ] diff --git a/rllib/core/learner/torch/torch_learner.py b/rllib/core/learner/torch/torch_learner.py index cacce6f90b063..6dc586c12077d 100644 --- a/rllib/core/learner/torch/torch_learner.py +++ b/rllib/core/learner/torch/torch_learner.py @@ -136,7 +136,7 @@ def compute_gradients( self, loss_per_module: Dict[ModuleID, TensorType], **kwargs ) -> ParamDict: for optim in self._optimizer_parameters: - # set_to_none is a faster way to zero out the gradients + # `set_to_none=True` is a faster way to zero out the gradients. optim.zero_grad(set_to_none=True) loss_per_module[ALL_MODULES].backward() grads = {pid: p.grad for pid, p in self._params.items()} diff --git a/rllib/core/models/torch/heads.py b/rllib/core/models/torch/heads.py index 95f61e58c4fd4..d634ff2ef4c84 100644 --- a/rllib/core/models/torch/heads.py +++ b/rllib/core/models/torch/heads.py @@ -218,7 +218,7 @@ def __init__(self, config: CNNTransposeHeadConfig) -> None: if initial_dense_weights_initializer: initial_dense_weights_initializer( self.initial_dense.weight, - **config.initial_dense_initializer_config or {}, + **config.initial_dense_weights_initializer_config or {}, ) # Initialized dense layer bais, if necessary. if initial_dense_bias_initializer: diff --git a/rllib/core/models/torch/primitives.py b/rllib/core/models/torch/primitives.py index a8b8d9b27fd28..0be0937302795 100644 --- a/rllib/core/models/torch/primitives.py +++ b/rllib/core/models/torch/primitives.py @@ -161,7 +161,8 @@ def __init__( # Insert a layer normalization in between layer's output and # the activation. if hidden_layer_use_layernorm: - layers.append(nn.LayerNorm(dims[i + 1])) + # We use an epsilon of 0.001 here to mimick the Tf default behavior. + layers.append(nn.LayerNorm(dims[i + 1], eps=0.001)) # Add the activation function. if hidden_activation is not None: layers.append(hidden_activation()) @@ -294,7 +295,8 @@ def __init__( # Layernorm. if cnn_use_layernorm: - layers.append(nn.LayerNorm((out_depth, out_size[0], out_size[1]))) + # We use an epsilon of 0.001 here to mimick the Tf default behavior. + layers.append(LayerNorm1D(out_depth, eps=0.001)) # Activation. if cnn_activation is not None: layers.append(cnn_activation()) @@ -446,7 +448,7 @@ def __init__( layers.append(layer) # Layernorm (never for final layer). if cnn_transpose_use_layernorm and not is_final_layer: - layers.append(nn.LayerNorm((out_depth, out_size[0], out_size[1]))) + layers.append(LayerNorm1D(out_depth, eps=0.001)) # Last layer is never activated (regardless of config). if cnn_transpose_activation is not None and not is_final_layer: layers.append(cnn_transpose_activation()) @@ -464,3 +466,20 @@ def forward(self, inputs): out = inputs.permute(0, 3, 1, 2) out = self.cnn_transpose(out.type(self.expected_input_dtype)) return out.permute(0, 2, 3, 1) + + +class LayerNorm1D(nn.Module): + def __init__(self, num_features, **kwargs): + super().__init__() + self.layer_norm = nn.LayerNorm(num_features, **kwargs) + + def forward(self, x): + # x shape: (B, dim, dim, channels). + batch_size, channels, h, w = x.size() + # Reshape to (batch_size * height * width, channels) for LayerNorm + x = x.permute(0, 2, 3, 1).reshape(-1, channels) + # Apply LayerNorm + x = self.layer_norm(x) + # Reshape back to (batch_size, dim, dim, channels) + x = x.reshape(batch_size, h, w, channels).permute(0, 3, 1, 2) + return x diff --git a/rllib/core/models/torch/utils.py b/rllib/core/models/torch/utils.py index f9da0adab31a1..1bdbdef016f4c 100644 --- a/rllib/core/models/torch/utils.py +++ b/rllib/core/models/torch/utils.py @@ -45,28 +45,33 @@ def __init__(self, width, height, stride_w, stride_h): self.stride_w = stride_w self.stride_h = stride_h - self.zeros = torch.zeros( - size=( - self.width * self.stride_w - (self.stride_w - 1), - self.height * self.stride_h - (self.stride_h - 1), + self.register_buffer( + "zeros", + torch.zeros( + size=( + self.width * self.stride_w - (self.stride_w - 1), + self.height * self.stride_h - (self.stride_h - 1), + ), + dtype=torch.float32, ), - dtype=torch.float32, ) + self.out_width, self.out_height = self.zeros.shape[0], self.zeros.shape[1] # Squeeze in batch and channel dims. self.zeros = self.zeros.unsqueeze(0).unsqueeze(0) - self.where_template = torch.zeros( + where_template = torch.zeros( (self.stride_w, self.stride_h), dtype=torch.float32 ) # Set upper/left corner to 1.0. - self.where_template[0][0] = 1.0 + where_template[0][0] = 1.0 # then tile across the entire (strided) image size. - self.where_template = self.where_template.repeat((self.height, self.width))[ + where_template = where_template.repeat((self.height, self.width))[ : -(self.stride_w - 1), : -(self.stride_h - 1) ] # Squeeze in batch and channel dims and convert to bool. - self.where_template = self.where_template.unsqueeze(0).unsqueeze(0).bool() + where_template = where_template.unsqueeze(0).unsqueeze(0).bool() + self.register_buffer("where_template", where_template) def forward(self, x): # Repeat incoming image stride(w/h) times to match the strided output template. diff --git a/rllib/core/rl_module/torch/torch_rl_module.py b/rllib/core/rl_module/torch/torch_rl_module.py index 9cb4d2bda6c40..0ced418785521 100644 --- a/rllib/core/rl_module/torch/torch_rl_module.py +++ b/rllib/core/rl_module/torch/torch_rl_module.py @@ -110,7 +110,8 @@ def get_state(self, inference_only: bool = False) -> Mapping[str, Any]: @override(RLModule) def set_state(self, state_dict: Mapping[str, Any]) -> None: - self.load_state_dict(convert_to_torch_tensor(state_dict)) + state_dict = convert_to_torch_tensor(state_dict) + self.load_state_dict(state_dict) def _module_state_file_name(self) -> pathlib.Path: return pathlib.Path("module_state.pt") diff --git a/rllib/examples/catalogs/mobilenet_v2_encoder.py b/rllib/examples/catalogs/mobilenet_v2_encoder.py index beebdb79f773e..ca44215b8bef0 100644 --- a/rllib/examples/catalogs/mobilenet_v2_encoder.py +++ b/rllib/examples/catalogs/mobilenet_v2_encoder.py @@ -44,7 +44,10 @@ def _get_encoder_config( # Create a generic config with our enhanced Catalog ppo_config = ( PPOConfig() - .api_stack(enable_rl_module_and_learner=True) + .api_stack( + enable_rl_module_and_learner=True, + enable_env_runner_and_connector_v2=True, + ) .rl_module( rl_module_spec=SingleAgentRLModuleSpec( catalog_class=MobileNetEnhancedPPOCatalog diff --git a/rllib/examples/evaluation/custom_evaluation.py b/rllib/examples/evaluation/custom_evaluation.py index d396ffee04dfc..76aad3eccdf4d 100644 --- a/rllib/examples/evaluation/custom_evaluation.py +++ b/rllib/examples/evaluation/custom_evaluation.py @@ -149,6 +149,7 @@ def custom_eval_function( eval_results = algorithm.metrics.reduce( key=(EVALUATION_RESULTS, ENV_RUNNER_RESULTS) ) + # Alternatively, you could manually reduce over the n returned `env_runner_metrics` # dicts, but this would be much harder as you might not know, which metrics # to sum up, which ones to average over, etc.. diff --git a/rllib/execution/rollout_ops.py b/rllib/execution/rollout_ops.py index 4b3582b9968b1..a42a1991453de 100644 --- a/rllib/execution/rollout_ops.py +++ b/rllib/execution/rollout_ops.py @@ -24,6 +24,7 @@ def synchronous_parallel_sample( max_env_steps: Optional[int] = None, concat: bool = True, sample_timeout_s: Optional[float] = 60.0, + random_actions: bool = False, _uses_new_env_runners: bool = False, _return_metrics: bool = False, ) -> Union[List[SampleBatchType], SampleBatchType, List[EpisodeType], EpisodeType]: @@ -81,6 +82,8 @@ def synchronous_parallel_sample( sample_batches_or_episodes = [] all_stats_dicts = [] + random_action_kwargs = {} if not random_actions else {"random_actions": True} + # Stop collecting batches as soon as one criterium is met. while (max_agent_or_env_steps is None and agent_or_env_steps == 0) or ( max_agent_or_env_steps is not None @@ -89,16 +92,16 @@ def synchronous_parallel_sample( # No remote workers in the set -> Use local worker for collecting # samples. if worker_set.num_remote_workers() <= 0: - sampled_data = [worker_set.local_worker().sample()] + sampled_data = [worker_set.local_worker().sample(**random_action_kwargs)] if _return_metrics: stats_dicts = [worker_set.local_worker().get_metrics()] # Loop over remote workers' `sample()` method in parallel. else: sampled_data = worker_set.foreach_worker( ( - (lambda w: w.sample()) + (lambda w: w.sample(**random_action_kwargs)) if not _return_metrics - else (lambda w: (w.sample(), w.get_metrics())) + else (lambda w: (w.sample(**random_action_kwargs), w.get_metrics())) ), local_worker=False, timeout_seconds=sample_timeout_s, diff --git a/rllib/tuned_examples/dreamerv3/atari_100k.py b/rllib/tuned_examples/dreamerv3/atari_100k.py index 23a46fcbf3e73..68bdc97451366 100644 --- a/rllib/tuned_examples/dreamerv3/atari_100k.py +++ b/rllib/tuned_examples/dreamerv3/atari_100k.py @@ -9,17 +9,27 @@ """ # Run with: -# python run_regression_tests.py --dir [this file] --env ALE/[gym ID e.g. Pong-v5] +# python [this script name].py --env ALE/[gym ID e.g. Pong-v5] -from ray.rllib.algorithms.dreamerv3.dreamerv3 import DreamerV3Config +# To see all available options: +# python [this script name].py --help +from ray.rllib.algorithms.dreamerv3.dreamerv3 import DreamerV3Config +from ray.rllib.utils.test_utils import add_rllib_example_script_args -# Number of GPUs to run on. -num_gpus = 1 +parser = add_rllib_example_script_args( + default_iters=1000000, + default_reward=20.0, + default_timesteps=1000000, +) +# Use `parser` to add your own custom command line options to this script +# and (if needed) use their values toset up `config` below. +args = parser.parse_args() config = ( DreamerV3Config() .environment( + env=args.env, # [2]: "We follow the evaluation protocol of Machado et al. (2018) with 200M # environment steps, action repeat of 4, a time limit of 108,000 steps per # episode that correspond to 30 minutes of game play, no access to life @@ -34,31 +44,35 @@ "full_action_space": False, # Already done by MaxAndSkip wrapper: "action repeat" == 4. "frameskip": 1, - } - ) - .resources( - num_cpus_for_main_process=1, - ) - .learners( - num_learners=0 if num_gpus == 1 else num_gpus, - num_gpus_per_learner=1 if num_gpus else 0, + }, ) .env_runners( + num_env_runners=(args.num_env_runners or 0), # If we use >1 GPU and increase the batch size accordingly, we should also # increase the number of envs per worker. - num_envs_per_env_runner=(num_gpus or 1), - remote_worker_envs=True, + num_envs_per_env_runner=(args.num_gpus or 1), + remote_worker_envs=(args.num_gpus > 1), + ) + .learners( + num_learners=0 if args.num_gpus == 1 else args.num_gpus, + num_gpus_per_learner=1 if args.num_gpus else 0, ) .reporting( - metrics_num_episodes_for_smoothing=(num_gpus or 1), - report_images_and_videos=False, - report_dream_data=False, + metrics_num_episodes_for_smoothing=(args.num_gpus or 1), + report_images_and_videos=True, + report_dream_data=True, report_individual_batch_item_stats=False, ) # See Appendix A. .training( model_size="S", training_ratio=1024, - batch_size_B=16 * (num_gpus or 1), + batch_size_B=16 * (args.num_gpus or 1), ) ) + + +if __name__ == "__main__": + from ray.rllib.utils.test_utils import run_rllib_example_script_experiment + + run_rllib_example_script_experiment(config, args, stop={}, keep_config=True) diff --git a/rllib/tuned_examples/dreamerv3/atari_200M.py b/rllib/tuned_examples/dreamerv3/atari_200M.py index 2fb1e48b09294..2339d345d2f86 100644 --- a/rllib/tuned_examples/dreamerv3/atari_200M.py +++ b/rllib/tuned_examples/dreamerv3/atari_200M.py @@ -9,13 +9,22 @@ """ # Run with: -# python run_regression_tests.py --dir [this file] --env ALE/[gym ID e.g. Pong-v5] +# python [this script name].py --env ALE/[gym ID e.g. Pong-v5] -from ray.rllib.algorithms.dreamerv3.dreamerv3 import DreamerV3Config +# To see all available options: +# python [this script name].py --help +from ray.rllib.algorithms.dreamerv3.dreamerv3 import DreamerV3Config +from ray.rllib.utils.test_utils import add_rllib_example_script_args -# Number of GPUs to run on. -num_gpus = 1 +parser = add_rllib_example_script_args( + default_iters=1000000, + default_reward=20.0, + default_timesteps=1000000, +) +# Use `parser` to add your own custom command line options to this script +# and (if needed) use their values toset up `config` below. +args = parser.parse_args() config = ( DreamerV3Config() @@ -23,19 +32,10 @@ # For each (parallelized) env, we should provide a CPU. Lower this number # if you don't have enough CPUs. num_cpus_for_main_process=8 - * (num_gpus or 1), - ) - .learners( - num_learners=0 if num_gpus == 1 else num_gpus, - num_gpus_per_learner=1 if num_gpus else 0, - ) - .env_runners( - # If we use >1 GPU and increase the batch size accordingly, we should also - # increase the number of envs per worker. - num_envs_per_env_runner=8 * (num_gpus or 1), - remote_worker_envs=True, + * (args.num_gpus or 1), ) .environment( + env=args.env, # [2]: "We follow the evaluation protocol of Machado et al. (2018) with 200M # environment steps, action repeat of 4, a time limit of 108,000 steps per # episode that correspond to 30 minutes of game play, no access to life @@ -50,10 +50,21 @@ "full_action_space": False, # Already done by MaxAndSkip wrapper: "action repeat" == 4. "frameskip": 1, - } + }, + ) + .env_runners( + num_env_runners=(args.num_env_runners or 0), + # If we use >1 GPU and increase the batch size accordingly, we should also + # increase the number of envs per worker. + num_envs_per_env_runner=8 * (args.num_gpus or 1), + remote_worker_envs=True, + ) + .learners( + num_learners=0 if args.num_gpus == 1 else args.num_gpus, + num_gpus_per_learner=1 if args.num_gpus else 0, ) .reporting( - metrics_num_episodes_for_smoothing=(num_gpus or 1), + metrics_num_episodes_for_smoothing=(args.num_gpus or 1), report_images_and_videos=False, report_dream_data=False, report_individual_batch_item_stats=False, @@ -62,6 +73,12 @@ .training( model_size="XL", training_ratio=64, - batch_size_B=16 * (num_gpus or 1), + batch_size_B=16 * (args.num_gpus or 1), ) ) + + +if __name__ == "__main__": + from ray.rllib.utils.test_utils import run_rllib_example_script_experiment + + run_rllib_example_script_experiment(config, args, keep_config=True) diff --git a/rllib/tuned_examples/dreamerv3/dm_control_suite_vision.py b/rllib/tuned_examples/dreamerv3/dm_control_suite_vision.py index b201900da5f6a..21c1a435a0345 100644 --- a/rllib/tuned_examples/dreamerv3/dm_control_suite_vision.py +++ b/rllib/tuned_examples/dreamerv3/dm_control_suite_vision.py @@ -7,30 +7,45 @@ D. Hafner, T. Lillicrap, M. Norouzi, J. Ba https://arxiv.org/pdf/2010.02193.pdf """ + # Run with: -# python run_regression_tests.py --dir [this file] --env DMC/[task]/[domain] -# e.g. --env=DMC/cartpole/swingup +# python [this script name].py --env DMC/[task]/[domain] (e.g. DMC/cartpole/swingup) -from ray.rllib.algorithms.dreamerv3.dreamerv3 import DreamerV3Config +# To see all available options: +# python [this script name].py --help +from ray.rllib.algorithms.dreamerv3.dreamerv3 import DreamerV3Config +from ray.rllib.utils.test_utils import add_rllib_example_script_args -# Number of GPUs to run on. -num_gpus = 1 +parser = add_rllib_example_script_args( + default_iters=1000000, + default_reward=800.0, + default_timesteps=1000000, +) +# Use `parser` to add your own custom command line options to this script +# and (if needed) use their values toset up `config` below. +args = parser.parse_args() config = ( DreamerV3Config() # Use image observations. - .environment(env_config={"from_pixels": True}) - .resources( - num_cpus_for_main_process=1, + .environment( + env=args.env, + env_config={"from_pixels": True}, ) .learners( - num_learners=0 if num_gpus == 1 else num_gpus, - num_gpus_per_learner=1 if num_gpus else 0, + num_learners=0 if args.num_gpus == 1 else args.num_gpus, + num_gpus_per_learner=1 if args.num_gpus else 0, + ) + .env_runners( + num_env_runners=(args.num_env_runners or 0), + # If we use >1 GPU and increase the batch size accordingly, we should also + # increase the number of envs per worker. + num_envs_per_env_runner=4 * (args.num_gpus or 1), + remote_worker_envs=True, ) - .env_runners(num_envs_per_env_runner=4 * (num_gpus or 1), remote_worker_envs=True) .reporting( - metrics_num_episodes_for_smoothing=(num_gpus or 1), + metrics_num_episodes_for_smoothing=(args.num_gpus or 1), report_images_and_videos=False, report_dream_data=False, report_individual_batch_item_stats=False, @@ -39,6 +54,6 @@ .training( model_size="S", training_ratio=512, - batch_size_B=16 * (num_gpus or 1), + batch_size_B=16 * (args.num_gpus or 1), ) ) diff --git a/rllib/utils/metrics/__init__.py b/rllib/utils/metrics/__init__.py index 39e087da9434c..764edabdb8a2c 100644 --- a/rllib/utils/metrics/__init__.py +++ b/rllib/utils/metrics/__init__.py @@ -1,6 +1,7 @@ # Algorithm ResultDict keys. EVALUATION_RESULTS = "evaluation" ENV_RUNNER_RESULTS = "env_runners" +REPLAY_BUFFER_RESULTS = "replay_buffer" LEARNER_RESULTS = "learners" FAULT_TOLERANCE_STATS = "fault_tolerance" TIMERS = "timers" diff --git a/rllib/utils/metrics/metrics_logger.py b/rllib/utils/metrics/metrics_logger.py index 8de81deb0038e..611d6f7d9b8cc 100644 --- a/rllib/utils/metrics/metrics_logger.py +++ b/rllib/utils/metrics/metrics_logger.py @@ -596,9 +596,9 @@ def deactivate_tensor_mode(self): def tensors_to_numpy(self, tensor_metrics): """Converts all previously logged and returned tensors back to numpy values.""" - for key, value in tensor_metrics.items(): + for key, values in tensor_metrics.items(): assert self._key_in_stats(key) - self._get_key(key).numpy(value) + self._get_key(key).set_to_numpy_values(values) @property def tensor_mode(self): @@ -727,6 +727,19 @@ def set_value( clear_on_reduce=clear_on_reduce, ) + def delete(self, *key: Tuple[str], key_error: bool = True) -> None: + """Deletes th egiven `key` from this metrics logger's stats. + + Args: + key: The key or key sequence (for nested location within self.stats), + to delete from this MetricsLogger's stats. + key_error: Whether to throw a KeyError if `key` cannot be found in `self`. + + Raises: + KeyError: If `key` cannot be found in `self` AND `key_error` is True. + """ + self._del_key(key, key_error) + def reduce( self, key: Optional[Union[str, Tuple[str]]] = None, @@ -894,6 +907,19 @@ def _set_key(self, flat_key, stats): _dict[key] = {} _dict = _dict[key] + def _del_key(self, flat_key, key_error=False): + flat_key = force_tuple(tree.flatten(flat_key)) + _dict = self.stats + try: + for i, key in enumerate(flat_key): + if i == len(flat_key) - 1: + del _dict[key] + return + _dict = _dict[key] + except KeyError as e: + if key_error: + raise e + @Deprecated(new="MetricsLogger.merge_and_log_n_dicts()", error=True) def log_n_dicts(self, *args, **kwargs): pass diff --git a/rllib/utils/metrics/stats.py b/rllib/utils/metrics/stats.py index eec5845fd1a9d..5a7358fde45e5 100644 --- a/rllib/utils/metrics/stats.py +++ b/rllib/utils/metrics/stats.py @@ -472,18 +472,19 @@ def merge_in_parallel(self, *others: "Stats") -> None: self.values = list(reversed(new_values)) - def numpy(self, value: Any = None) -> "Stats": - """Converts all of self's internal values to numpy (if a tensor).""" - if value is not None: - if self._reduce_method is None: - assert isinstance(value, list) and len(self.values) >= len(value) - self.values = convert_to_numpy(value) - else: - assert len(self.values) > 0 - self.values = [convert_to_numpy(value)] + def set_to_numpy_values(self, values) -> None: + """Converts `self.values` from tensors to actual numpy values. + + Args: + values: The (numpy) values to set `self.values` to. + """ + numpy_values = convert_to_numpy(values) + if self._reduce_method is None: + assert isinstance(values, list) and len(self.values) >= len(values) + self.values = numpy_values else: - self.values = convert_to_numpy(self.values) - return self + assert len(self.values) > 0 + self.values = [numpy_values] def __len__(self) -> int: """Returns the length of the internal values list.""" @@ -613,8 +614,21 @@ def _reduced_values(self, values=None, window=None) -> Tuple[Any, Any]: reduce_in = reduce_in.float() reduced = reduce_meth(reduce_in) elif tf and tf.is_tensor(values[0]): - reduce_meth = getattr(tf, "reduce_" + self._reduce_method) - reduced = reduce_meth(values) + # TODO (sven): Currently, tensor metrics only work with window=1. + # We might want o enforce it more formally, b/c it's probably not a + # good idea to have MetricsLogger or Stats tinker with the actual + # computation graph that users are trying to build in their loss + # functions. + assert len(values) == 1 + # TODO (sven) If the shape is (), do NOT even use the reduce method. + # Using `tf.reduce_mean()` here actually lead to a completely broken + # DreamerV3 (for a still unknown exact reason). + if len(values[0].shape) == 0: + reduced = values[0] + else: + reduce_meth = getattr(tf, "reduce_" + self._reduce_method) + reduced = reduce_meth(values) + else: reduce_meth = getattr(np, "nan" + self._reduce_method) reduced = reduce_meth(values)