Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
297 changes: 257 additions & 40 deletions docs/source/reference/collectors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,19 @@ try to limit the cases where a deepcopy will be executed. The following chart sh

Policy copy decision tree in Collectors.

Weight Synchronization using Weight Update Schemes
--------------------------------------------------
Weight Synchronization
----------------------

RL pipelines are typically split in two big computational buckets: training, and inference.
While the inference pipeline sends data to the training one, the training pipeline needs to occasionally
synchronize its weights with the inference one.
In the most basic setting (fully synchronized data collection with traditional neural networks), the same weights are
used in both instances. From there, anything can happen:
In reinforcement learning, the training pipeline is typically split into two computational phases:
**inference** (data collection) and **training** (policy optimization). While the inference pipeline
sends data to the training one, the training pipeline needs to periodically synchronize its weights
with the inference workers to ensure they collect data using up-to-date policies.

Overview & Motivation
~~~~~~~~~~~~~~~~~~~~~

In the simplest setting, the same policy weights are used in both training and inference. However,
real-world RL systems often face additional complexity:

- In multiprocessed or distributed settings, several copies of the policy can be held by the inference workers (named
`DataCollectors` in TorchRL). When synchronizing the weights, each worker needs to receive a new copy of the weights
Expand All @@ -140,15 +145,222 @@ used in both instances. From there, anything can happen:
asks for new weights, or must it only be the trainer who pushes its weights to the workers? An intermediate approach
is to store the weights on some intermediary server and let the workers fetch them when necessary.

TorchRL tries to account for each of these problems in a flexible manner. We individuate four basic components in a weight
transfer:
Key Challenges
^^^^^^^^^^^^^^

Modern RL training often involves multiple models that need independent synchronization:

1. **Multiple Models Per Collector**: A collector might need to update:

- The main policy network
- A value network in a Ray actor within the replay buffer
- Models embedded in the environment itself
- Separate world models or auxiliary networks

2. **Different Update Strategies**: Each model may require different synchronization approaches:

- Full state_dict transfer vs. TensorDict-based updates
- Different transport mechanisms (multiprocessing pipes, shared memory, Ray object store, collective communication, RDMA, etc.)
- Varied update frequencies

3. **Worker-Agnostic Updates**: Some models (like those in shared Ray actors) shouldn't be tied to
specific worker indices, requiring a more flexible update mechanism.

The Solution
^^^^^^^^^^^^

TorchRL addresses these challenges through a flexible, modular architecture built around four components:

- **WeightSyncScheme**: Defines *what* to synchronize and *how* (user-facing configuration)
- **WeightSender**: Handles distributing weights from the main process to workers (internal)
- **WeightReceiver**: Handles applying weights in worker processes (internal)
- **TransportBackend**: Manages the actual communication layer (internal)

This design allows you to independently configure synchronization for multiple models,
choose appropriate transport mechanisms, and swap strategies without rewriting your training code.

Architecture & Concepts
~~~~~~~~~~~~~~~~~~~~~~~

Component Roles
^^^^^^^^^^^^^^^

- A `Sender` class that somehow gets the weights (or a reference to them) and initializes the transfer;
- A `Receiver` class that casts the weights to the destination module (policy or other utility module);
- A `Transport` class that codes up the actual transfer of the weights (through shared memory, nccl or anything else).
- A Scheme that defines what sender, receiver and transport have to be used and how to initialize them.
The weight synchronization system separates concerns into four distinct layers:

Each of these classes is detailed below.
1. **WeightSyncScheme** (User-Facing)

This is your main configuration interface. You create scheme objects that define:

- The synchronization strategy (``"state_dict"`` or ``"tensordict"``)
- The transport mechanism (multiprocessing pipes, shared memory, Ray, RPC, etc.)
- Additional options like auto-registration and timeout behavior

When working with collectors, you pass a dictionary mapping model IDs to schemes.

2. **WeightSender** (Internal)

Created by the scheme in the main training process. The sender:

- Holds a reference to the source model
- Manages transport connections to all workers
- Extracts weights using the configured strategy
- Broadcasts weight updates across all transports

3. **WeightReceiver** (Internal)

Created by the scheme in each worker process. The receiver:

- Holds a reference to the destination model
- Polls its transport for weight updates
- Applies received weights using the configured strategy
- Handles model registration and initialization

4. **TransportBackend** (Internal)

Implements the actual communication mechanism:

- ``MPTransport``: Uses multiprocessing pipes
- ``SharedMemTransport``: Uses shared memory buffers (zero-copy)
- ``RayTransport``: Uses Ray's object store
- ``RPCTransport``: Uses PyTorch RPC
- ``DistributedTransport``: Uses collective communication (NCCL, Gloo, MPI)

Initialization Phase
^^^^^^^^^^^^^^^^^^^^

When you create a collector with weight sync schemes, the following initialization occurs:

.. aafig::
:aspect: 60
:scale: 130
:proportional:

INITIALIZATION PHASE
====================

WeightSyncScheme
+------------------+
| |
| Configuration: |
| - strategy |
| - transport_type |
| |
+--------+---------+
|
+------------+-------------+
| |
creates creates
| |
v v
Main Process Worker Process
+--------------+ +---------------+
| WeightSender | | WeightReceiver|
| | | |
| - strategy | | - strategy |
| - transports | | - transport |
| - model_ref | | - model_ref |
| | | |
| Registers: | | Registers: |
| - model | | - model |
| - workers | | - transport |
+--------------+ +---------------+
| |
| Transport Layer |
| +----------------+ |
+-->+ MPTransport |<------+
| | (pipes) | |
| +----------------+ |
| +----------------+ |
+-->+ SharedMemTrans |<------+
| | (shared mem) | |
| +----------------+ |
| +----------------+ |
+-->+ RayTransport |<------+
| (Ray store) |
+----------------+

The scheme creates a sender in the main process and a receiver in each worker, then establishes
transport connections between them.

Synchronization Phase
^^^^^^^^^^^^^^^^^^^^^

When you call ``collector.update_policy_weights_()``, the weight synchronization proceeds as follows:

.. aafig::
:aspect: 60
:scale: 130
:proportional:

SYNCHRONIZATION PHASE
=====================

Main Process Worker Process

+-------------------+ +-------------------+
| WeightSender | | WeightReceiver |
| | | |
| 1. Extract | | 4. Poll transport |
| weights from | | for weights |
| model using | | |
| strategy | | |
| | 2. Send via | |
| +-------------+ | Transport | +--------------+ |
| | Strategy | | +------------+ | | Strategy | |
| | extract() | | | | | | apply() | |
| +-------------+ +----+ Transport +-------->+ +--------------+ |
| | | | | | | |
| v | +------------+ | v |
| +-------------+ | | +--------------+ |
| | Model | | | | Model | |
| | (source) | | 3. Ack (optional) | | (dest) | |
| +-------------+ | <-----------------------+ | +--------------+ |
| | | |
+-------------------+ | 5. Apply weights |
| to model using |
| strategy |
+-------------------+

1. **Extract**: Sender extracts weights from the source model (state_dict or TensorDict)
2. **Send**: Sender broadcasts weights through all registered transports
3. **Acknowledge** (optional): Some transports send acknowledgment back
4. **Poll**: Receiver checks its transport for new weights
5. **Apply**: Receiver applies weights to the destination model

Multi-Model Synchronization
^^^^^^^^^^^^^^^^^^^^^^^^^^^

One of the key features is support for synchronizing multiple models independently:

.. aafig::
:aspect: 60
:scale: 130
:proportional:

Main Process Worker Process 1 Worker Process 2

+-----------------+ +---------------+ +---------------+
| Collector | | Collector | | Collector |
| | | | | |
| Models: | | Models: | | Models: |
| +----------+ | | +--------+ | | +--------+ |
| | Policy A | | | |Policy A| | | |Policy A| |
| +----------+ | | +--------+ | | +--------+ |
| +----------+ | | +--------+ | | +--------+ |
| | Model B | | | |Model B| | | |Model B| |
| +----------+ | | +--------+ | | +--------+ |
| | | | | |
| Weight Senders: | | Weight | | Weight |
| +----------+ | | Receivers: | | Receivers: |
| | Sender A +---+------------+->Receiver A | | Receiver A |
| +----------+ | | | | |
| +----------+ | | +--------+ | | +--------+ |
| | Sender B +---+------------+->Receiver B | | Receiver B |
| +----------+ | Pipes | | Pipes | |
+-----------------+ +-------+-------+ +-------+-------+

Each model gets its own sender/receiver pair, allowing independent synchronization frequencies,
different transport mechanisms per model, and model-specific strategies.

Usage Examples
~~~~~~~~~~~~~~
Expand Down Expand Up @@ -301,32 +513,55 @@ across multiple inference workers:
dictionaries, while ``"tensordict"`` uses TensorDict format which can be more efficient for structured
models and supports advanced features like lazy initialization.

Weight Senders
~~~~~~~~~~~~~~
API Reference
~~~~~~~~~~~~~

The weight synchronization system provides both user-facing configuration classes and internal
implementation classes that are automatically managed by the collectors.

Schemes (User-Facing Configuration)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

These are the main classes you'll use to configure weight synchronization. Pass them in the
``weight_sync_schemes`` dictionary when creating collectors.

.. currentmodule:: torchrl.weight_update

.. autosummary::
:toctree: generated/
:template: rl_template.rst

WeightSender
RayModuleTransformSender
WeightSyncScheme
MultiProcessWeightSyncScheme
SharedMemWeightSyncScheme
NoWeightSyncScheme
RayWeightSyncScheme
RayModuleTransformScheme
RPCWeightSyncScheme
DistributedWeightSyncScheme

Senders and Receivers (Internal)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Weight Receivers
~~~~~~~~~~~~~~~~
These classes are automatically created and managed by the schemes. You typically don't need
to interact with them directly.

.. currentmodule:: torchrl.weight_update

.. autosummary::
:toctree: generated/
:template: rl_template.rst

WeightSender
WeightReceiver
RayModuleTransformSender
RayModuleTransformReceiver

Transports
~~~~~~~~~~
Transport Backends (Internal)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Transport classes handle the actual communication between processes. They are automatically
selected and configured by the schemes.

.. currentmodule:: torchrl.weight_update

Expand All @@ -342,24 +577,6 @@ Transports
RPCTransport
DistributedTransport

Schemes
~~~~~~~

.. currentmodule:: torchrl.weight_update

.. autosummary::
:toctree: generated/
:template: rl_template.rst

WeightSyncScheme
MultiProcessWeightSyncScheme
SharedMemWeightSyncScheme
NoWeightSyncScheme
RayWeightSyncScheme
RayModuleTransformScheme
RPCWeightSyncScheme
DistributedWeightSyncScheme

Legacy: Weight Synchronization in Distributed Environments
----------------------------------------------------------

Expand Down
Loading