diff --git a/torchrl/trainers/algorithms/configs/trainers.py b/torchrl/trainers/algorithms/configs/trainers.py index da467c61461..a35e686d124 100644 --- a/torchrl/trainers/algorithms/configs/trainers.py +++ b/torchrl/trainers/algorithms/configs/trainers.py @@ -101,7 +101,9 @@ def _make_sac_trainer(*args, **kwargs) -> SACTrainer: elif replay_buffer is not None: collector = collector(replay_buffer=replay_buffer) elif getattr(collector, "replay_buffer", None) is None: - if collector.replay_buffer is None or replay_buffer is None: + if async_collection and ( + collector.replay_buffer is None or replay_buffer is None + ): raise ValueError( "replay_buffer must be provided when async_collection is True" ) @@ -230,7 +232,7 @@ def _make_ppo_trainer(*args, **kwargs) -> PPOTrainer: collector = collector() else: collector = collector(replay_buffer=replay_buffer) - elif getattr(collector, "replay_buffer", None) is None: + elif async_collection and getattr(collector, "replay_buffer", None) is None: raise RuntimeError( "replay_buffer must be provided when async_collection is True" )