diff --git a/docs/source/reference/trainers.rst b/docs/source/reference/trainers.rst index 36ff4d66331..063a0f26735 100644 --- a/docs/source/reference/trainers.rst +++ b/docs/source/reference/trainers.rst @@ -183,6 +183,7 @@ Trainer and hooks Trainer TrainerHookBase UpdateWeights + TargetNetUpdaterHook Algorithm-specific trainers (Experimental) @@ -202,37 +203,54 @@ into complete training solutions with sensible defaults and comprehensive config :template: rl_template.rst PPOTrainer + SACTrainer -PPOTrainer -~~~~~~~~~~ +Algorithm Trainers +~~~~~~~~~~~~~~~~~~ -The :class:`~torchrl.trainers.algorithms.PPOTrainer` provides a complete PPO training solution -with configurable defaults and a comprehensive configuration system built on Hydra. +TorchRL provides high-level algorithm trainers that offer complete training solutions with minimal code. +These trainers feature comprehensive configuration systems built on Hydra, enabling both simple usage +and sophisticated customization. + +**Currently Available:** + +- :class:`~torchrl.trainers.algorithms.PPOTrainer` - Proximal Policy Optimization +- :class:`~torchrl.trainers.algorithms.SACTrainer` - Soft Actor-Critic **Key Features:** -- Complete training pipeline with environment setup, data collection, and optimization -- Extensive configuration system using dataclasses and Hydra -- Built-in logging for rewards, actions, and training statistics -- Modular design built on existing TorchRL components -- **Minimal code**: Complete SOTA implementation in just ~20 lines! +- **Complete pipeline**: Environment setup, data collection, and optimization +- **Hydra configuration**: Extensive dataclass-based configuration system +- **Built-in logging**: Rewards, actions, and algorithm-specific metrics +- **Modular design**: Built on existing TorchRL components +- **Minimal code**: Complete SOTA implementations in ~20 lines! .. warning:: - This is an experimental feature. The API may change in future versions. - We welcome feedback and contributions to help improve this implementation! + Algorithm trainers are experimental features. The API may change in future versions. + We welcome feedback and contributions to help improve these implementations! -**Quick Start - Command Line Interface:** +Quick Start Examples +^^^^^^^^^^^^^^^^^^^^ + +**PPO Training:** .. code-block:: bash - # Basic usage - train PPO on Pendulum-v1 with default settings + # Train PPO on Pendulum-v1 with default settings python sota-implementations/ppo_trainer/train.py +**SAC Training:** + +.. code-block:: bash + + # Train SAC on a continuous control task + python sota-implementations/sac_trainer/train.py + **Custom Configuration:** .. code-block:: bash - # Override specific parameters via command line + # Override parameters for any algorithm python sota-implementations/ppo_trainer/train.py \ trainer.total_frames=2000000 \ training_env.create_env_fn.base_env.env_name=HalfCheetah-v4 \ @@ -243,32 +261,34 @@ with configurable defaults and a comprehensive configuration system built on Hyd .. code-block:: bash - # Switch to a different environment and logger - python sota-implementations/ppo_trainer/train.py \ - env=gym \ + # Switch environment and logger for any trainer + python sota-implementations/sac_trainer/train.py \ training_env.create_env_fn.base_env.env_name=Walker2d-v4 \ - logger=tensorboard + logger=tensorboard \ + logger.exp_name=sac_walker2d -**See All Options:** +**View Configuration Options:** .. code-block:: bash - # View all available configuration options + # See all available options for any trainer python sota-implementations/ppo_trainer/train.py --help + python sota-implementations/sac_trainer/train.py --help -**Configuration Groups:** +Universal Configuration System +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -The PPOTrainer configuration is organized into logical groups: +All algorithm trainers share a unified configuration architecture organized into logical groups: -- **Environment**: ``env_cfg__env_name``, ``env_cfg__backend``, ``env_cfg__device`` -- **Networks**: ``actor_network__network__num_cells``, ``critic_network__module__num_cells`` -- **Training**: ``total_frames``, ``clip_norm``, ``num_epochs``, ``optimizer_cfg__lr`` -- **Logging**: ``log_rewards``, ``log_actions``, ``log_observations`` +- **Environment**: ``training_env.create_env_fn.base_env.env_name``, ``training_env.num_workers`` +- **Networks**: ``networks.policy_network.num_cells``, ``networks.value_network.num_cells`` +- **Training**: ``trainer.total_frames``, ``trainer.clip_norm``, ``optimizer.lr`` +- **Data**: ``collector.frames_per_batch``, ``replay_buffer.batch_size``, ``replay_buffer.storage.max_size`` +- **Logging**: ``logger.exp_name``, ``logger.project``, ``trainer.log_interval`` **Working Example:** -The `sota-implementations/ppo_trainer/ `_ -directory contains a complete, working PPO implementation that demonstrates the simplicity and power of the trainer system: +All trainer implementations follow the same simple pattern: .. code-block:: python @@ -283,33 +303,57 @@ directory contains a complete, working PPO implementation that demonstrates the if __name__ == "__main__": main() -*Complete PPO training with full configurability in ~20 lines!* +*Complete algorithm training with full configurability in ~20 lines!* -**Configuration Classes:** +Configuration Classes +^^^^^^^^^^^^^^^^^^^^^ -The PPOTrainer uses a hierarchical configuration system with these main config classes. +The trainer system uses a hierarchical configuration system with shared components. .. note:: The configuration system requires Python 3.10+ due to its use of modern type annotation syntax. -- **Trainer**: :class:`~torchrl.trainers.algorithms.configs.trainers.PPOTrainerConfig` +**Algorithm-Specific Trainers:** + +- **PPO**: :class:`~torchrl.trainers.algorithms.configs.trainers.PPOTrainerConfig` +- **SAC**: :class:`~torchrl.trainers.algorithms.configs.trainers.SACTrainerConfig` + +**Shared Configuration Components:** + - **Environment**: :class:`~torchrl.trainers.algorithms.configs.envs_libs.GymEnvConfig`, :class:`~torchrl.trainers.algorithms.configs.envs.BatchedEnvConfig` - **Networks**: :class:`~torchrl.trainers.algorithms.configs.modules.MLPConfig`, :class:`~torchrl.trainers.algorithms.configs.modules.TanhNormalModelConfig` - **Data**: :class:`~torchrl.trainers.algorithms.configs.data.TensorDictReplayBufferConfig`, :class:`~torchrl.trainers.algorithms.configs.collectors.MultiaSyncDataCollectorConfig` -- **Objectives**: :class:`~torchrl.trainers.algorithms.configs.objectives.PPOLossConfig` +- **Objectives**: :class:`~torchrl.trainers.algorithms.configs.objectives.PPOLossConfig`, :class:`~torchrl.trainers.algorithms.configs.objectives.SACLossConfig` - **Optimizers**: :class:`~torchrl.trainers.algorithms.configs.utils.AdamConfig`, :class:`~torchrl.trainers.algorithms.configs.utils.AdamWConfig` - **Logging**: :class:`~torchrl.trainers.algorithms.configs.logging.WandbLoggerConfig`, :class:`~torchrl.trainers.algorithms.configs.logging.TensorboardLoggerConfig` +Algorithm-Specific Features +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +**PPOTrainer:** + +- On-policy learning with advantage estimation +- Policy clipping and value function optimization +- Configurable number of epochs per batch +- Built-in GAE (Generalized Advantage Estimation) + +**SACTrainer:** + +- Off-policy learning with replay buffer +- Entropy-regularized policy optimization +- Target network soft updates +- Continuous action space optimization + **Future Development:** -This is the first of many planned algorithm-specific trainers. Future releases will include: +The trainer system is actively expanding. Upcoming features include: -- Additional algorithms: SAC, TD3, DQN, A2C, and more -- Full integration of all TorchRL components within the configuration system -- Enhanced configuration validation and error reporting -- Distributed training support for high-level trainers +- Additional algorithms: TD3, DQN, A2C, DDPG, and more +- Enhanced distributed training support +- Advanced configuration validation and error reporting +- Integration with more TorchRL ecosystem components -See the complete `configuration system documentation `_ for all available options. +See the complete `configuration system documentation `_ for all available options and examples. Builders diff --git a/sota-implementations/sac_trainer/config/config.yaml b/sota-implementations/sac_trainer/config/config.yaml new file mode 100644 index 00000000000..2f794c9bfa2 --- /dev/null +++ b/sota-implementations/sac_trainer/config/config.yaml @@ -0,0 +1,146 @@ +# SAC Trainer Configuration for HalfCheetah-v4 +# This configuration uses the new configurable trainer system and matches SOTA SAC implementation + +defaults: + + - transform@transform0: step_counter + - transform@transform1: double_to_float + + - env@training_env: batched_env + - env@training_env.create_env_fn: transformed_env + - env@training_env.create_env_fn.base_env: gym + - transform@training_env.create_env_fn.transform: compose + + - model@models.policy_model: tanh_normal + - model@models.value_model: value + - model@models.qvalue_model: value + + - network@networks.policy_network: mlp + - network@networks.value_network: mlp + - network@networks.qvalue_network: mlp + + - collector@collector: multi_async + + - replay_buffer@replay_buffer: base + - storage@replay_buffer.storage: lazy_tensor + - writer@replay_buffer.writer: round_robin + - sampler@replay_buffer.sampler: random + - trainer@trainer: sac + - optimizer@optimizer: adam + - loss@loss: sac + - target_net_updater@target_net_updater: soft + - logger@logger: wandb + - _self_ + +# Network configurations +networks: + policy_network: + out_features: 12 # HalfCheetah action space is 6-dimensional (loc + scale) + in_features: 17 # HalfCheetah observation space is 17-dimensional + num_cells: [256, 256] + + value_network: + out_features: 1 # Value output + in_features: 17 # HalfCheetah observation space + num_cells: [256, 256] + + qvalue_network: + out_features: 1 # Q-value output + in_features: 23 # HalfCheetah observation space (17) + action space (6) + num_cells: [256, 256] + +# Model configurations +models: + policy_model: + return_log_prob: true + in_keys: ["observation"] + param_keys: ["loc", "scale"] + out_keys: ["action"] + network: ${networks.policy_network} + + qvalue_model: + in_keys: ["observation", "action"] + out_keys: ["state_action_value"] + network: ${networks.qvalue_network} + +transform0: + max_steps: 1000 + step_count_key: "step_count" + +transform1: + # DoubleToFloatTransform - converts double precision to float to fix dtype mismatch + in_keys: null + out_keys: null + +training_env: + num_workers: 4 + create_env_fn: + base_env: + env_name: HalfCheetah-v4 + transform: + transforms: + - ${transform0} + - ${transform1} + _partial_: true + +# Loss configuration +loss: + actor_network: ${models.policy_model} + qvalue_network: ${models.qvalue_model} + target_entropy: "auto" + loss_function: l2 + alpha_init: 1.0 + delay_qvalue: true + num_qvalue_nets: 2 + +target_net_updater: + tau: 0.001 + +# Optimizer configuration +optimizer: + lr: 3.0e-4 + +# Collector configuration +collector: + create_env_fn: ${training_env} + policy: ${models.policy_model} + total_frames: 1_000_000 + frames_per_batch: 1000 + num_workers: 4 + init_random_frames: 25000 + track_policy_version: true + +# Replay buffer configuration +replay_buffer: + storage: + max_size: 1_000_000 + device: cpu + ndim: 1 + sampler: + writer: + compilable: false + batch_size: 256 + +logger: + exp_name: sac_halfcheetah_v4 + offline: false + project: torchrl-sota-implementations + +# Trainer configuration +trainer: + collector: ${collector} + optimizer: ${optimizer} + replay_buffer: ${replay_buffer} + target_net_updater: ${target_net_updater} + loss_module: ${loss} + logger: ${logger} + total_frames: 1_000_000 + frame_skip: 1 + clip_grad_norm: false # SAC typically doesn't use gradient clipping + clip_norm: null + progress_bar: true + seed: 42 + save_trainer_interval: 25000 # Match SOTA eval_iter + log_interval: 25000 + save_trainer_file: null + optim_steps_per_batch: 64 # Match SOTA utd_ratio diff --git a/sota-implementations/sac_trainer/train.py b/sota-implementations/sac_trainer/train.py new file mode 100644 index 00000000000..2df69106df9 --- /dev/null +++ b/sota-implementations/sac_trainer/train.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import hydra +import torchrl +from torchrl.trainers.algorithms.configs import * # noqa: F401, F403 + + +@hydra.main(config_path="config", config_name="config", version_base="1.1") +def main(cfg): + def print_reward(td): + torchrl.logger.info(f"reward: {td['next', 'reward'].mean(): 4.4f}") + + trainer = hydra.utils.instantiate(cfg.trainer) + trainer.register_op(dest="batch_process", op=print_reward) + trainer.train() + + +if __name__ == "__main__": + main() diff --git a/torchrl/trainers/__init__.py b/torchrl/trainers/__init__.py index 0136da42761..3eb3e7ca75d 100644 --- a/torchrl/trainers/__init__.py +++ b/torchrl/trainers/__init__.py @@ -16,6 +16,7 @@ ReplayBufferTrainer, RewardNormalizer, SelectKeys, + TargetNetUpdaterHook, Trainer, TrainerHookBase, UpdateWeights, @@ -37,4 +38,5 @@ "Trainer", "TrainerHookBase", "UpdateWeights", + "TargetNetUpdaterHook", ] diff --git a/torchrl/trainers/algorithms/__init__.py b/torchrl/trainers/algorithms/__init__.py index d35af17b5ed..4f6b71894d7 100644 --- a/torchrl/trainers/algorithms/__init__.py +++ b/torchrl/trainers/algorithms/__init__.py @@ -6,5 +6,6 @@ from __future__ import annotations from .ppo import PPOTrainer +from .sac import SACTrainer -__all__ = ["PPOTrainer"] +__all__ = ["PPOTrainer", "SACTrainer"] diff --git a/torchrl/trainers/algorithms/configs/__init__.py b/torchrl/trainers/algorithms/configs/__init__.py index 73d859d13b0..52c50ed8e11 100644 --- a/torchrl/trainers/algorithms/configs/__init__.py +++ b/torchrl/trainers/algorithms/configs/__init__.py @@ -84,8 +84,18 @@ TensorDictModuleConfig, ValueModelConfig, ) -from torchrl.trainers.algorithms.configs.objectives import LossConfig, PPOLossConfig -from torchrl.trainers.algorithms.configs.trainers import PPOTrainerConfig, TrainerConfig +from torchrl.trainers.algorithms.configs.objectives import ( + HardUpdateConfig, + LossConfig, + PPOLossConfig, + SACLossConfig, + SoftUpdateConfig, +) +from torchrl.trainers.algorithms.configs.trainers import ( + PPOTrainerConfig, + SACTrainerConfig, + TrainerConfig, +) from torchrl.trainers.algorithms.configs.transforms import ( ActionDiscretizerConfig, ActionMaskConfig, @@ -317,8 +327,10 @@ # Losses "LossConfig", "PPOLossConfig", + "SACLossConfig", # Trainers "PPOTrainerConfig", + "SACTrainerConfig", "TrainerConfig", # Loggers "CSVLoggerConfig", @@ -479,6 +491,10 @@ def _register_configs(): cs.store(group="loss", name="base", node=LossConfig) cs.store(group="loss", name="ppo", node=PPOLossConfig) + cs.store(group="loss", name="sac", node=SACLossConfig) + + cs.store(group="target_net_updater", name="soft", node=SoftUpdateConfig) + cs.store(group="target_net_updater", name="hard", node=HardUpdateConfig) # ============================================================================= # Replay Buffer Configurations @@ -523,6 +539,7 @@ def _register_configs(): cs.store(group="trainer", name="base", node=TrainerConfig) cs.store(group="trainer", name="ppo", node=PPOTrainerConfig) + cs.store(group="trainer", name="sac", node=SACTrainerConfig) # ============================================================================= # Optimizer Configurations diff --git a/torchrl/trainers/algorithms/configs/collectors.py b/torchrl/trainers/algorithms/configs/collectors.py index d653ea05c7f..34eb778b9b2 100644 --- a/torchrl/trainers/algorithms/configs/collectors.py +++ b/torchrl/trainers/algorithms/configs/collectors.py @@ -51,6 +51,7 @@ class SyncDataCollectorConfig(DataCollectorConfig): cudagraph_policy: Any = None no_cuda_sync: bool = False weight_updater: Any = None + track_policy_version: bool = False _target_: str = "torchrl.collectors.SyncDataCollector" _partial_: bool = False @@ -93,6 +94,7 @@ class AsyncDataCollectorConfig(DataCollectorConfig): cudagraph_policy: Any = None no_cuda_sync: bool = False weight_updater: Any = None + track_policy_version: bool = False _target_: str = "torchrl.collectors.aSyncDataCollector" def __post_init__(self): @@ -133,6 +135,7 @@ class MultiSyncDataCollectorConfig(DataCollectorConfig): cudagraph_policy: Any = None no_cuda_sync: bool = False weight_updater: Any = None + track_policy_version: bool = False _target_: str = "torchrl.collectors.MultiSyncDataCollector" def __post_init__(self): @@ -174,6 +177,7 @@ class MultiaSyncDataCollectorConfig(DataCollectorConfig): cudagraph_policy: Any = None no_cuda_sync: bool = False weight_updater: Any = None + track_policy_version: bool = False _target_: str = "torchrl.collectors.MultiaSyncDataCollector" def __post_init__(self): diff --git a/torchrl/trainers/algorithms/configs/data.py b/torchrl/trainers/algorithms/configs/data.py index daf11078303..40c90c1a808 100644 --- a/torchrl/trainers/algorithms/configs/data.py +++ b/torchrl/trainers/algorithms/configs/data.py @@ -50,6 +50,7 @@ class RandomSamplerConfig(SamplerConfig): """Configuration for random sampling from replay buffer.""" _target_: str = "torchrl.data.replay_buffers.RandomSampler" + batch_size: int | None = None def __post_init__(self) -> None: """Post-initialization hook for random sampler configurations.""" diff --git a/torchrl/trainers/algorithms/configs/objectives.py b/torchrl/trainers/algorithms/configs/objectives.py index 087091d5f26..a0d1c8fb0d3 100644 --- a/torchrl/trainers/algorithms/configs/objectives.py +++ b/torchrl/trainers/algorithms/configs/objectives.py @@ -8,7 +8,8 @@ from dataclasses import dataclass from typing import Any -from torchrl.objectives.ppo import ClipPPOLoss, KLPENPPOLoss, PPOLoss +from torchrl.objectives import ClipPPOLoss, KLPENPPOLoss, PPOLoss, SACLoss +from torchrl.objectives.sac import DiscreteSACLoss from torchrl.trainers.algorithms.configs.common import ConfigBase @@ -27,12 +28,61 @@ def __post_init__(self) -> None: @dataclass -class PPOLossConfig(LossConfig): - """A class to configure a PPO loss. +class SACLossConfig(LossConfig): + """A class to configure a SAC loss.""" - Args: - loss_type: The type of loss to use. - """ + actor_network: Any = None + qvalue_network: Any = None + value_network: Any = None + discrete: bool = False + num_qvalue_nets: int = 2 + loss_function: str = "smooth_l1" + alpha_init: float = 1.0 + min_alpha: float | None = None + max_alpha: float | None = None + action_spec: Any = None + fixed_alpha: bool = False + target_entropy: str | float = "auto" + delay_actor: bool = False + delay_qvalue: bool = True + delay_value: bool = True + gamma: float | None = None + priority_key: str | None = None + separate_losses: bool = False + reduction: str | None = None + skip_done_states: bool = False + deactivate_vmap: bool = False + _target_: str = "torchrl.trainers.algorithms.configs.objectives._make_sac_loss" + + def __post_init__(self) -> None: + """Post-initialization hook for SAC loss configurations.""" + super().__post_init__() + + +def _make_sac_loss(*args, **kwargs) -> SACLoss: + discrete_loss_type = kwargs.pop("discrete", False) + + # Instantiate networks if they are config objects + actor_network = kwargs.get("actor_network") + qvalue_network = kwargs.get("qvalue_network") + value_network = kwargs.get("value_network") + + if actor_network is not None and hasattr(actor_network, "_target_"): + kwargs["actor_network"] = actor_network() + if qvalue_network is not None and hasattr(qvalue_network, "_target_"): + kwargs["qvalue_network"] = qvalue_network() + if value_network is not None and hasattr(value_network, "_target_"): + kwargs["value_network"] = value_network() + + if discrete_loss_type: + return DiscreteSACLoss(*args, **kwargs) + else: + return SACLoss(*args, **kwargs) + + +@dataclass +class PPOLossConfig(LossConfig): + """A class to configure a PPO loss.""" actor_network: Any = None critic_network: Any = None @@ -73,3 +123,28 @@ def _make_ppo_loss(*args, **kwargs) -> PPOLoss: return PPOLoss(*args, **kwargs) else: raise ValueError(f"Invalid loss type: {loss_type}") + + +@dataclass +class TargetNetUpdaterConfig: + """An abstract class to configure target net updaters.""" + + loss_module: Any + _partial_: bool = True + + +@dataclass +class SoftUpdateConfig(TargetNetUpdaterConfig): + """A class for soft update instantiation.""" + + _target_: str = "torchrl.objectives.utils.SoftUpdate" + eps: float | None = None # noqa # type-ignore + tau: float | None = 0.001 # noqa # type-ignore + + +@dataclass +class HardUpdateConfig(TargetNetUpdaterConfig): + """A class for hard update instantiation.""" + + _target_: str = "torchrl.objectives.utils.HardUpdate." + value_network_update_interval: int = 1000 diff --git a/torchrl/trainers/algorithms/configs/trainers.py b/torchrl/trainers/algorithms/configs/trainers.py index fb6a21114bc..cf2f02893bc 100644 --- a/torchrl/trainers/algorithms/configs/trainers.py +++ b/torchrl/trainers/algorithms/configs/trainers.py @@ -12,8 +12,10 @@ from torchrl.collectors import DataCollectorBase from torchrl.objectives.common import LossModule +from torchrl.objectives.utils import TargetNetUpdater from torchrl.trainers.algorithms.configs.common import ConfigBase from torchrl.trainers.algorithms.ppo import PPOTrainer +from torchrl.trainers.algorithms.sac import SACTrainer @dataclass @@ -24,6 +26,121 @@ def __post_init__(self) -> None: """Post-initialization hook for trainer configurations.""" +@dataclass +class SACTrainerConfig(TrainerConfig): + """Configuration class for SAC (Soft Actor Critic) trainer. + + This class defines the configuration parameters for creating a SAC trainer, + including both required and optional fields with sensible defaults. + """ + + collector: Any + total_frames: int + optim_steps_per_batch: int | None + loss_module: Any + optimizer: Any + logger: Any + save_trainer_file: Any + replay_buffer: Any + frame_skip: int = 1 + clip_grad_norm: bool = True + clip_norm: float | None = None + progress_bar: bool = True + seed: int | None = None + save_trainer_interval: int = 10000 + log_interval: int = 10000 + create_env_fn: Any = None + actor_network: Any = None + critic_network: Any = None + target_net_updater: Any = None + + _target_: str = "torchrl.trainers.algorithms.configs.trainers._make_sac_trainer" + + def __post_init__(self) -> None: + """Post-initialization hook for SAC trainer configuration.""" + super().__post_init__() + + +def _make_sac_trainer(*args, **kwargs) -> SACTrainer: + from torchrl.trainers.trainers import Logger + + collector = kwargs.pop("collector") + total_frames = kwargs.pop("total_frames") + if total_frames is None: + total_frames = collector.total_frames + frame_skip = kwargs.pop("frame_skip", 1) + optim_steps_per_batch = kwargs.pop("optim_steps_per_batch", 1) + loss_module = kwargs.pop("loss_module") + optimizer = kwargs.pop("optimizer") + logger = kwargs.pop("logger") + clip_grad_norm = kwargs.pop("clip_grad_norm", True) + clip_norm = kwargs.pop("clip_norm") + progress_bar = kwargs.pop("progress_bar", True) + replay_buffer = kwargs.pop("replay_buffer") + save_trainer_interval = kwargs.pop("save_trainer_interval", 10000) + log_interval = kwargs.pop("log_interval", 10000) + save_trainer_file = kwargs.pop("save_trainer_file") + seed = kwargs.pop("seed") + actor_network = kwargs.pop("actor_network") + critic_network = kwargs.pop("critic_network") + create_env_fn = kwargs.pop("create_env_fn") + target_net_updater = kwargs.pop("target_net_updater") + + # Instantiate networks first + if actor_network is not None: + actor_network = actor_network() + if critic_network is not None: + critic_network = critic_network() + + if not isinstance(collector, DataCollectorBase): + # then it's a partial config + collector = collector(create_env_fn=create_env_fn, policy=actor_network) + if not isinstance(loss_module, LossModule): + # then it's a partial config + loss_module = loss_module( + actor_network=actor_network, critic_network=critic_network + ) + if not isinstance(target_net_updater, TargetNetUpdater): + # target_net_updater must be a partial taking the loss as input + target_net_updater = target_net_updater(loss_module) + if not isinstance(optimizer, torch.optim.Optimizer): + # then it's a partial config + optimizer = optimizer(params=loss_module.parameters()) + + # Quick instance checks + if not isinstance(collector, DataCollectorBase): + raise ValueError( + f"collector must be a DataCollectorBase, got {type(collector)}" + ) + if not isinstance(loss_module, LossModule): + raise ValueError(f"loss_module must be a LossModule, got {type(loss_module)}") + if not isinstance(optimizer, torch.optim.Optimizer): + raise ValueError( + f"optimizer must be a torch.optim.Optimizer, got {type(optimizer)}" + ) + if not isinstance(logger, Logger) and logger is not None: + raise ValueError(f"logger must be a Logger, got {type(logger)}") + + return SACTrainer( + collector=collector, + total_frames=total_frames, + frame_skip=frame_skip, + optim_steps_per_batch=optim_steps_per_batch, + loss_module=loss_module, + optimizer=optimizer, + logger=logger, + clip_grad_norm=clip_grad_norm, + clip_norm=clip_norm, + progress_bar=progress_bar, + seed=seed, + save_trainer_interval=save_trainer_interval, + log_interval=log_interval, + save_trainer_file=save_trainer_file, + replay_buffer=replay_buffer, + target_net_updater=target_net_updater, + ) + + @dataclass class PPOTrainerConfig(TrainerConfig): """Configuration class for PPO (Proximal Policy Optimization) trainer. diff --git a/torchrl/trainers/algorithms/sac.py b/torchrl/trainers/algorithms/sac.py new file mode 100644 index 00000000000..caf4180925a --- /dev/null +++ b/torchrl/trainers/algorithms/sac.py @@ -0,0 +1,281 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import pathlib +import warnings + +from collections.abc import Callable + +from functools import partial + +from tensordict import TensorDict, TensorDictBase +from torch import optim + +from torchrl.collectors import DataCollectorBase + +from torchrl.data.replay_buffers.replay_buffers import ReplayBuffer +from torchrl.objectives.common import LossModule +from torchrl.objectives.utils import TargetNetUpdater +from torchrl.record.loggers import Logger +from torchrl.trainers.trainers import ( + LogScalar, + ReplayBufferTrainer, + TargetNetUpdaterHook, + Trainer, + UpdateWeights, +) + + +class SACTrainer(Trainer): + """A trainer class for Soft Actor-Critic (SAC) algorithm. + + This trainer implements the SAC algorithm, an off-policy actor-critic method that + optimizes a stochastic policy in an off-policy way, forming a bridge between + stochastic policy optimization and DDPG-style approaches. SAC incorporates the + entropy measure of the policy into the reward to encourage exploration. + + The trainer handles: + - Replay buffer management for off-policy learning + - Target network updates with configurable update frequency + - Policy weight updates to the data collector + - Comprehensive logging of training metrics + - Gradient clipping and optimization steps + + Args: + collector (DataCollectorBase): The data collector used to gather environment interactions. + total_frames (int): Total number of frames to collect during training. + frame_skip (int): Number of frames to skip between policy updates. + optim_steps_per_batch (int): Number of optimization steps per collected batch. + loss_module (LossModule | Callable): The SAC loss module or a callable that computes losses. + optimizer (optim.Optimizer, optional): The optimizer for training. If None, must be configured elsewhere. + logger (Logger, optional): Logger for recording training metrics. Defaults to None. + clip_grad_norm (bool, optional): Whether to clip gradient norms. Defaults to True. + clip_norm (float, optional): Maximum gradient norm for clipping. Defaults to None. + progress_bar (bool, optional): Whether to show a progress bar during training. Defaults to True. + seed (int, optional): Random seed for reproducibility. Defaults to None. + save_trainer_interval (int, optional): Interval for saving trainer state. Defaults to 10000. + log_interval (int, optional): Interval for logging metrics. Defaults to 10000. + save_trainer_file (str | pathlib.Path, optional): File path for saving trainer state. Defaults to None. + replay_buffer (ReplayBuffer, optional): Replay buffer for storing and sampling experiences. Defaults to None. + batch_size (int, optional): Batch size for sampling from replay buffer. Defaults to None. + enable_logging (bool, optional): Whether to enable metric logging. Defaults to True. + log_rewards (bool, optional): Whether to log reward statistics. Defaults to True. + log_actions (bool, optional): Whether to log action statistics. Defaults to True. + log_observations (bool, optional): Whether to log observation statistics. Defaults to False. + target_net_updater (TargetNetUpdater, optional): Target network updater for soft updates. Defaults to None. + + Example: + >>> from torchrl.collectors import SyncDataCollector + >>> from torchrl.objectives import SACLoss + >>> from torchrl.data import ReplayBuffer, LazyTensorStorage + >>> from torch import optim + >>> + >>> # Set up collector, loss, and replay buffer + >>> collector = SyncDataCollector(env, policy, frames_per_batch=1000) + >>> loss_module = SACLoss(actor_network, qvalue_network) + >>> optimizer = optim.Adam(loss_module.parameters(), lr=3e-4) + >>> replay_buffer = ReplayBuffer(storage=LazyTensorStorage(100000)) + >>> + >>> # Create and run trainer + >>> trainer = SACTrainer( + ... collector=collector, + ... total_frames=1000000, + ... frame_skip=1, + ... optim_steps_per_batch=100, + ... loss_module=loss_module, + ... optimizer=optimizer, + ... replay_buffer=replay_buffer, + ... ) + >>> trainer.train() + + Note: + This is an experimental/prototype feature. The API may change in future versions. + SAC is particularly effective for continuous control tasks and environments where + exploration is crucial due to its entropy regularization. + + """ + + def __init__( + self, + *, + collector: DataCollectorBase, + total_frames: int, + frame_skip: int, + optim_steps_per_batch: int, + loss_module: LossModule | Callable[[TensorDictBase], TensorDictBase], + optimizer: optim.Optimizer | None = None, + logger: Logger | None = None, + clip_grad_norm: bool = True, + clip_norm: float | None = None, + progress_bar: bool = True, + seed: int | None = None, + save_trainer_interval: int = 10000, + log_interval: int = 10000, + save_trainer_file: str | pathlib.Path | None = None, + replay_buffer: ReplayBuffer | None = None, + batch_size: int | None = None, + enable_logging: bool = True, + log_rewards: bool = True, + log_actions: bool = True, + log_observations: bool = False, + target_net_updater: TargetNetUpdater | None = None, + ) -> None: + warnings.warn( + "SACTrainer is an experimental/prototype feature. The API may change in future versions. " + "Please report any issues or feedback to help improve this implementation.", + UserWarning, + stacklevel=2, + ) + # try to get the action spec + self._pass_action_spec_from_collector_to_loss(collector, loss_module) + + super().__init__( + collector=collector, + total_frames=total_frames, + frame_skip=frame_skip, + optim_steps_per_batch=optim_steps_per_batch, + loss_module=loss_module, + optimizer=optimizer, + logger=logger, + clip_grad_norm=clip_grad_norm, + clip_norm=clip_norm, + progress_bar=progress_bar, + seed=seed, + save_trainer_interval=save_trainer_interval, + log_interval=log_interval, + save_trainer_file=save_trainer_file, + ) + self.replay_buffer = replay_buffer + + # Note: SAC can use any sampler type, unlike PPO which requires SamplerWithoutReplacement + + if replay_buffer is not None: + rb_trainer = ReplayBufferTrainer( + replay_buffer, + batch_size=None, + flatten_tensordicts=True, + memmap=False, + device=getattr(replay_buffer.storage, "device", "cpu"), + iterate=True, + ) + + self.register_op("pre_epoch", rb_trainer.extend) + self.register_op("process_optim_batch", rb_trainer.sample) + self.register_op("post_loss", rb_trainer.update_priority) + self.register_op("post_optim", TargetNetUpdaterHook(target_net_updater)) + + policy_weights_getter = partial( + TensorDict.from_module, self.loss_module.actor_network + ) + update_weights = UpdateWeights( + self.collector, 1, policy_weights_getter=policy_weights_getter + ) + self.register_op("post_steps", update_weights) + + # Store logging configuration + self.enable_logging = enable_logging + self.log_rewards = log_rewards + self.log_actions = log_actions + self.log_observations = log_observations + + # Set up comprehensive logging for SAC training + if self.enable_logging: + self._setup_sac_logging() + + def _pass_action_spec_from_collector_to_loss( + self, collector: DataCollectorBase, loss: LossModule + ): + """Pass the action specification from the collector's environment to the loss module. + + This method extracts the action specification from the collector's environment + and assigns it to the loss module if the loss module doesn't already have one. + This is necessary for SAC loss computation which requires knowledge of the + action space bounds for proper entropy calculation and action clipping. + + Args: + collector (DataCollectorBase): The data collector containing the environment. + loss (LossModule): The loss module that needs the action specification. + """ + if hasattr(loss, "_action_spec") and loss._action_spec is None: + action_spec = collector.getattr_env("full_action_spec_unbatched").cpu() + loss._action_spec = action_spec + + def _setup_sac_logging(self): + """Set up logging hooks for SAC-specific metrics. + + This method configures logging for common SAC metrics including: + - Training rewards (mean, max, total, and std) + - Action statistics (action norms) + - Episode completion rates (done percentage) + - Observation statistics (when enabled) + - Q-value and policy loss metrics (handled by loss module) + """ + # Always log done states as percentage (episode completion rate) + log_done_percentage = LogScalar( + key=("next", "done"), + logname="done_percentage", + log_pbar=True, + include_std=False, # No std for binary values + reduction="mean", + ) + self.register_op("pre_steps_log", log_done_percentage) + + # Log rewards if enabled + if self.log_rewards: + # 1. Log training rewards (most important metric for SAC) + log_rewards = LogScalar( + key=("next", "reward"), + logname="r_training", + log_pbar=True, # Show in progress bar + include_std=True, + reduction="mean", + ) + self.register_op("pre_steps_log", log_rewards) + + # 2. Log maximum reward in batch (for monitoring best performance) + log_max_reward = LogScalar( + key=("next", "reward"), + logname="r_max", + log_pbar=False, + include_std=False, + reduction="max", + ) + self.register_op("pre_steps_log", log_max_reward) + + # 3. Log total reward in batch (for monitoring cumulative performance) + log_total_reward = LogScalar( + key=("next", "reward"), + logname="r_total", + log_pbar=False, + include_std=False, + reduction="sum", + ) + self.register_op("pre_steps_log", log_total_reward) + + # Log actions if enabled + if self.log_actions: + # 4. Log action norms (useful for monitoring policy behavior) + log_action_norm = LogScalar( + key="action", + logname="action_norm", + log_pbar=False, + include_std=True, + reduction="mean", + ) + self.register_op("pre_steps_log", log_action_norm) + + # Log observations if enabled + if self.log_observations: + # 5. Log observation statistics (for monitoring state distributions) + log_obs_norm = LogScalar( + key="observation", + logname="obs_norm", + log_pbar=False, + include_std=True, + reduction="mean", + ) + self.register_op("pre_steps_log", log_obs_norm) diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index a8cca232674..4f7006b97ff 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -18,6 +18,7 @@ import numpy as np import torch.nn from tensordict import NestedKey, pad, TensorDictBase +from tensordict._tensorcollection import TensorCollection from tensordict.nn import TensorDictModule from tensordict.utils import expand_right from torch import nn, optim @@ -39,6 +40,7 @@ from torchrl.envs.common import EnvBase from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.objectives.common import LossModule +from torchrl.objectives.utils import TargetNetUpdater from torchrl.record.loggers import Logger try: @@ -1134,7 +1136,7 @@ def __call__(self, batch: TensorDictBase) -> dict: return result - def register(self, trainer: Trainer, name: str = None): + def register(self, trainer: Trainer, name: str | None = None): if name is None: name = f"log_{self.logname}" trainer.register_op("pre_steps_log", self) @@ -1719,3 +1721,29 @@ def flatten_dict(d): else: out[key] = item return out + + +class TargetNetUpdaterHook: + """A hook for target parameters update. + + Examples: + >>> # define a loss module + >>> loss_module = SACLoss(actor_network, qvalue_network) + >>> # define a target network updater + >>> target_net_updater = SoftUpdate(loss_module) + >>> # define a target network updater hook + >>> target_net_updater_hook = TargetNetUpdaterHook(target_net_updater) + >>> # register the target network updater hook + >>> trainer.register_op("post_optim", target_net_updater_hook) + """ + + def __init__(self, target_params_updater: TargetNetUpdater): + if not isinstance(target_params_updater, TargetNetUpdater): + raise ValueError( + f"Expected a target network updater, got {type(target_params_updater)=}" + ) + self.target_params_updater = target_params_updater + + def __call__(self, tensordict: TensorCollection): + self.target_params_updater.step() + return tensordict