Skip to content

Commit

Permalink
[RLlib] Fix stateful module errors with inference only mode. (ray-pro…
Browse files Browse the repository at this point in the history
…ject#45465)

Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
  • Loading branch information
simonsays1980 authored and ryanaoleary committed Jun 6, 2024
1 parent 59465e0 commit bb20ec3
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 3 deletions.
11 changes: 10 additions & 1 deletion rllib/algorithms/ppo/ppo_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any, Type

from ray.rllib.core.columns import Columns
from ray.rllib.core.models.configs import RecurrentEncoderConfig
from ray.rllib.core.models.specs.specs_dict import SpecDict
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.models.distributions import Distribution
Expand All @@ -20,12 +21,20 @@ class PPORLModule(RLModule, abc.ABC):
def setup(self):
# __sphinx_doc_begin__
catalog = self.config.get_catalog()
# If we have a stateful model states for the critic need to be collected
# during sampling and `inference-only` needs to be `False`. Note, at this
# point the encoder is not built, yet and therefore `is_stateful()` does
# not work.
is_stateful = isinstance(
catalog.actor_critic_encoder_config.base_encoder_config,
RecurrentEncoderConfig,
)
self.inference_only &= not is_stateful
# If this is not a learner module, we use only a single value network. This
# network is then either the share encoder network from the learner module
# or the actor encoder network from the learner module (if the value network
# is not shared with the actor network).
if self.inference_only and self.framework == "torch":
# catalog._model_config_dict["vf_share_layers"] = True
# We need to set the shared flag in the encoder config
# b/c the catalog has already been built at this point.
catalog.actor_critic_encoder_config.shared = True
Expand Down
3 changes: 2 additions & 1 deletion rllib/algorithms/ppo/torch/ppo_torch_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def setup(self):
def get_state(self, inference_only: bool = False) -> Dict[str, Any]:
state_dict = self.state_dict()
# If this module is not for inference, but the state dict is.
if not self.inference_only and inference_only:
# Note, for stateful modules, we need the full state dict.
if not self.inference_only and not self.is_stateful() and inference_only:
# Call the local hook to remove or rename the parameters.
return self._inference_only_get_state_hook(state_dict)
# Otherwise, the state dict is for checkpointing or saving the model.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def add(
for i in range(len(eps))
]
)
# Increase index.
# Increase index to the new length of `self._indices`.
j = len(self._indices)

@override(EpisodeReplayBuffer)
Expand Down

0 comments on commit bb20ec3

Please sign in to comment.