From f5be451bf8e57ae1076c8794f70e303275cf861e Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 24 Sep 2025 16:40:24 +0100 Subject: [PATCH] [BugFix] Fix wrong assertion about collector and buffer --- torchrl/trainers/algorithms/configs/trainers.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchrl/trainers/algorithms/configs/trainers.py b/torchrl/trainers/algorithms/configs/trainers.py index da467c61461..a35e686d124 100644 --- a/torchrl/trainers/algorithms/configs/trainers.py +++ b/torchrl/trainers/algorithms/configs/trainers.py @@ -101,7 +101,9 @@ def _make_sac_trainer(*args, **kwargs) -> SACTrainer: elif replay_buffer is not None: collector = collector(replay_buffer=replay_buffer) elif getattr(collector, "replay_buffer", None) is None: - if collector.replay_buffer is None or replay_buffer is None: + if async_collection and ( + collector.replay_buffer is None or replay_buffer is None + ): raise ValueError( "replay_buffer must be provided when async_collection is True" ) @@ -230,7 +232,7 @@ def _make_ppo_trainer(*args, **kwargs) -> PPOTrainer: collector = collector() else: collector = collector(replay_buffer=replay_buffer) - elif getattr(collector, "replay_buffer", None) is None: + elif async_collection and getattr(collector, "replay_buffer", None) is None: raise RuntimeError( "replay_buffer must be provided when async_collection is True" )