Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Fix stateful module errors with inference only mode. #45465

3 changes: 3 additions & 0 deletions rllib/algorithms/ppo/ppo_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ class PPORLModule(RLModule, abc.ABC):
def setup(self):
# __sphinx_doc_begin__
catalog = self.config.get_catalog()
# If we have a stateful model states for the critic need to be collected
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we also use is_stateful() here? What if the user doesn't use the built-in use_lstm option, but comes with their own stateful model?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sven1977 this was my first intend, however at this point in time is_stateful() cannot be called, yet b/c the encoder is not yet defined.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that this is not a nice solution, but at this point in the code we need to know, if the module is stateful or not, but the is_stateful() depends on the encoder which is defined depending on inference-only being True/False.

# during sampling and `inference-only` needs to be `False`.
self.inference_only = not self.config.model_config_dict["use_lstm"]
# If this is not a learner module, we use only a single value network. This
# network is then either the share encoder network from the learner module
# or the actor encoder network from the learner module (if the value network
Expand Down
3 changes: 2 additions & 1 deletion rllib/algorithms/ppo/torch/ppo_torch_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def setup(self):
def get_state(self, inference_only: bool = False) -> Dict[str, Any]:
state_dict = self.state_dict()
# If this module is not for inference, but the state dict is.
if not self.inference_only and inference_only:
# Note, for stateful modules, we need the full state dict.
if not self.inference_only and not self.is_stateful() and inference_only:
# Call the local hook to remove or rename the parameters.
return self._inference_only_get_state_hook(state_dict)
# Otherwise, the state dict is for checkpointing or saving the model.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def add(
for i in range(len(eps))
]
)
# Increase index.
# Increase index to the new length of `self._indices`.
j = len(self._indices)

@override(EpisodeReplayBuffer)
Expand Down
Loading