Skip to content

Commit

Permalink
[RLlib] Fix configs.py for framework != torch (e.g. tf2). (#35975)
Browse files Browse the repository at this point in the history
  • Loading branch information
RocketRider committed Sep 15, 2023
1 parent 515fbe8 commit 49cd01f
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions rllib/core/models/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ def _validate(self, framework: str = "torch"):

@_framework_implemented()
def build(self, framework: str = "torch") -> "Model":
self._validate()
self._validate(framework)

if framework == "torch":
from ray.rllib.core.models.torch.heads import TorchCNNTransposeHead
Expand Down Expand Up @@ -682,7 +682,7 @@ def _validate(self, framework: str = "torch"):

@_framework_implemented()
def build(self, framework: str = "torch") -> "Model":
self._validate()
self._validate(framework)

if framework == "torch":
from ray.rllib.core.models.torch.encoder import TorchCNNEncoder
Expand Down Expand Up @@ -746,7 +746,7 @@ class MLPEncoderConfig(_MLPConfig):

@_framework_implemented()
def build(self, framework: str = "torch") -> "Encoder":
self._validate()
self._validate(framework)

if framework == "torch":
from ray.rllib.core.models.torch.encoder import TorchMLPEncoder
Expand Down

0 comments on commit 49cd01f

Please sign in to comment.