Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 83 additions & 39 deletions docs/source/reference/trainers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ Trainer and hooks
Trainer
TrainerHookBase
UpdateWeights
TargetNetUpdaterHook


Algorithm-specific trainers (Experimental)
Expand All @@ -202,37 +203,54 @@ into complete training solutions with sensible defaults and comprehensive config
:template: rl_template.rst

PPOTrainer
SACTrainer

PPOTrainer
~~~~~~~~~~
Algorithm Trainers
~~~~~~~~~~~~~~~~~~

The :class:`~torchrl.trainers.algorithms.PPOTrainer` provides a complete PPO training solution
with configurable defaults and a comprehensive configuration system built on Hydra.
TorchRL provides high-level algorithm trainers that offer complete training solutions with minimal code.
These trainers feature comprehensive configuration systems built on Hydra, enabling both simple usage
and sophisticated customization.

**Currently Available:**

- :class:`~torchrl.trainers.algorithms.PPOTrainer` - Proximal Policy Optimization
- :class:`~torchrl.trainers.algorithms.SACTrainer` - Soft Actor-Critic

**Key Features:**

- Complete training pipeline with environment setup, data collection, and optimization
- Extensive configuration system using dataclasses and Hydra
- Built-in logging for rewards, actions, and training statistics
- Modular design built on existing TorchRL components
- **Minimal code**: Complete SOTA implementation in just ~20 lines!
- **Complete pipeline**: Environment setup, data collection, and optimization
- **Hydra configuration**: Extensive dataclass-based configuration system
- **Built-in logging**: Rewards, actions, and algorithm-specific metrics
- **Modular design**: Built on existing TorchRL components
- **Minimal code**: Complete SOTA implementations in ~20 lines!

.. warning::
This is an experimental feature. The API may change in future versions.
We welcome feedback and contributions to help improve this implementation!
Algorithm trainers are experimental features. The API may change in future versions.
We welcome feedback and contributions to help improve these implementations!

**Quick Start - Command Line Interface:**
Quick Start Examples
^^^^^^^^^^^^^^^^^^^^

**PPO Training:**

.. code-block:: bash

# Basic usage - train PPO on Pendulum-v1 with default settings
# Train PPO on Pendulum-v1 with default settings
python sota-implementations/ppo_trainer/train.py

**SAC Training:**

.. code-block:: bash

# Train SAC on a continuous control task
python sota-implementations/sac_trainer/train.py

**Custom Configuration:**

.. code-block:: bash

# Override specific parameters via command line
# Override parameters for any algorithm
python sota-implementations/ppo_trainer/train.py \
trainer.total_frames=2000000 \
training_env.create_env_fn.base_env.env_name=HalfCheetah-v4 \
Expand All @@ -243,32 +261,34 @@ with configurable defaults and a comprehensive configuration system built on Hyd

.. code-block:: bash

# Switch to a different environment and logger
python sota-implementations/ppo_trainer/train.py \
env=gym \
# Switch environment and logger for any trainer
python sota-implementations/sac_trainer/train.py \
training_env.create_env_fn.base_env.env_name=Walker2d-v4 \
logger=tensorboard
logger=tensorboard \
logger.exp_name=sac_walker2d

**See All Options:**
**View Configuration Options:**

.. code-block:: bash

# View all available configuration options
# See all available options for any trainer
python sota-implementations/ppo_trainer/train.py --help
python sota-implementations/sac_trainer/train.py --help

**Configuration Groups:**
Universal Configuration System
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

The PPOTrainer configuration is organized into logical groups:
All algorithm trainers share a unified configuration architecture organized into logical groups:

- **Environment**: ``env_cfg__env_name``, ``env_cfg__backend``, ``env_cfg__device``
- **Networks**: ``actor_network__network__num_cells``, ``critic_network__module__num_cells``
- **Training**: ``total_frames``, ``clip_norm``, ``num_epochs``, ``optimizer_cfg__lr``
- **Logging**: ``log_rewards``, ``log_actions``, ``log_observations``
- **Environment**: ``training_env.create_env_fn.base_env.env_name``, ``training_env.num_workers``
- **Networks**: ``networks.policy_network.num_cells``, ``networks.value_network.num_cells``
- **Training**: ``trainer.total_frames``, ``trainer.clip_norm``, ``optimizer.lr``
- **Data**: ``collector.frames_per_batch``, ``replay_buffer.batch_size``, ``replay_buffer.storage.max_size``
- **Logging**: ``logger.exp_name``, ``logger.project``, ``trainer.log_interval``

**Working Example:**

The `sota-implementations/ppo_trainer/ <https://github.com/pytorch/rl/tree/main/sota-implementations/ppo_trainer>`_
directory contains a complete, working PPO implementation that demonstrates the simplicity and power of the trainer system:
All trainer implementations follow the same simple pattern:

.. code-block:: python

Expand All @@ -283,33 +303,57 @@ directory contains a complete, working PPO implementation that demonstrates the
if __name__ == "__main__":
main()

*Complete PPO training with full configurability in ~20 lines!*
*Complete algorithm training with full configurability in ~20 lines!*

**Configuration Classes:**
Configuration Classes
^^^^^^^^^^^^^^^^^^^^^

The PPOTrainer uses a hierarchical configuration system with these main config classes.
The trainer system uses a hierarchical configuration system with shared components.

.. note::
The configuration system requires Python 3.10+ due to its use of modern type annotation syntax.

- **Trainer**: :class:`~torchrl.trainers.algorithms.configs.trainers.PPOTrainerConfig`
**Algorithm-Specific Trainers:**

- **PPO**: :class:`~torchrl.trainers.algorithms.configs.trainers.PPOTrainerConfig`
- **SAC**: :class:`~torchrl.trainers.algorithms.configs.trainers.SACTrainerConfig`

**Shared Configuration Components:**

- **Environment**: :class:`~torchrl.trainers.algorithms.configs.envs_libs.GymEnvConfig`, :class:`~torchrl.trainers.algorithms.configs.envs.BatchedEnvConfig`
- **Networks**: :class:`~torchrl.trainers.algorithms.configs.modules.MLPConfig`, :class:`~torchrl.trainers.algorithms.configs.modules.TanhNormalModelConfig`
- **Data**: :class:`~torchrl.trainers.algorithms.configs.data.TensorDictReplayBufferConfig`, :class:`~torchrl.trainers.algorithms.configs.collectors.MultiaSyncDataCollectorConfig`
- **Objectives**: :class:`~torchrl.trainers.algorithms.configs.objectives.PPOLossConfig`
- **Objectives**: :class:`~torchrl.trainers.algorithms.configs.objectives.PPOLossConfig`, :class:`~torchrl.trainers.algorithms.configs.objectives.SACLossConfig`
- **Optimizers**: :class:`~torchrl.trainers.algorithms.configs.utils.AdamConfig`, :class:`~torchrl.trainers.algorithms.configs.utils.AdamWConfig`
- **Logging**: :class:`~torchrl.trainers.algorithms.configs.logging.WandbLoggerConfig`, :class:`~torchrl.trainers.algorithms.configs.logging.TensorboardLoggerConfig`

Algorithm-Specific Features
^^^^^^^^^^^^^^^^^^^^^^^^^^^

**PPOTrainer:**

- On-policy learning with advantage estimation
- Policy clipping and value function optimization
- Configurable number of epochs per batch
- Built-in GAE (Generalized Advantage Estimation)

**SACTrainer:**

- Off-policy learning with replay buffer
- Entropy-regularized policy optimization
- Target network soft updates
- Continuous action space optimization

**Future Development:**

This is the first of many planned algorithm-specific trainers. Future releases will include:
The trainer system is actively expanding. Upcoming features include:

- Additional algorithms: SAC, TD3, DQN, A2C, and more
- Full integration of all TorchRL components within the configuration system
- Enhanced configuration validation and error reporting
- Distributed training support for high-level trainers
- Additional algorithms: TD3, DQN, A2C, DDPG, and more
- Enhanced distributed training support
- Advanced configuration validation and error reporting
- Integration with more TorchRL ecosystem components

See the complete `configuration system documentation <https://github.com/pytorch/rl/tree/main/torchrl/trainers/algorithms/configs>`_ for all available options.
See the complete `configuration system documentation <https://github.com/pytorch/rl/tree/main/torchrl/trainers/algorithms/configs>`_ for all available options and examples.


Builders
Expand Down
146 changes: 146 additions & 0 deletions sota-implementations/sac_trainer/config/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# SAC Trainer Configuration for HalfCheetah-v4
# This configuration uses the new configurable trainer system and matches SOTA SAC implementation

defaults:

- transform@transform0: step_counter
- transform@transform1: double_to_float

- env@training_env: batched_env
- env@training_env.create_env_fn: transformed_env
- env@training_env.create_env_fn.base_env: gym
- transform@training_env.create_env_fn.transform: compose

- model@models.policy_model: tanh_normal
- model@models.value_model: value
- model@models.qvalue_model: value

- network@networks.policy_network: mlp
- network@networks.value_network: mlp
- network@networks.qvalue_network: mlp

- collector@collector: multi_async

- replay_buffer@replay_buffer: base
- storage@replay_buffer.storage: lazy_tensor
- writer@replay_buffer.writer: round_robin
- sampler@replay_buffer.sampler: random
- trainer@trainer: sac
- optimizer@optimizer: adam
- loss@loss: sac
- target_net_updater@target_net_updater: soft
- logger@logger: wandb
- _self_

# Network configurations
networks:
policy_network:
out_features: 12 # HalfCheetah action space is 6-dimensional (loc + scale)
in_features: 17 # HalfCheetah observation space is 17-dimensional
num_cells: [256, 256]

value_network:
out_features: 1 # Value output
in_features: 17 # HalfCheetah observation space
num_cells: [256, 256]

qvalue_network:
out_features: 1 # Q-value output
in_features: 23 # HalfCheetah observation space (17) + action space (6)
num_cells: [256, 256]

# Model configurations
models:
policy_model:
return_log_prob: true
in_keys: ["observation"]
param_keys: ["loc", "scale"]
out_keys: ["action"]
network: ${networks.policy_network}

qvalue_model:
in_keys: ["observation", "action"]
out_keys: ["state_action_value"]
network: ${networks.qvalue_network}

transform0:
max_steps: 1000
step_count_key: "step_count"

transform1:
# DoubleToFloatTransform - converts double precision to float to fix dtype mismatch
in_keys: null
out_keys: null

training_env:
num_workers: 4
create_env_fn:
base_env:
env_name: HalfCheetah-v4
transform:
transforms:
- ${transform0}
- ${transform1}
_partial_: true

# Loss configuration
loss:
actor_network: ${models.policy_model}
qvalue_network: ${models.qvalue_model}
target_entropy: "auto"
loss_function: l2
alpha_init: 1.0
delay_qvalue: true
num_qvalue_nets: 2

target_net_updater:
tau: 0.001

# Optimizer configuration
optimizer:
lr: 3.0e-4

# Collector configuration
collector:
create_env_fn: ${training_env}
policy: ${models.policy_model}
total_frames: 1_000_000
frames_per_batch: 1000
num_workers: 4
init_random_frames: 25000
track_policy_version: true

# Replay buffer configuration
replay_buffer:
storage:
max_size: 1_000_000
device: cpu
ndim: 1
sampler:
writer:
compilable: false
batch_size: 256

logger:
exp_name: sac_halfcheetah_v4
offline: false
project: torchrl-sota-implementations

# Trainer configuration
trainer:
collector: ${collector}
optimizer: ${optimizer}
replay_buffer: ${replay_buffer}
target_net_updater: ${target_net_updater}
loss_module: ${loss}
logger: ${logger}
total_frames: 1_000_000
frame_skip: 1
clip_grad_norm: false # SAC typically doesn't use gradient clipping
clip_norm: null
progress_bar: true
seed: 42
save_trainer_interval: 25000 # Match SOTA eval_iter
log_interval: 25000
save_trainer_file: null
optim_steps_per_batch: 64 # Match SOTA utd_ratio
21 changes: 21 additions & 0 deletions sota-implementations/sac_trainer/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import hydra
import torchrl
from torchrl.trainers.algorithms.configs import * # noqa: F401, F403


@hydra.main(config_path="config", config_name="config", version_base="1.1")
def main(cfg):
def print_reward(td):
torchrl.logger.info(f"reward: {td['next', 'reward'].mean(): 4.4f}")

trainer = hydra.utils.instantiate(cfg.trainer)
trainer.register_op(dest="batch_process", op=print_reward)
trainer.train()


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions torchrl/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
ReplayBufferTrainer,
RewardNormalizer,
SelectKeys,
TargetNetUpdaterHook,
Trainer,
TrainerHookBase,
UpdateWeights,
Expand All @@ -37,4 +38,5 @@
"Trainer",
"TrainerHookBase",
"UpdateWeights",
"TargetNetUpdaterHook",
]
3 changes: 2 additions & 1 deletion torchrl/trainers/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
from __future__ import annotations

from .ppo import PPOTrainer
from .sac import SACTrainer

__all__ = ["PPOTrainer"]
__all__ = ["PPOTrainer", "SACTrainer"]
Loading
Loading