diff --git a/torchrl/trainers/__init__.py b/torchrl/trainers/__init__.py index 93ea6134aca..e56e63e6d77 100644 --- a/torchrl/trainers/__init__.py +++ b/torchrl/trainers/__init__.py @@ -9,6 +9,7 @@ CountFramesLog, LogReward, LogScalar, + LogTiming, LogValidationReward, mask_batch, OptimizerHook, @@ -29,6 +30,7 @@ "CountFramesLog", "LogReward", "LogScalar", + "LogTiming", "LogValidationReward", "mask_batch", "OptimizerHook", diff --git a/torchrl/trainers/algorithms/sac.py b/torchrl/trainers/algorithms/sac.py index 6bf4956bcd5..3d4176c5374 100644 --- a/torchrl/trainers/algorithms/sac.py +++ b/torchrl/trainers/algorithms/sac.py @@ -125,6 +125,7 @@ def __init__( log_observations: bool = False, target_net_updater: TargetNetUpdater | None = None, async_collection: bool = False, + log_timings: bool = False, ) -> None: warnings.warn( "SACTrainer is an experimental/prototype feature. The API may change in future versions. " @@ -151,6 +152,7 @@ def __init__( log_interval=log_interval, save_trainer_file=save_trainer_file, async_collection=async_collection, + log_timings=log_timings, ) self.replay_buffer = replay_buffer self.async_collection = async_collection