Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 29 additions & 23 deletions docs/source/reference/collectors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,15 @@ Usage Examples
Using Weight Update Schemes Independently
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Weight update schemes can be used outside of collectors for custom synchronization scenarios. Here's a basic example:
Weight update schemes can be used outside of collectors for custom synchronization scenarios.
The new simplified API provides four core methods for weight synchronization:

- ``init_on_sender(model_id, **kwargs)`` - Initialize on the main process (trainer) side
- ``init_on_worker(model_id, **kwargs)`` - Initialize on worker process side
- ``get_sender()`` - Get the configured sender instance
- ``get_receiver()`` - Get the configured receiver instance

Here's a basic example:

.. code-block:: python

Expand All @@ -182,39 +190,37 @@ Weight update schemes can be used outside of collectors for custom synchronizati
# --------------------------------------------------------------
# On the main process side (trainer):
scheme = MultiProcessWeightSyncScheme(strategy="state_dict")
sender = scheme.create_sender()

# Register worker pipes

# Initialize scheme with pipes
parent_pipe, child_pipe = mp.Pipe()
sender.register_worker(worker_idx=0, pipe_or_context=parent_pipe)

# Send weights to workers
scheme.init_on_sender(model_id="policy", pipes=[parent_pipe])

# Get the sender and send weights
sender = scheme.get_sender()
weights = policy.state_dict()
sender.update_weights(weights)
sender.send(weights) # Synchronous send
# or sender.send_async(weights); sender.wait_async() # Asynchronous send

# On the worker process side:
# receiver = scheme.create_receiver()
# receiver.register_model(policy)
# receiver.register_worker_transport(child_pipe)
# # Receive and apply weights
# result = receiver._transport.receive_weights(timeout=5.0)
# if result is not None:
# model_id, weights = result
# receiver.apply_weights(weights)
# scheme.init_on_worker(model_id="policy", pipe=child_pipe, model=policy)
# receiver = scheme.get_receiver()
# # Non-blocking check for new weights
# if receiver.receive(timeout=0.001):
# # Weights were received and applied

# Example 2: Shared memory weight synchronization
# ------------------------------------------------
# Create shared memory scheme with auto-registration
shared_scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True)
shared_sender = shared_scheme.create_sender()

# Register worker pipes for lazy registration

# Initialize with pipes for lazy registration
parent_pipe2, child_pipe2 = mp.Pipe()
shared_sender.register_worker(worker_idx=0, pipe_or_context=parent_pipe2)

# Send weights (automatically creates shared buffer on first send)
shared_scheme.init_on_sender(model_id="policy", pipes=[parent_pipe2])

# Get sender and send weights (automatically creates shared buffer on first send)
shared_sender = shared_scheme.get_sender()
weights_td = TensorDict.from_module(policy)
shared_sender.update_weights(weights_td)
shared_sender.send(weights_td)

# Workers automatically see updates via shared memory!

Expand Down
Loading
Loading