From 4d6443e8fef11635339a225e98707e07873d5bc6 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 14 Oct 2025 09:41:02 +0100 Subject: [PATCH] Update [ghstack-poisoned] --- torchrl/trainers/__init__.py | 2 ++ torchrl/trainers/algorithms/sac.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/torchrl/trainers/__init__.py b/torchrl/trainers/__init__.py index 93ea6134aca..e56e63e6d77 100644 --- a/torchrl/trainers/__init__.py +++ b/torchrl/trainers/__init__.py @@ -9,6 +9,7 @@ CountFramesLog, LogReward, LogScalar, + LogTiming, LogValidationReward, mask_batch, OptimizerHook, @@ -29,6 +30,7 @@ "CountFramesLog", "LogReward", "LogScalar", + "LogTiming", "LogValidationReward", "mask_batch", "OptimizerHook", diff --git a/torchrl/trainers/algorithms/sac.py b/torchrl/trainers/algorithms/sac.py index 6bf4956bcd5..3d4176c5374 100644 --- a/torchrl/trainers/algorithms/sac.py +++ b/torchrl/trainers/algorithms/sac.py @@ -125,6 +125,7 @@ def __init__( log_observations: bool = False, target_net_updater: TargetNetUpdater | None = None, async_collection: bool = False, + log_timings: bool = False, ) -> None: warnings.warn( "SACTrainer is an experimental/prototype feature. The API may change in future versions. " @@ -151,6 +152,7 @@ def __init__( log_interval=log_interval, save_trainer_file=save_trainer_file, async_collection=async_collection, + log_timings=log_timings, ) self.replay_buffer = replay_buffer self.async_collection = async_collection