From 49cd01f03ae43bc3a0c29b6d6831fabf36c15469 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michael=20M=C3=B6bius?= Date: Fri, 15 Sep 2023 15:33:25 +0200 Subject: [PATCH] [RLlib] Fix configs.py for `framework != torch` (e.g. `tf2`). (#35975) --- rllib/core/models/configs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/rllib/core/models/configs.py b/rllib/core/models/configs.py index 0920c7ad49063..75475f7889985 100644 --- a/rllib/core/models/configs.py +++ b/rllib/core/models/configs.py @@ -520,7 +520,7 @@ def _validate(self, framework: str = "torch"): @_framework_implemented() def build(self, framework: str = "torch") -> "Model": - self._validate() + self._validate(framework) if framework == "torch": from ray.rllib.core.models.torch.heads import TorchCNNTransposeHead @@ -682,7 +682,7 @@ def _validate(self, framework: str = "torch"): @_framework_implemented() def build(self, framework: str = "torch") -> "Model": - self._validate() + self._validate(framework) if framework == "torch": from ray.rllib.core.models.torch.encoder import TorchCNNEncoder @@ -746,7 +746,7 @@ class MLPEncoderConfig(_MLPConfig): @_framework_implemented() def build(self, framework: str = "torch") -> "Encoder": - self._validate() + self._validate(framework) if framework == "torch": from ray.rllib.core.models.torch.encoder import TorchMLPEncoder