Skip to content

Commit

Permalink
[RLlib] DreamerV3 on tf: Fix bug w/ reduce_fn still passed into `Le…
Browse files Browse the repository at this point in the history
…arnerGroup.update_from_batch()`. (#45419)
  • Loading branch information
sven1977 committed May 31, 2024
1 parent f9ab439 commit a95ec7f
Show file tree
Hide file tree
Showing 38 changed files with 951 additions and 660 deletions.
19 changes: 14 additions & 5 deletions rllib/algorithms/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,11 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict:
env_runner_results, key=ENV_RUNNER_RESULTS
)

self.metrics.log_dict(
self.metrics.peek(ENV_RUNNER_RESULTS, NUM_AGENT_STEPS_SAMPLED, default={}),
key=NUM_AGENT_STEPS_SAMPLED_LIFETIME,
reduce="sum",
)
self.metrics.log_value(
NUM_ENV_STEPS_SAMPLED_LIFETIME,
self.metrics.peek(ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED, default=0),
Expand All @@ -639,11 +644,6 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict:
self.metrics.peek(ENV_RUNNER_RESULTS, NUM_EPISODES, default=0),
reduce="sum",
)
self.metrics.log_dict(
self.metrics.peek(ENV_RUNNER_RESULTS, NUM_AGENT_STEPS_SAMPLED, default={}),
key=NUM_AGENT_STEPS_SAMPLED_LIFETIME,
reduce="sum",
)
self.metrics.log_dict(
self.metrics.peek(ENV_RUNNER_RESULTS, NUM_MODULE_STEPS_SAMPLED, default={}),
key=NUM_MODULE_STEPS_SAMPLED_LIFETIME,
Expand Down Expand Up @@ -680,6 +680,14 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict:
with self.metrics.log_time((TIMERS, LEARNER_UPDATE_TIMER)):
learner_results = self.learner_group.update_from_episodes(
episodes=episodes,
timesteps={
NUM_ENV_STEPS_SAMPLED_LIFETIME: (
self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME)
),
NUM_AGENT_STEPS_SAMPLED_LIFETIME: (
self.metrics.peek(NUM_AGENT_STEPS_SAMPLED_LIFETIME)
),
},
)
# Isolate TD-errors from result dicts (we should not log these to
# disk or WandB, they might be very large).
Expand Down Expand Up @@ -730,6 +738,7 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict:
# Update the target networks, if necessary.
with self.metrics.log_time((TIMERS, LEARNER_ADDITIONAL_UPDATE_TIMER)):
modules_to_update = set(learner_results[0].keys()) - {ALL_MODULES}
# TODO (sven): Move to Learner._after_gradient_based_update().
additional_results = self.learner_group.additional_update(
module_ids_to_update=modules_to_update,
timestep=current_ts,
Expand Down
222 changes: 129 additions & 93 deletions rllib/algorithms/dreamerv3/dreamerv3.py

Large diffs are not rendered by default.

56 changes: 10 additions & 46 deletions rllib/algorithms/dreamerv3/dreamerv3_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,61 +7,25 @@
D. Hafner, T. Lillicrap, M. Norouzi, J. Ba
https://arxiv.org/pdf/2010.02193.pdf
"""
from typing import Any, DefaultDict, Dict

from ray.rllib.algorithms.dreamerv3.dreamerv3 import DreamerV3Config
from ray.rllib.core.learner.learner import Learner
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import ModuleID, TensorType
from ray.rllib.utils.annotations import (
override,
OverrideToImplementCustomLogic_CallToSuperRecommended,
)


class DreamerV3Learner(Learner):
"""DreamerV3 specific Learner class.
Only implements the `additional_update_for_module()` method to define the logic
Only implements the `_after_gradient_based_update()` method to define the logic
for updating the critic EMA-copy after each training step.
"""

@OverrideToImplementCustomLogic_CallToSuperRecommended
@override(Learner)
def compile_results(
self,
*,
batch: MultiAgentBatch,
fwd_out: Dict[str, Any],
loss_per_module: Dict[str, TensorType],
metrics_per_module: DefaultDict[ModuleID, Dict[str, Any]],
) -> Dict[str, Any]:
results = super().compile_results(
batch=batch,
fwd_out=fwd_out,
loss_per_module=loss_per_module,
metrics_per_module=metrics_per_module,
)

# Add the predicted obs distributions for possible (video) summarization.
if self.config.report_images_and_videos:
for module_id, res in results.items():
if module_id in fwd_out:
res["WORLD_MODEL_fwd_out_obs_distribution_means_BxT"] = fwd_out[
module_id
]["obs_distribution_means_BxT"]
return results

@override(Learner)
def additional_update_for_module(
self,
*,
module_id: ModuleID,
config: DreamerV3Config,
timestep: int,
) -> None:
"""Updates the EMA weights of the critic network."""

# Call the base class' method.
super().additional_update_for_module(
module_id=module_id, config=config, timestep=timestep
)
def _after_gradient_based_update(self, timesteps):
super()._after_gradient_based_update(timesteps)

# Update EMA weights of the critic.
self.module[module_id].critic.update_ema()
for module_id, module in self.module._rl_modules.items():
module.critic.update_ema()
86 changes: 61 additions & 25 deletions rllib/algorithms/dreamerv3/dreamerv3_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import abc
from typing import Any, Dict

import gymnasium as gym
import numpy as np
Expand All @@ -29,6 +30,8 @@
class DreamerV3RLModule(RLModule, abc.ABC):
@override(RLModule)
def setup(self):
super().setup()

# Gather model-relevant settings.
B = 1
T = self.config.model_config_dict["batch_length_T"]
Expand Down Expand Up @@ -79,35 +82,39 @@ def setup(self):
self.action_dist_cls = catalog.get_action_dist_cls(framework=self.framework)

# Perform a test `call()` to force building the dreamer model's variables.
test_obs = np.tile(
np.expand_dims(self.config.observation_space.sample(), (0, 1)),
reps=(B, T) + (1,) * len(self.config.observation_space.shape),
)
if isinstance(self.config.action_space, gym.spaces.Discrete):
test_actions = np.tile(
np.expand_dims(
one_hot(
self.config.action_space.sample(),
depth=self.config.action_space.n,
if self.framework == "tf2":
test_obs = np.tile(
np.expand_dims(self.config.observation_space.sample(), (0, 1)),
reps=(B, T) + (1,) * len(self.config.observation_space.shape),
)
if isinstance(self.config.action_space, gym.spaces.Discrete):
test_actions = np.tile(
np.expand_dims(
one_hot(
self.config.action_space.sample(),
depth=self.config.action_space.n,
),
(0, 1),
),
(0, 1),
reps=(B, T, 1),
)
else:
test_actions = np.tile(
np.expand_dims(self.config.action_space.sample(), (0, 1)),
reps=(B, T, 1),
)

self.dreamer_model(
inputs=None,
observations=_convert_to_tf(test_obs, dtype=tf.float32),
actions=_convert_to_tf(test_actions, dtype=tf.float32),
is_first=_convert_to_tf(np.ones((B, T)), dtype=tf.bool),
start_is_terminated_BxT=_convert_to_tf(
np.zeros((B * T,)), dtype=tf.bool
),
reps=(B, T, 1),
)
else:
test_actions = np.tile(
np.expand_dims(self.config.action_space.sample(), (0, 1)),
reps=(B, T, 1),
gamma=gamma,
)

self.dreamer_model(
None,
_convert_to_tf(test_obs, dtype=tf.float32),
_convert_to_tf(test_actions, dtype=tf.float32),
_convert_to_tf(np.ones((B, T)), dtype=tf.bool),
_convert_to_tf(np.zeros((B * T,)), dtype=tf.bool),
)

# Initialize the critic EMA net:
self.critic.init_ema()

Expand Down Expand Up @@ -152,3 +159,32 @@ def output_specs_train(self) -> SpecDict:
# Deterministic, continuous h-states (t1 to T).
"h_states_BxT",
]

@override(RLModule)
def _forward_inference(self, batch: NestedDict) -> Dict[str, Any]:
# Call the Dreamer-Model's forward_inference method and return a dict.
actions, next_state = self.dreamer_model.forward_inference(
observations=batch[Columns.OBS],
previous_states=batch[Columns.STATE_IN],
is_first=batch["is_first"],
)
return {Columns.ACTIONS: actions, Columns.STATE_OUT: next_state}

@override(RLModule)
def _forward_exploration(self, batch: NestedDict) -> Dict[str, Any]:
# Call the Dreamer-Model's forward_exploration method and return a dict.
actions, next_state = self.dreamer_model.forward_exploration(
observations=batch[Columns.OBS],
previous_states=batch[Columns.STATE_IN],
is_first=batch["is_first"],
)
return {Columns.ACTIONS: actions, Columns.STATE_OUT: next_state}

@override(RLModule)
def _forward_train(self, batch: NestedDict):
# Call the Dreamer-Model's forward_train method and return its outputs as-is.
return self.dreamer_model.forward_train(
observations=batch[Columns.OBS],
actions=batch[Columns.ACTIONS],
is_first=batch["is_first"],
)
126 changes: 63 additions & 63 deletions rllib/algorithms/dreamerv3/tests/test_dreamerv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def test_dreamerv3_compilation(self):
# Build a DreamerV3Config object.
config = (
dreamerv3.DreamerV3Config()
.framework(eager_tracing=False)
.training(
# Keep things simple. Especially the long dream rollouts seem
# to take an enormous amount of time (initially).
Expand All @@ -51,78 +52,77 @@ def test_dreamerv3_compilation(self):
use_float16=False,
)
.learners(
num_learners=2, # Try with 2 Learners.
num_learners=0, # TODO 2 # Try with 2 Learners.
num_cpus_per_learner=1,
num_gpus_per_learner=0,
)
)

num_iterations = 2

for _ in framework_iterator(config, frameworks="tf2"):
for env in [
"FrozenLake-v1",
"CartPole-v1",
"ALE/MsPacman-v5",
"Pendulum-v1",
]:
print("Env={}".format(env))
# Add one-hot observations for FrozenLake env.
if env == "FrozenLake-v1":

def env_creator(ctx):
import gymnasium as gym
from ray.rllib.algorithms.dreamerv3.utils.env_runner import (
OneHot,
)

return OneHot(gym.make("FrozenLake-v1"))

tune.register_env("frozen-lake-one-hot", env_creator)
env = "frozen-lake-one-hot"

config.environment(env)
algo = config.build()
obs_space = algo.workers.local_worker().env.single_observation_space
act_space = algo.workers.local_worker().env.single_action_space
rl_module = algo.workers.local_worker().module

for i in range(num_iterations):
results = algo.train()
print(results)
# Test dream trajectory w/ recreated observations.
sample = algo.replay_buffer.sample()
dream = rl_module.dreamer_model.dream_trajectory_with_burn_in(
start_states=rl_module.dreamer_model.get_initial_state(),
timesteps_burn_in=5,
timesteps_H=45,
observations=sample["obs"][:1], # B=1
actions=(
one_hot(
sample["actions"],
depth=act_space.n,
)
if isinstance(act_space, gym.spaces.Discrete)
else sample["actions"]
)[
:1
], # B=1
)
self.assertTrue(
dream["actions_dreamed_t0_to_H_BxT"].shape
== (46, 1)
+ (
(act_space.n,)
if isinstance(act_space, gym.spaces.Discrete)
else tuple(act_space.shape)
for env in [
"FrozenLake-v1",
"CartPole-v1",
"ALE/MsPacman-v5",
"Pendulum-v1",
]:
print("Env={}".format(env))
# Add one-hot observations for FrozenLake env.
if env == "FrozenLake-v1":

def env_creator(ctx):
import gymnasium as gym
from ray.rllib.algorithms.dreamerv3.utils.env_runner import (
OneHot,
)

return OneHot(gym.make("FrozenLake-v1"))

tune.register_env("frozen-lake-one-hot", env_creator)
env = "frozen-lake-one-hot"

config.environment(env)
algo = config.build()
obs_space = algo.workers.local_worker().env.single_observation_space
act_space = algo.workers.local_worker().env.single_action_space
rl_module = algo.workers.local_worker().module

for i in range(num_iterations):
results = algo.train()
print(results)
# Test dream trajectory w/ recreated observations.
sample = algo.replay_buffer.sample()
dream = rl_module.dreamer_model.dream_trajectory_with_burn_in(
start_states=rl_module.dreamer_model.get_initial_state(),
timesteps_burn_in=5,
timesteps_H=45,
observations=sample["obs"][:1], # B=1
actions=(
one_hot(
sample["actions"],
depth=act_space.n,
)
if isinstance(act_space, gym.spaces.Discrete)
else sample["actions"]
)[
:1
], # B=1
)
self.assertTrue(
dream["actions_dreamed_t0_to_H_BxT"].shape
== (46, 1)
+ (
(act_space.n,)
if isinstance(act_space, gym.spaces.Discrete)
else tuple(act_space.shape)
)
self.assertTrue(dream["continues_dreamed_t0_to_H_BxT"].shape == (46, 1))
self.assertTrue(
dream["observations_dreamed_t0_to_H_BxT"].shape
== [46, 1] + list(obs_space.shape)
)
algo.stop()
)
self.assertTrue(dream["continues_dreamed_t0_to_H_BxT"].shape == (46, 1))
self.assertTrue(
dream["observations_dreamed_t0_to_H_BxT"].shape
== [46, 1] + list(obs_space.shape)
)
algo.stop()

def test_dreamerv3_dreamer_model_sizes(self):
"""Tests, whether the different model sizes match the ones reported in [1]."""
Expand Down
Loading

0 comments on commit a95ec7f

Please sign in to comment.