From e9756e224e8be18c966aa4e4ff604cb73bd20e55 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 14 Oct 2025 09:41:05 +0100 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- docs/source/reference/collectors.rst | 242 +++++++++++++++++++++++++++ 1 file changed, 242 insertions(+) diff --git a/docs/source/reference/collectors.rst b/docs/source/reference/collectors.rst index ca8dca38e4e..b598df3e173 100644 --- a/docs/source/reference/collectors.rst +++ b/docs/source/reference/collectors.rst @@ -169,6 +169,248 @@ transformed, and applied, ensuring seamless integration with their existing infr RPCWeightUpdater DistributedWeightUpdater +Weight Synchronization API +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The weight synchronization API provides a simple, modular approach to updating model weights across +distributed collectors. This system is designed to handle the complexities of modern RL setups where multiple +models may need to be synchronized independently. + +Overview +^^^^^^^^ + +In reinforcement learning, particularly with multi-process data collection, it's essential to keep the inference +policies synchronized with the latest trained weights. The API addresses this challenge through a clean +separation of concerns, where four classes are involved: + +- **Configuration**: :class:`~torchrl.weight_update.weight_sync_schemes.WeightSyncScheme` objects define *what* to synchronize and *how*. For DataCollectors, this is + your main entrypoint to configure the weight synchronization. +- **Sending**: :class:`~torchrl.weight_update.weight_sync_schemes.WeightSender` handles distributing weights from the main process to workers. +- **Receiving**: :class:`~torchrl.weight_update.weight_sync_schemes.WeightReceiver` handles applying weights in worker processes. +- **Transport**: Backend-specific communication mechanisms (pipes, shared memory, Ray, RPC) + +The following diagram shows the different classes involved in the weight synchronization process: + +.. aafig:: + :aspect: 60 + :scale: 130 + :proportional: + + INITIALIZATION PHASE + ==================== + + WeightSyncScheme + +------------------+ + | | + | Configuration: | + | - strategy | + | - transport_type | + | | + +--------+---------+ + | + +------------+-------------+ + | | + creates creates + | | + v v + Main Process Worker Process + +--------------+ +---------------+ + | WeightSender | | WeightReceiver| + | | | | + | - strategy | | - strategy | + | - transports | | - transport | + | - model_ref | | - model_ref | + | | | | + | Registers: | | Registers: | + | - model | | - model | + | - workers | | - transport | + +--------------+ +---------------+ + | | + | Transport Layer | + | +----------------+ | + +-->+ MPTransport |<------+ + | | (pipes) | | + | +----------------+ | + | +----------------+ | + +-->+ SharedMemTrans |<------+ + | | (shared mem) | | + | +----------------+ | + | +----------------+ | + +-->+ RayTransport |<------+ + | (Ray store) | + +----------------+ + + + SYNCHRONIZATION PHASE + ===================== + + Main Process Worker Process + + +-------------------+ +-------------------+ + | WeightSender | | WeightReceiver | + | | | | + | 1. Extract | | 4. Poll transport | + | weights from | | for weights | + | model using | | | + | strategy | | | + | | 2. Send via | | + | +-------------+ | Transport | +--------------+ | + | | Strategy | | +------------+ | | Strategy | | + | | extract() | | | | | | apply() | | + | +-------------+ +----+ Transport +-------->+ +--------------+ | + | | | | | | | | + | v | +------------+ | v | + | +-------------+ | | +--------------+ | + | | Model | | | | Model | | + | | (source) | | 3. Ack (optional) | | (dest) | | + | +-------------+ | <-----------------------+ | +--------------+ | + | | | | + +-------------------+ | 5. Apply weights | + | to model using | + | strategy | + +-------------------+ + +Key Challenges Addressed +^^^^^^^^^^^^^^^^^^^^^^^^^ + +Modern RL training often involves multiple models that need independent synchronization: + +1. **Multiple Models Per Collector**: A collector might need to update: + + - The main policy network + - A value network in a Ray actor within the replay buffer + - Models embedded in the environment itself + - Separate world models or auxiliary networks + +2. **Different Update Strategies**: Each model may require different synchronization approaches: + + - Full state_dict transfer vs. TensorDict-based updates + - Different transport mechanisms (multiprocessing pipes, shared memory, Ray object store, collective communication, RDMA, etc.) + - Varied update frequencies + +3. **Worker-Agnostic Updates**: Some models (like those in shared Ray actors) shouldn't be tied to + specific worker indices, requiring a more flexible update mechanism. + +Architecture +^^^^^^^^^^^^ + +The API follows a scheme-based design where users specify synchronization requirements upfront, +and the collector handles the orchestration transparently: + +.. aafig:: + :aspect: 60 + :scale: 130 + :proportional: + + Main Process Worker Process 1 Worker Process 2 + + +-----------------+ +---------------+ +---------------+ + | Collector | | Collector | | Collector | + | | | | | | + | Models: | | Models: | | Models: | + | +----------+ | | +--------+ | | +--------+ | + | | Policy A | | | |Policy A| | | |Policy A| | + | +----------+ | | +--------+ | | +--------+ | + | +----------+ | | +--------+ | | +--------+ | + | | Model B | | | |Model B| | | |Model B| | + | +----------+ | | +--------+ | | +--------+ | + | | | | | | + | Weight Senders: | | Weight | | Weight | + | +----------+ | | Receivers: | | Receivers: | + | | Sender A +---+------------+->Receiver A | | Receiver A | + | +----------+ | | | | | + | +----------+ | | +--------+ | | +--------+ | + | | Sender B +---+------------+->Receiver B | | Receiver B | + | +----------+ | Pipes | | Pipes | | + +-----------------+ +-------+-------+ +-------+-------+ + ^ ^ ^ + | | | + | update_policy_weights_() | Apply weights | + | | | + +------+-------+ | | + | User Code | | | + | (Training) | | | + +--------------+ +------------------------+ + +The weight synchronization flow: + +1. **Initialization**: User creates ``weight_sync_schemes`` dict mapping model IDs to schemes +2. **Registration**: Collector creates ``WeightSender`` for each model in the main process +3. **Worker Setup**: Each worker creates corresponding ``WeightReceiver`` instances +4. **Synchronization**: Calling ``update_policy_weights_()`` triggers all senders to push weights +5. **Application**: Receivers automatically apply weights to their registered models + +Available Classes +^^^^^^^^^^^^^^^^^ + +**Synchronization Schemes** (User-Facing Configuration): + +- :class:`~torchrl.weight_update.weight_sync_schemes.WeightSyncScheme`: Base class for schemes +- :class:`~torchrl.weight_update.weight_sync_schemes.MultiProcessWeightSyncScheme`: For multiprocessing with pipes +- :class:`~torchrl.weight_update.weight_sync_schemes.SharedMemWeightSyncScheme`: For shared memory synchronization +- :class:`~torchrl.weight_update.weight_sync_schemes.RayWeightSyncScheme`: For Ray-based distribution +- :class:`~torchrl.weight_update.weight_sync_schemes.NoWeightSyncScheme`: Dummy scheme for no synchronization + +**Internal Classes** (Automatically Managed): + +- :class:`~torchrl.weight_update.weight_sync_schemes.WeightSender`: Sends weights to all workers for one model +- :class:`~torchrl.weight_update.weight_sync_schemes.WeightReceiver`: Receives and applies weights in worker +- :class:`~torchrl.weight_update.weight_sync_schemes.TransportBackend`: Communication layer abstraction + +Usage Example +^^^^^^^^^^^^^ + +.. code-block:: python + + from torchrl.collectors import MultiSyncDataCollector + from torchrl.weight_update.weight_sync_schemes import MultiProcessWeightSyncScheme + + # Define synchronization for multiple models + weight_sync_schemes = { + "policy": MultiProcessWeightSyncScheme(strategy="tensordict"), + "value_net": MultiProcessWeightSyncScheme(strategy="state_dict"), + } + + collector = MultiSyncDataCollector( + create_env_fn=[make_env] * 4, + policy=policy, + frames_per_batch=1000, + weight_sync_schemes=weight_sync_schemes, # Pass schemes dict + ) + + # Single call updates all registered models across all workers + for i, batch in enumerate(collector): + # Training step + loss = train(batch) + + # Sync all models with one call + collector.update_policy_weights_(policy) + +The collector automatically: + +- Creates ``WeightSender`` instances in the main process for each model +- Creates ``WeightReceiver`` instances in each worker process +- Resolves models by ID (e.g., ``"policy"`` → ``collector.policy``) +- Handles transport setup and communication +- Applies weights using the appropriate strategy (state_dict vs tensordict) + +API Reference +^^^^^^^^^^^^^ + +.. currentmodule:: torchrl.weight_update.weight_sync_schemes + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + WeightSyncScheme + MultiProcessWeightSyncScheme + SharedMemWeightSyncScheme + RayWeightSyncScheme + NoWeightSyncScheme + WeightSender + WeightReceiver + Collectors and replay buffers interoperability ---------------------------------------------- From ba7d03f57b8a65cf7f3a938648c0697dbbc52a36 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sat, 18 Oct 2025 16:15:22 -0700 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- docs/source/reference/collectors.rst | 539 +++++++++--------- examples/collectors/weight_sync_collectors.py | 68 +-- examples/collectors/weight_sync_standalone.py | 108 ++-- torchrl/weight_update/weight_sync_schemes.py | 4 +- 4 files changed, 350 insertions(+), 369 deletions(-) diff --git a/docs/source/reference/collectors.rst b/docs/source/reference/collectors.rst index a35e57e7362..a54f1aa3592 100644 --- a/docs/source/reference/collectors.rst +++ b/docs/source/reference/collectors.rst @@ -117,14 +117,19 @@ try to limit the cases where a deepcopy will be executed. The following chart sh Policy copy decision tree in Collectors. -Weight Synchronization using Weight Update Schemes --------------------------------------------------- +Weight Synchronization +---------------------- -RL pipelines are typically split in two big computational buckets: training, and inference. -While the inference pipeline sends data to the training one, the training pipeline needs to occasionally -synchronize its weights with the inference one. -In the most basic setting (fully synchronized data collection with traditional neural networks), the same weights are -used in both instances. From there, anything can happen: +In reinforcement learning, the training pipeline is typically split into two computational phases: +**inference** (data collection) and **training** (policy optimization). While the inference pipeline +sends data to the training one, the training pipeline needs to periodically synchronize its weights +with the inference workers to ensure they collect data using up-to-date policies. + +Overview & Motivation +~~~~~~~~~~~~~~~~~~~~~ + +In the simplest setting, the same policy weights are used in both training and inference. However, +real-world RL systems often face additional complexity: - In multiprocessed or distributed settings, several copies of the policy can be held by the inference workers (named `DataCollectors` in TorchRL). When synchronizing the weights, each worker needs to receive a new copy of the weights @@ -140,15 +145,222 @@ used in both instances. From there, anything can happen: asks for new weights, or must it only be the trainer who pushes its weights to the workers? An intermediate approach is to store the weights on some intermediary server and let the workers fetch them when necessary. -TorchRL tries to account for each of these problems in a flexible manner. We individuate four basic components in a weight -transfer: +Key Challenges +^^^^^^^^^^^^^^ + +Modern RL training often involves multiple models that need independent synchronization: + +1. **Multiple Models Per Collector**: A collector might need to update: + + - The main policy network + - A value network in a Ray actor within the replay buffer + - Models embedded in the environment itself + - Separate world models or auxiliary networks + +2. **Different Update Strategies**: Each model may require different synchronization approaches: + + - Full state_dict transfer vs. TensorDict-based updates + - Different transport mechanisms (multiprocessing pipes, shared memory, Ray object store, collective communication, RDMA, etc.) + - Varied update frequencies + +3. **Worker-Agnostic Updates**: Some models (like those in shared Ray actors) shouldn't be tied to + specific worker indices, requiring a more flexible update mechanism. + +The Solution +^^^^^^^^^^^^ + +TorchRL addresses these challenges through a flexible, modular architecture built around four components: + +- **WeightSyncScheme**: Defines *what* to synchronize and *how* (user-facing configuration) +- **WeightSender**: Handles distributing weights from the main process to workers (internal) +- **WeightReceiver**: Handles applying weights in worker processes (internal) +- **TransportBackend**: Manages the actual communication layer (internal) + +This design allows you to independently configure synchronization for multiple models, +choose appropriate transport mechanisms, and swap strategies without rewriting your training code. + +Architecture & Concepts +~~~~~~~~~~~~~~~~~~~~~~~ + +Component Roles +^^^^^^^^^^^^^^^ + +The weight synchronization system separates concerns into four distinct layers: + +1. **WeightSyncScheme** (User-Facing) + + This is your main configuration interface. You create scheme objects that define: + + - The synchronization strategy (``"state_dict"`` or ``"tensordict"``) + - The transport mechanism (multiprocessing pipes, shared memory, Ray, RPC, etc.) + - Additional options like auto-registration and timeout behavior + + When working with collectors, you pass a dictionary mapping model IDs to schemes. + +2. **WeightSender** (Internal) + + Created by the scheme in the main training process. The sender: + + - Holds a reference to the source model + - Manages transport connections to all workers + - Extracts weights using the configured strategy + - Broadcasts weight updates across all transports + +3. **WeightReceiver** (Internal) + + Created by the scheme in each worker process. The receiver: + + - Holds a reference to the destination model + - Polls its transport for weight updates + - Applies received weights using the configured strategy + - Handles model registration and initialization + +4. **TransportBackend** (Internal) + + Implements the actual communication mechanism: + + - ``MPTransport``: Uses multiprocessing pipes + - ``SharedMemTransport``: Uses shared memory buffers (zero-copy) + - ``RayTransport``: Uses Ray's object store + - ``RPCTransport``: Uses PyTorch RPC + - ``DistributedTransport``: Uses collective communication (NCCL, Gloo, MPI) + +Initialization Phase +^^^^^^^^^^^^^^^^^^^^ + +When you create a collector with weight sync schemes, the following initialization occurs: + +.. aafig:: + :aspect: 60 + :scale: 130 + :proportional: + + INITIALIZATION PHASE + ==================== + + WeightSyncScheme + +------------------+ + | | + | Configuration: | + | - strategy | + | - transport_type | + | | + +--------+---------+ + | + +------------+-------------+ + | | + creates creates + | | + v v + Main Process Worker Process + +--------------+ +---------------+ + | WeightSender | | WeightReceiver| + | | | | + | - strategy | | - strategy | + | - transports | | - transport | + | - model_ref | | - model_ref | + | | | | + | Registers: | | Registers: | + | - model | | - model | + | - workers | | - transport | + +--------------+ +---------------+ + | | + | Transport Layer | + | +----------------+ | + +-->+ MPTransport |<------+ + | | (pipes) | | + | +----------------+ | + | +----------------+ | + +-->+ SharedMemTrans |<------+ + | | (shared mem) | | + | +----------------+ | + | +----------------+ | + +-->+ RayTransport |<------+ + | (Ray store) | + +----------------+ + +The scheme creates a sender in the main process and a receiver in each worker, then establishes +transport connections between them. + +Synchronization Phase +^^^^^^^^^^^^^^^^^^^^^ + +When you call ``collector.update_policy_weights_()``, the weight synchronization proceeds as follows: + +.. aafig:: + :aspect: 60 + :scale: 130 + :proportional: + + SYNCHRONIZATION PHASE + ===================== + + Main Process Worker Process + + +-------------------+ +-------------------+ + | WeightSender | | WeightReceiver | + | | | | + | 1. Extract | | 4. Poll transport | + | weights from | | for weights | + | model using | | | + | strategy | | | + | | 2. Send via | | + | +-------------+ | Transport | +--------------+ | + | | Strategy | | +------------+ | | Strategy | | + | | extract() | | | | | | apply() | | + | +-------------+ +----+ Transport +-------->+ +--------------+ | + | | | | | | | | + | v | +------------+ | v | + | +-------------+ | | +--------------+ | + | | Model | | | | Model | | + | | (source) | | 3. Ack (optional) | | (dest) | | + | +-------------+ | <-----------------------+ | +--------------+ | + | | | | + +-------------------+ | 5. Apply weights | + | to model using | + | strategy | + +-------------------+ + +1. **Extract**: Sender extracts weights from the source model (state_dict or TensorDict) +2. **Send**: Sender broadcasts weights through all registered transports +3. **Acknowledge** (optional): Some transports send acknowledgment back +4. **Poll**: Receiver checks its transport for new weights +5. **Apply**: Receiver applies weights to the destination model + +Multi-Model Synchronization +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +One of the key features is support for synchronizing multiple models independently: + +.. aafig:: + :aspect: 60 + :scale: 130 + :proportional: -- A `Sender` class that somehow gets the weights (or a reference to them) and initializes the transfer; -- A `Receiver` class that casts the weights to the destination module (policy or other utility module); -- A `Transport` class that codes up the actual transfer of the weights (through shared memory, nccl or anything else). -- A Scheme that defines what sender, receiver and transport have to be used and how to initialize them. + Main Process Worker Process 1 Worker Process 2 + + +-----------------+ +---------------+ +---------------+ + | Collector | | Collector | | Collector | + | | | | | | + | Models: | | Models: | | Models: | + | +----------+ | | +--------+ | | +--------+ | + | | Policy A | | | |Policy A| | | |Policy A| | + | +----------+ | | +--------+ | | +--------+ | + | +----------+ | | +--------+ | | +--------+ | + | | Model B | | | |Model B| | | |Model B| | + | +----------+ | | +--------+ | | +--------+ | + | | | | | | + | Weight Senders: | | Weight | | Weight | + | +----------+ | | Receivers: | | Receivers: | + | | Sender A +---+------------+->Receiver A | | Receiver A | + | +----------+ | | | | | + | +----------+ | | +--------+ | | +--------+ | + | | Sender B +---+------------+->Receiver B | | Receiver B | + | +----------+ | Pipes | | Pipes | | + +-----------------+ +-------+-------+ +-------+-------+ -Each of these classes is detailed below. +Each model gets its own sender/receiver pair, allowing independent synchronization frequencies, +different transport mechanisms per model, and model-specific strategies. Usage Examples ~~~~~~~~~~~~~~ @@ -301,8 +513,17 @@ across multiple inference workers: dictionaries, while ``"tensordict"`` uses TensorDict format which can be more efficient for structured models and supports advanced features like lazy initialization. -Weight Senders -~~~~~~~~~~~~~~ +API Reference +~~~~~~~~~~~~~ + +The weight synchronization system provides both user-facing configuration classes and internal +implementation classes that are automatically managed by the collectors. + +Schemes (User-Facing Configuration) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +These are the main classes you'll use to configure weight synchronization. Pass them in the +``weight_sync_schemes`` dictionary when creating collectors. .. currentmodule:: torchrl.weight_update @@ -310,11 +531,20 @@ Weight Senders :toctree: generated/ :template: rl_template.rst - WeightSender - RayModuleTransformSender + WeightSyncScheme + MultiProcessWeightSyncScheme + SharedMemWeightSyncScheme + NoWeightSyncScheme + RayWeightSyncScheme + RayModuleTransformScheme + RPCWeightSyncScheme + DistributedWeightSyncScheme + +Senders and Receivers (Internal) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Weight Receivers -~~~~~~~~~~~~~~~~ +These classes are automatically created and managed by the schemes. You typically don't need +to interact with them directly. .. currentmodule:: torchrl.weight_update @@ -322,11 +552,16 @@ Weight Receivers :toctree: generated/ :template: rl_template.rst + WeightSender WeightReceiver + RayModuleTransformSender RayModuleTransformReceiver -Transports -~~~~~~~~~~ +Transport Backends (Internal) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Transport classes handle the actual communication between processes. They are automatically +selected and configured by the schemes. .. currentmodule:: torchrl.weight_update @@ -342,24 +577,6 @@ Transports RPCTransport DistributedTransport -Schemes -~~~~~~~ - -.. currentmodule:: torchrl.weight_update - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - WeightSyncScheme - MultiProcessWeightSyncScheme - SharedMemWeightSyncScheme - NoWeightSyncScheme - RayWeightSyncScheme - RayModuleTransformScheme - RPCWeightSyncScheme - DistributedWeightSyncScheme - Legacy: Weight Synchronization in Distributed Environments ---------------------------------------------------------- @@ -417,248 +634,6 @@ transformed, and applied, ensuring seamless integration with their existing infr RPCWeightUpdater DistributedWeightUpdater -Weight Synchronization API -~~~~~~~~~~~~~~~~~~~~~~~~~~ - -The weight synchronization API provides a simple, modular approach to updating model weights across -distributed collectors. This system is designed to handle the complexities of modern RL setups where multiple -models may need to be synchronized independently. - -Overview -^^^^^^^^ - -In reinforcement learning, particularly with multi-process data collection, it's essential to keep the inference -policies synchronized with the latest trained weights. The API addresses this challenge through a clean -separation of concerns, where four classes are involved: - -- **Configuration**: :class:`~torchrl.weight_update.weight_sync_schemes.WeightSyncScheme` objects define *what* to synchronize and *how*. For DataCollectors, this is - your main entrypoint to configure the weight synchronization. -- **Sending**: :class:`~torchrl.weight_update.weight_sync_schemes.WeightSender` handles distributing weights from the main process to workers. -- **Receiving**: :class:`~torchrl.weight_update.weight_sync_schemes.WeightReceiver` handles applying weights in worker processes. -- **Transport**: Backend-specific communication mechanisms (pipes, shared memory, Ray, RPC) - -The following diagram shows the different classes involved in the weight synchronization process: - -.. aafig:: - :aspect: 60 - :scale: 130 - :proportional: - - INITIALIZATION PHASE - ==================== - - WeightSyncScheme - +------------------+ - | | - | Configuration: | - | - strategy | - | - transport_type | - | | - +--------+---------+ - | - +------------+-------------+ - | | - creates creates - | | - v v - Main Process Worker Process - +--------------+ +---------------+ - | WeightSender | | WeightReceiver| - | | | | - | - strategy | | - strategy | - | - transports | | - transport | - | - model_ref | | - model_ref | - | | | | - | Registers: | | Registers: | - | - model | | - model | - | - workers | | - transport | - +--------------+ +---------------+ - | | - | Transport Layer | - | +----------------+ | - +-->+ MPTransport |<------+ - | | (pipes) | | - | +----------------+ | - | +----------------+ | - +-->+ SharedMemTrans |<------+ - | | (shared mem) | | - | +----------------+ | - | +----------------+ | - +-->+ RayTransport |<------+ - | (Ray store) | - +----------------+ - - - SYNCHRONIZATION PHASE - ===================== - - Main Process Worker Process - - +-------------------+ +-------------------+ - | WeightSender | | WeightReceiver | - | | | | - | 1. Extract | | 4. Poll transport | - | weights from | | for weights | - | model using | | | - | strategy | | | - | | 2. Send via | | - | +-------------+ | Transport | +--------------+ | - | | Strategy | | +------------+ | | Strategy | | - | | extract() | | | | | | apply() | | - | +-------------+ +----+ Transport +-------->+ +--------------+ | - | | | | | | | | - | v | +------------+ | v | - | +-------------+ | | +--------------+ | - | | Model | | | | Model | | - | | (source) | | 3. Ack (optional) | | (dest) | | - | +-------------+ | <-----------------------+ | +--------------+ | - | | | | - +-------------------+ | 5. Apply weights | - | to model using | - | strategy | - +-------------------+ - -Key Challenges Addressed -^^^^^^^^^^^^^^^^^^^^^^^^^ - -Modern RL training often involves multiple models that need independent synchronization: - -1. **Multiple Models Per Collector**: A collector might need to update: - - - The main policy network - - A value network in a Ray actor within the replay buffer - - Models embedded in the environment itself - - Separate world models or auxiliary networks - -2. **Different Update Strategies**: Each model may require different synchronization approaches: - - - Full state_dict transfer vs. TensorDict-based updates - - Different transport mechanisms (multiprocessing pipes, shared memory, Ray object store, collective communication, RDMA, etc.) - - Varied update frequencies - -3. **Worker-Agnostic Updates**: Some models (like those in shared Ray actors) shouldn't be tied to - specific worker indices, requiring a more flexible update mechanism. - -Architecture -^^^^^^^^^^^^ - -The API follows a scheme-based design where users specify synchronization requirements upfront, -and the collector handles the orchestration transparently: - -.. aafig:: - :aspect: 60 - :scale: 130 - :proportional: - - Main Process Worker Process 1 Worker Process 2 - - +-----------------+ +---------------+ +---------------+ - | Collector | | Collector | | Collector | - | | | | | | - | Models: | | Models: | | Models: | - | +----------+ | | +--------+ | | +--------+ | - | | Policy A | | | |Policy A| | | |Policy A| | - | +----------+ | | +--------+ | | +--------+ | - | +----------+ | | +--------+ | | +--------+ | - | | Model B | | | |Model B| | | |Model B| | - | +----------+ | | +--------+ | | +--------+ | - | | | | | | - | Weight Senders: | | Weight | | Weight | - | +----------+ | | Receivers: | | Receivers: | - | | Sender A +---+------------+->Receiver A | | Receiver A | - | +----------+ | | | | | - | +----------+ | | +--------+ | | +--------+ | - | | Sender B +---+------------+->Receiver B | | Receiver B | - | +----------+ | Pipes | | Pipes | | - +-----------------+ +-------+-------+ +-------+-------+ - ^ ^ ^ - | | | - | update_policy_weights_() | Apply weights | - | | | - +------+-------+ | | - | User Code | | | - | (Training) | | | - +--------------+ +------------------------+ - -The weight synchronization flow: - -1. **Initialization**: User creates ``weight_sync_schemes`` dict mapping model IDs to schemes -2. **Registration**: Collector creates ``WeightSender`` for each model in the main process -3. **Worker Setup**: Each worker creates corresponding ``WeightReceiver`` instances -4. **Synchronization**: Calling ``update_policy_weights_()`` triggers all senders to push weights -5. **Application**: Receivers automatically apply weights to their registered models - -Available Classes -^^^^^^^^^^^^^^^^^ - -**Synchronization Schemes** (User-Facing Configuration): - -- :class:`~torchrl.weight_update.weight_sync_schemes.WeightSyncScheme`: Base class for schemes -- :class:`~torchrl.weight_update.weight_sync_schemes.MultiProcessWeightSyncScheme`: For multiprocessing with pipes -- :class:`~torchrl.weight_update.weight_sync_schemes.SharedMemWeightSyncScheme`: For shared memory synchronization -- :class:`~torchrl.weight_update.weight_sync_schemes.RayWeightSyncScheme`: For Ray-based distribution -- :class:`~torchrl.weight_update.weight_sync_schemes.NoWeightSyncScheme`: Dummy scheme for no synchronization - -**Internal Classes** (Automatically Managed): - -- :class:`~torchrl.weight_update.weight_sync_schemes.WeightSender`: Sends weights to all workers for one model -- :class:`~torchrl.weight_update.weight_sync_schemes.WeightReceiver`: Receives and applies weights in worker -- :class:`~torchrl.weight_update.weight_sync_schemes.TransportBackend`: Communication layer abstraction - -Usage Example -^^^^^^^^^^^^^ - -.. code-block:: python - - from torchrl.collectors import MultiSyncDataCollector - from torchrl.weight_update.weight_sync_schemes import MultiProcessWeightSyncScheme - - # Define synchronization for multiple models - weight_sync_schemes = { - "policy": MultiProcessWeightSyncScheme(strategy="tensordict"), - "value_net": MultiProcessWeightSyncScheme(strategy="state_dict"), - } - - collector = MultiSyncDataCollector( - create_env_fn=[make_env] * 4, - policy=policy, - frames_per_batch=1000, - weight_sync_schemes=weight_sync_schemes, # Pass schemes dict - ) - - # Single call updates all registered models across all workers - for i, batch in enumerate(collector): - # Training step - loss = train(batch) - - # Sync all models with one call - collector.update_policy_weights_(policy) - -The collector automatically: - -- Creates ``WeightSender`` instances in the main process for each model -- Creates ``WeightReceiver`` instances in each worker process -- Resolves models by ID (e.g., ``"policy"`` → ``collector.policy``) -- Handles transport setup and communication -- Applies weights using the appropriate strategy (state_dict vs tensordict) - -API Reference -^^^^^^^^^^^^^ - -.. currentmodule:: torchrl.weight_update.weight_sync_schemes - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - WeightSyncScheme - MultiProcessWeightSyncScheme - SharedMemWeightSyncScheme - RayWeightSyncScheme - NoWeightSyncScheme - WeightSender - WeightReceiver - Collectors and replay buffers interoperability ---------------------------------------------- diff --git a/examples/collectors/weight_sync_collectors.py b/examples/collectors/weight_sync_collectors.py index fbb1a8a1166..a3962966c8c 100644 --- a/examples/collectors/weight_sync_collectors.py +++ b/examples/collectors/weight_sync_collectors.py @@ -17,7 +17,7 @@ import torch.nn as nn from tensordict import TensorDict from tensordict.nn import TensorDictModule -from torchrl.collectors import SyncDataCollector, MultiSyncDataCollector +from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector from torchrl.envs import GymEnv from torchrl.weight_update import ( MultiProcessWeightSyncScheme, @@ -27,25 +27,24 @@ def example_single_collector_multiprocess(): """Example 1: Single collector with multiprocess scheme.""" - print("\n" + "="*70) + print("\n" + "=" * 70) print("Example 1: Single Collector with Multiprocess Scheme") - print("="*70) - + print("=" * 70) + # Create environment and policy env = GymEnv("CartPole-v1") policy = TensorDictModule( nn.Linear( - env.observation_spec["observation"].shape[-1], - env.action_spec.shape[-1] + env.observation_spec["observation"].shape[-1], env.action_spec.shape[-1] ), in_keys=["observation"], out_keys=["action"], ) env.close() - + # Create weight sync scheme scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - + print("Creating collector with multiprocess weight sync...") collector = SyncDataCollector( create_env_fn=lambda: GymEnv("CartPole-v1"), @@ -54,46 +53,45 @@ def example_single_collector_multiprocess(): total_frames=200, weight_sync_schemes={"policy": scheme}, ) - + # Collect data and update weights periodically print("Collecting data...") for i, data in enumerate(collector): print(f"Iteration {i}: Collected {data.numel()} transitions") - + # Update policy weights every 2 iterations if i % 2 == 0: new_weights = policy.state_dict() collector.update_policy_weights_(new_weights) print(" → Updated policy weights") - + if i >= 2: # Just run a few iterations for demo break - + collector.shutdown() print("✓ Single collector example completed!\n") def example_multi_collector_shared_memory(): """Example 2: Multiple collectors with shared memory.""" - print("\n" + "="*70) + print("\n" + "=" * 70) print("Example 2: Multiple Collectors with Shared Memory") - print("="*70) - + print("=" * 70) + # Create environment and policy env = GymEnv("CartPole-v1") policy = TensorDictModule( nn.Linear( - env.observation_spec["observation"].shape[-1], - env.action_spec.shape[-1] + env.observation_spec["observation"].shape[-1], env.action_spec.shape[-1] ), in_keys=["observation"], out_keys=["action"], ) env.close() - + # Shared memory is more efficient for frequent updates scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) - + print("Creating multi-collector with shared memory...") collector = MultiSyncDataCollector( create_env_fn=[ @@ -106,49 +104,51 @@ def example_multi_collector_shared_memory(): total_frames=400, weight_sync_schemes={"policy": scheme}, ) - + # Workers automatically see weight updates via shared memory print("Collecting data...") for i, data in enumerate(collector): print(f"Iteration {i}: Collected {data.numel()} transitions") - + # Update weights frequently (shared memory makes this very fast) collector.update_policy_weights_(TensorDict.from_module(policy)) print(" → Updated policy weights via shared memory") - + if i >= 1: # Just run a couple iterations for demo break - + collector.shutdown() print("✓ Multi-collector with shared memory example completed!\n") def main(): """Run all examples.""" - print("\n" + "="*70) + print("\n" + "=" * 70) print("Weight Synchronization Schemes - Collector Integration Examples") - print("="*70) - + print("=" * 70) + # Set multiprocessing start method import torch.multiprocessing as mp + try: - mp.set_start_method('spawn') + mp.set_start_method("spawn") except RuntimeError: pass # Already set - + # Run examples example_single_collector_multiprocess() example_multi_collector_shared_memory() - - print("\n" + "="*70) + + print("\n" + "=" * 70) print("All examples completed successfully!") - print("="*70) + print("=" * 70) print("\nKey takeaways:") print(" • MultiProcessWeightSyncScheme: Good for general multiprocess scenarios") - print(" • SharedMemWeightSyncScheme: Fast zero-copy updates for same-machine workers") - print("="*70 + "\n") + print( + " • SharedMemWeightSyncScheme: Fast zero-copy updates for same-machine workers" + ) + print("=" * 70 + "\n") if __name__ == "__main__": main() - diff --git a/examples/collectors/weight_sync_standalone.py b/examples/collectors/weight_sync_standalone.py index 83492256412..69d9947bdc7 100644 --- a/examples/collectors/weight_sync_standalone.py +++ b/examples/collectors/weight_sync_standalone.py @@ -16,8 +16,8 @@ import torch import torch.nn as nn -from torch import multiprocessing as mp from tensordict import TensorDict +from torch import multiprocessing as mp from torchrl.weight_update import ( MultiProcessWeightSyncScheme, SharedMemWeightSyncScheme, @@ -27,21 +27,21 @@ def worker_process_mp(child_pipe, model_state): """Worker process that receives weights via multiprocessing pipe.""" print("Worker: Starting...") - + # Create a policy on the worker side policy = nn.Linear(4, 2) with torch.no_grad(): policy.weight.fill_(0.0) policy.bias.fill_(0.0) - + # Create receiver and register the policy scheme = MultiProcessWeightSyncScheme(strategy="state_dict") receiver = scheme.create_receiver() receiver.register_model(policy) receiver.register_worker_transport(child_pipe) - + print(f"Worker: Before update - weight sum: {policy.weight.sum().item():.4f}") - + # Receive and apply weights result = receiver._transport.receive_weights(timeout=5.0) if result is not None: @@ -50,19 +50,19 @@ def worker_process_mp(child_pipe, model_state): print(f"Worker: After update - weight sum: {policy.weight.sum().item():.4f}") else: print("Worker: No weights received") - + # Store final state for verification - model_state['weight_sum'] = policy.weight.sum().item() - model_state['bias_sum'] = policy.bias.sum().item() + model_state["weight_sum"] = policy.weight.sum().item() + model_state["bias_sum"] = policy.bias.sum().item() def worker_process_shared_mem(child_pipe, model_state): """Worker process that receives shared memory buffer reference.""" print("SharedMem Worker: Starting...") - + # Create a policy on the worker side policy = nn.Linear(4, 2) - + # Wait for shared memory buffer registration if child_pipe.poll(timeout=10.0): data, msg = child_pipe.recv() @@ -73,129 +73,135 @@ def worker_process_shared_mem(child_pipe, model_state): shared_weights.to_module(policy) # Send acknowledgment child_pipe.send((None, "registered")) - + # Small delay to ensure main process updates shared memory import time + time.sleep(0.5) - + print(f"SharedMem Worker: weight sum: {policy.weight.sum().item():.4f}") - + # Store final state for verification - model_state['weight_sum'] = policy.weight.sum().item() - model_state['bias_sum'] = policy.bias.sum().item() + model_state["weight_sum"] = policy.weight.sum().item() + model_state["bias_sum"] = policy.bias.sum().item() def example_multiprocess_sync(): """Example 1: Multiprocess weight synchronization with state_dict.""" - print("\n" + "="*70) + print("\n" + "=" * 70) print("Example 1: Multiprocess Weight Synchronization") - print("="*70) - + print("=" * 70) + # Create a simple policy on main process policy = nn.Linear(4, 2) with torch.no_grad(): policy.weight.fill_(1.0) policy.bias.fill_(0.5) - + print(f"Main: Policy weight sum: {policy.weight.sum().item():.4f}") - + # Create scheme and sender scheme = MultiProcessWeightSyncScheme(strategy="state_dict") sender = scheme.create_sender() - + # Create pipe for communication parent_pipe, child_pipe = mp.Pipe() sender.register_worker(worker_idx=0, pipe_or_context=parent_pipe) - + # Start worker process manager = mp.Manager() model_state = manager.dict() process = mp.Process(target=worker_process_mp, args=(child_pipe, model_state)) process.start() - + # Send weights to worker weights = policy.state_dict() print("Main: Sending weights to worker...") sender.update_weights(weights) - + # Wait for worker to complete process.join(timeout=10.0) - + if process.is_alive(): print("Warning: Worker process did not terminate in time") process.terminate() else: - print(f"Main: Worker completed. Worker's weight sum: {model_state['weight_sum']:.4f}") - print(f"✓ Weight synchronization successful!") + print( + f"Main: Worker completed. Worker's weight sum: {model_state['weight_sum']:.4f}" + ) + print("✓ Weight synchronization successful!") def example_shared_memory_sync(): """Example 2: Shared memory weight synchronization.""" - print("\n" + "="*70) + print("\n" + "=" * 70) print("Example 2: Shared Memory Weight Synchronization") - print("="*70) - + print("=" * 70) + # Create a simple policy policy = nn.Linear(4, 2) - + # Create shared memory scheme with auto-registration scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) sender = scheme.create_sender() - + # Create pipe for lazy registration parent_pipe, child_pipe = mp.Pipe() sender.register_worker(worker_idx=0, pipe_or_context=parent_pipe) - + # Start worker process manager = mp.Manager() model_state = manager.dict() - process = mp.Process(target=worker_process_shared_mem, args=(child_pipe, model_state)) + process = mp.Process( + target=worker_process_shared_mem, args=(child_pipe, model_state) + ) process.start() - + # Send weights (automatically creates shared buffer on first send) weights_td = TensorDict.from_module(policy) with torch.no_grad(): weights_td["weight"].fill_(2.0) weights_td["bias"].fill_(1.0) - - print(f"Main: Sending weights via shared memory...") + + print("Main: Sending weights via shared memory...") sender.update_weights(weights_td) - + # Workers automatically see updates via shared memory! print("Main: Weights are now in shared memory, workers can access them") - + # Wait for worker to complete process.join(timeout=10.0) - + if process.is_alive(): print("Warning: Worker process did not terminate in time") process.terminate() else: - print(f"Main: Worker completed. Worker's weight sum: {model_state['weight_sum']:.4f}") - print(f"✓ Shared memory synchronization successful!") + print( + f"Main: Worker completed. Worker's weight sum: {model_state['weight_sum']:.4f}" + ) + print("✓ Shared memory synchronization successful!") def main(): """Run all examples.""" - print("\n" + "="*70) + print("\n" + "=" * 70) print("Weight Synchronization Schemes - Standalone Usage Examples") - print("="*70) - + print("=" * 70) + # Set multiprocessing start method try: - mp.set_start_method('spawn') + mp.set_start_method("spawn") except RuntimeError: pass # Already set - + # Run examples example_multiprocess_sync() example_shared_memory_sync() - - print("\n" + "="*70) + + print("\n" + "=" * 70) print("All examples completed successfully!") - print("="*70 + "\n") + print("=" * 70 + "\n") if __name__ == "__main__": main() - diff --git a/torchrl/weight_update/weight_sync_schemes.py b/torchrl/weight_update/weight_sync_schemes.py index 763753896b2..244b7c204f4 100644 --- a/torchrl/weight_update/weight_sync_schemes.py +++ b/torchrl/weight_update/weight_sync_schemes.py @@ -1313,7 +1313,7 @@ def _resolve_model(context: Any, model_id: str) -> Any: obj = getattr(obj, key) except AttributeError: raise AttributeError( - f"Attribute {key} from {parts[:i+1]} not found in {'.'.join(parts[:i])}={obj}" + f"Attribute {key} from {parts[:i + 1]} not found in {'.'.join(parts[:i])}={obj}" ) for index in indices: obj = obj[index] @@ -1322,6 +1322,6 @@ def _resolve_model(context: Any, model_id: str) -> Any: obj = getattr(obj, part) except AttributeError: raise AttributeError( - f"Attribute {part} from {parts[:i+1]} not found in {'.'.join(parts[:i])}={obj}" + f"Attribute {part} from {parts[:i + 1]} not found in {'.'.join(parts[:i])}={obj}" ) return obj