From 217df91c4f7f3de194c120fc532a5163b35f2c63 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 14 Oct 2025 09:40:44 +0100 Subject: [PATCH 1/6] Update [ghstack-poisoned] --- pyproject.toml | 4 +- test/test_weightsync.py | 574 ++++++++ torchrl/collectors/weight_update.py | 16 + torchrl/weight_update/__init__.py | 50 + torchrl/weight_update/weight_sync_schemes.py | 1327 ++++++++++++++++++ 5 files changed, 1970 insertions(+), 1 deletion(-) create mode 100644 test/test_weightsync.py create mode 100644 torchrl/weight_update/__init__.py create mode 100644 torchrl/weight_update/weight_sync_schemes.py diff --git a/pyproject.toml b/pyproject.toml index 9159f05f9a0..f156381d153 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ maintainers = [ ] keywords = ["reinforcement-learning", "pytorch", "rl", "machine-learning"] classifiers = [ + "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", @@ -53,6 +54,7 @@ tests = [ "pytest-benchmark", "pytest-rerunfailures", "pytest-error-for-skips", + "pytest-timeout", ] utils = [ "tensorboard", @@ -75,7 +77,7 @@ offline-data = [ ] marl = ["vmas>=1.2.10", "pettingzoo>=1.24.1", "dm-meltingpot"] open_spiel = ["open_spiel>=1.5"] -brax = ["jax[cuda12]>=0.7.0", "brax"] +brax = ["jax>=0.7.0", "brax"] llm = [ "transformers", "vllm", diff --git a/test/test_weightsync.py b/test/test_weightsync.py new file mode 100644 index 00000000000..91203c209f0 --- /dev/null +++ b/test/test_weightsync.py @@ -0,0 +1,574 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import argparse + +import pytest +import torch +import torch.nn as nn +from tensordict import TensorDict +from tensordict.nn import TensorDictModule +from torch import multiprocessing as mp +from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector +from torchrl.envs import GymEnv + +from torchrl.weight_update.weight_sync_schemes import ( + _resolve_model, + MPTransport, + MultiProcessWeightSyncScheme, + NoWeightSyncScheme, + SharedMemTransport, + SharedMemWeightSyncScheme, + WeightStrategy, +) + + +def worker_update_policy(pipe, timeout=5.0): + policy = nn.Linear(4, 2) + with torch.no_grad(): + policy.weight.fill_(0.0) + policy.bias.fill_(0.0) + + scheme = MultiProcessWeightSyncScheme(strategy="state_dict") + receiver = scheme.create_receiver() + receiver.register_model(policy) + receiver.register_worker_transport(pipe) + + if receiver._transport.pipe.poll(timeout): + data, msg = receiver._transport.pipe.recv() + if msg == "update_weights": + model_id, weights = data + receiver.apply_weights(weights) + + return policy.weight.sum().item(), policy.bias.sum().item() + + +def worker_update_policy_tensordict(pipe, timeout=5.0): + policy = nn.Linear(4, 2) + with torch.no_grad(): + policy.weight.fill_(0.0) + policy.bias.fill_(0.0) + + scheme = MultiProcessWeightSyncScheme(strategy="tensordict") + receiver = scheme.create_receiver() + receiver.register_model(policy) + receiver.register_worker_transport(pipe) + + if receiver._transport.pipe.poll(timeout): + data, msg = receiver._transport.pipe.recv() + if msg == "update_weights": + model_id, weights = data + receiver.apply_weights(weights) + + return policy.weight.sum().item(), policy.bias.sum().item() + + +def worker_shared_mem(pipe, timeout=10.0): + policy = nn.Linear(4, 2) + + if pipe.poll(timeout): + data, msg = pipe.recv() + if msg == "register_shared_weights": + model_id, shared_weights = data + shared_weights.to_module(policy) + pipe.send((None, "registered")) + + import time + + time.sleep(0.5) + + return policy.weight.sum().item(), policy.bias.sum().item() + + +class TestTransportBackends: + def test_mp_transport_basic(self): + parent_pipe, child_pipe = mp.Pipe() + transport = MPTransport(parent_pipe) + + assert transport.check_connection() + + proc = mp.Process(target=worker_update_policy, args=(child_pipe,)) + proc.start() + + test_weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)} + transport.send_weights("policy", test_weights) + + proc.join(timeout=10.0) + assert not proc.is_alive() + + def test_mp_transport_async(self): + parent_pipe, child_pipe = mp.Pipe() + transport = MPTransport(parent_pipe) + + proc = mp.Process(target=worker_update_policy, args=(child_pipe,)) + proc.start() + + test_weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)} + transport.send_weights_async("policy", test_weights) + transport.wait_ack() + + proc.join(timeout=10.0) + assert not proc.is_alive() + + def test_shared_mem_transport(self): + shared_buffer = TensorDict( + {"weight": torch.zeros(2, 4), "bias": torch.zeros(2)}, batch_size=[] + ).share_memory_() + + transport = SharedMemTransport({"policy": shared_buffer}) + + new_weights = TensorDict( + {"weight": torch.ones(2, 4), "bias": torch.ones(2)}, batch_size=[] + ) + + transport.send_weights("policy", new_weights) + + assert torch.allclose(shared_buffer["weight"], torch.ones(2, 4)) + assert torch.allclose(shared_buffer["bias"], torch.ones(2)) + + +class TestWeightStrategies: + def test_state_dict_strategy(self): + strategy = WeightStrategy(extract_as="state_dict") + + policy = nn.Linear(3, 4) + weights = strategy.extract_weights(policy) + assert isinstance(weights, dict) + assert "weight" in weights + assert "bias" in weights + + target_policy = nn.Linear(3, 4) + with torch.no_grad(): + target_policy.weight.fill_(0.0) + target_policy.bias.fill_(0.0) + + strategy.apply_weights(target_policy, weights) + + assert torch.allclose(policy.weight, target_policy.weight) + assert torch.allclose(policy.bias, target_policy.bias) + + def test_tensordict_strategy(self): + strategy = WeightStrategy(extract_as="tensordict") + + policy = nn.Linear(3, 4) + weights = strategy.extract_weights(policy) + assert isinstance(weights, TensorDict) + + target_policy = nn.Linear(3, 4) + with torch.no_grad(): + target_policy.weight.fill_(0.0) + target_policy.bias.fill_(0.0) + + strategy.apply_weights(target_policy, weights) + + assert torch.allclose(policy.weight, target_policy.weight) + assert torch.allclose(policy.bias, target_policy.bias) + + def test_cross_format_conversion(self): + policy = nn.Linear(3, 4) + + state_dict_strategy = WeightStrategy(extract_as="state_dict") + tensordict_strategy = WeightStrategy(extract_as="tensordict") + + state_dict_weights = state_dict_strategy.extract_weights(policy) + tensordict_weights = tensordict_strategy.extract_weights(policy) + + target_policy_1 = nn.Linear(3, 4) + target_policy_2 = nn.Linear(3, 4) + + with torch.no_grad(): + target_policy_1.weight.fill_(0.0) + target_policy_1.bias.fill_(0.0) + target_policy_2.weight.fill_(0.0) + target_policy_2.bias.fill_(0.0) + + state_dict_strategy.apply_weights(target_policy_1, tensordict_weights) + tensordict_strategy.apply_weights(target_policy_2, state_dict_weights) + + assert torch.allclose(policy.weight, target_policy_1.weight) + assert torch.allclose(policy.weight, target_policy_2.weight) + + +class TestWeightSyncSchemes: + def test_multiprocess_scheme_state_dict(self): + parent_pipe, child_pipe = mp.Pipe() + + scheme = MultiProcessWeightSyncScheme(strategy="state_dict") + sender = scheme.create_sender() + sender.register_worker(0, parent_pipe) + + proc = mp.Process(target=worker_update_policy, args=(child_pipe,)) + proc.start() + + weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)} + sender.update_weights(weights) + + proc.join(timeout=10.0) + assert not proc.is_alive() + + def test_multiprocess_scheme_tensordict(self): + parent_pipe, child_pipe = mp.Pipe() + + scheme = MultiProcessWeightSyncScheme(strategy="tensordict") + sender = scheme.create_sender() + sender.register_worker(0, parent_pipe) + + proc = mp.Process(target=worker_update_policy_tensordict, args=(child_pipe,)) + proc.start() + + weights = TensorDict( + {"weight": torch.ones(2, 4), "bias": torch.ones(2)}, batch_size=[] + ) + sender.update_weights(weights) + + proc.join(timeout=10.0) + assert not proc.is_alive() + + def test_shared_mem_scheme(self): + shared_buffer = TensorDict( + {"weight": torch.zeros(2, 4), "bias": torch.zeros(2)}, batch_size=[] + ).share_memory_() + + scheme = SharedMemWeightSyncScheme( + policy_weights={"policy": shared_buffer}, + strategy="tensordict", + auto_register=False, + ) + + transport = scheme.create_transport(None) + + new_weights = TensorDict( + {"weight": torch.ones(2, 4), "bias": torch.ones(2)}, batch_size=[] + ) + + transport.send_weights("policy", new_weights) + + assert torch.allclose(shared_buffer["weight"], torch.ones(2, 4)) + assert torch.allclose(shared_buffer["bias"], torch.ones(2)) + + def test_shared_mem_scheme_auto_register(self): + scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) + transport = scheme.create_transport(None) + + weights = TensorDict( + {"weight": torch.ones(2, 4), "bias": torch.ones(2)}, batch_size=[] + ) + + transport.send_weights("policy", weights) + + assert "policy" in scheme.policy_weights + assert torch.allclose( + scheme.policy_weights["policy"]["weight"], torch.ones(2, 4) + ) + + def test_no_weight_sync_scheme(self): + scheme = NoWeightSyncScheme() + transport = scheme.create_transport(None) + + weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)} + transport.send_weights("policy", weights) + + +class TestCollectorIntegration: + @pytest.fixture + def simple_env(self): + return GymEnv("CartPole-v1") + + @pytest.fixture + def simple_policy(self, simple_env): + return TensorDictModule( + nn.Linear( + simple_env.observation_spec["observation"].shape[-1], + simple_env.action_spec.shape[-1], + ), + in_keys=["observation"], + out_keys=["action"], + ) + + def test_syncdatacollector_multiprocess_scheme(self, simple_policy): + scheme = MultiProcessWeightSyncScheme(strategy="state_dict") + + collector = SyncDataCollector( + create_env_fn=lambda: GymEnv("CartPole-v1"), + policy=simple_policy, + frames_per_batch=64, + total_frames=128, + weight_sync_schemes={"policy": scheme}, + ) + + new_weights = simple_policy.state_dict() + with torch.no_grad(): + for key in new_weights: + new_weights[key].fill_(1.0) + + collector.update_policy_weights_(new_weights) + + for data in collector: + assert data.numel() > 0 + break + + collector.shutdown() + + def test_multisyncdatacollector_multiprocess_scheme(self, simple_policy): + scheme = MultiProcessWeightSyncScheme(strategy="state_dict") + + collector = MultiSyncDataCollector( + create_env_fn=[ + lambda: GymEnv("CartPole-v1"), + lambda: GymEnv("CartPole-v1"), + ], + policy=simple_policy, + frames_per_batch=64, + total_frames=128, + weight_sync_schemes={"policy": scheme}, + ) + + new_weights = simple_policy.state_dict() + with torch.no_grad(): + for key in new_weights: + new_weights[key].fill_(1.0) + + collector.update_policy_weights_(new_weights) + + for data in collector: + assert data.numel() > 0 + break + + collector.shutdown() + + def test_multisyncdatacollector_shared_mem_scheme(self, simple_policy): + scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) + + collector = MultiSyncDataCollector( + create_env_fn=[ + lambda: GymEnv("CartPole-v1"), + lambda: GymEnv("CartPole-v1"), + ], + policy=simple_policy, + frames_per_batch=64, + total_frames=128, + weight_sync_schemes={"policy": scheme}, + ) + + new_weights = TensorDict.from_module(simple_policy) + with torch.no_grad(): + new_weights["module"]["weight"].fill_(1.0) + new_weights["module"]["bias"].fill_(1.0) + + collector.update_policy_weights_(new_weights) + + for data in collector: + assert data.numel() > 0 + break + + collector.shutdown() + + def test_collector_no_weight_sync(self, simple_policy): + scheme = NoWeightSyncScheme() + + collector = SyncDataCollector( + create_env_fn=lambda: GymEnv("CartPole-v1"), + policy=simple_policy, + frames_per_batch=64, + total_frames=128, + weight_sync_schemes={"policy": scheme}, + ) + + for data in collector: + assert data.numel() > 0 + break + + collector.shutdown() + + +class TestMultiModelUpdates: + def test_multi_model_state_dict_updates(self): + env = GymEnv("CartPole-v1") + + policy = TensorDictModule( + nn.Linear( + env.observation_spec["observation"].shape[-1], env.action_spec.shape[-1] + ), + in_keys=["observation"], + out_keys=["action"], + ) + + value = TensorDictModule( + nn.Linear(env.observation_spec["observation"].shape[-1], 1), + in_keys=["observation"], + out_keys=["value"], + ) + + weight_sync_schemes = { + "policy": MultiProcessWeightSyncScheme(strategy="state_dict"), + "value": MultiProcessWeightSyncScheme(strategy="state_dict"), + } + + collector = SyncDataCollector( + create_env_fn=lambda: GymEnv("CartPole-v1"), + policy=policy, + frames_per_batch=64, + total_frames=128, + weight_sync_schemes=weight_sync_schemes, + ) + + policy_weights = policy.state_dict() + value_weights = value.state_dict() + + with torch.no_grad(): + for key in policy_weights: + policy_weights[key].fill_(1.0) + for key in value_weights: + value_weights[key].fill_(2.0) + + collector.update_policy_weights_( + weights_dict={ + "policy": policy_weights, + "value": value_weights, + } + ) + + for data in collector: + assert data.numel() > 0 + break + + collector.shutdown() + env.close() + + def test_multi_model_tensordict_updates(self): + env = GymEnv("CartPole-v1") + + policy = TensorDictModule( + nn.Linear( + env.observation_spec["observation"].shape[-1], env.action_spec.shape[-1] + ), + in_keys=["observation"], + out_keys=["action"], + ) + + value = TensorDictModule( + nn.Linear(env.observation_spec["observation"].shape[-1], 1), + in_keys=["observation"], + out_keys=["value"], + ) + + weight_sync_schemes = { + "policy": MultiProcessWeightSyncScheme(strategy="tensordict"), + "value": MultiProcessWeightSyncScheme(strategy="tensordict"), + } + + collector = SyncDataCollector( + create_env_fn=lambda: GymEnv("CartPole-v1"), + policy=policy, + frames_per_batch=64, + total_frames=128, + weight_sync_schemes=weight_sync_schemes, + ) + + policy_weights = TensorDict.from_module(policy) + value_weights = TensorDict.from_module(value) + + with torch.no_grad(): + policy_weights["module"]["weight"].fill_(1.0) + policy_weights["module"]["bias"].fill_(1.0) + value_weights["module"]["weight"].fill_(2.0) + value_weights["module"]["bias"].fill_(2.0) + + collector.update_policy_weights_( + weights_dict={ + "policy": policy_weights, + "value": value_weights, + } + ) + + for data in collector: + assert data.numel() > 0 + break + + collector.shutdown() + env.close() + + +class TestHelpers: + def test_resolve_model_simple(self): + class Context: + def __init__(self): + self.policy = nn.Linear(2, 3) + + context = Context() + resolved = _resolve_model(context, "policy") + assert resolved is context.policy + + def test_resolve_model_nested(self): + class Inner: + def __init__(self): + self.value_net = nn.Linear(2, 3) + + class Context: + def __init__(self): + self.env = Inner() + + context = Context() + resolved = _resolve_model(context, "env.value_net") + assert resolved is context.env.value_net + + def test_resolve_model_with_index(self): + class Context: + def __init__(self): + self.transform = [nn.Linear(2, 3), nn.Linear(3, 4)] + + context = Context() + resolved = _resolve_model(context, "transform[0]") + assert resolved is context.transform[0] + + resolved = _resolve_model(context, "transform[1]") + assert resolved is context.transform[1] + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +class TestDeviceHandling: + def test_weight_update_cpu_to_cpu(self): + policy = nn.Linear(3, 4) + strategy = WeightStrategy(extract_as="state_dict") + + weights = strategy.extract_weights(policy) + target = nn.Linear(3, 4) + strategy.apply_weights(target, weights) + + assert torch.allclose(policy.weight, target.weight) + + def test_weight_update_cuda_to_cuda(self): + policy = nn.Linear(3, 4).cuda() + strategy = WeightStrategy(extract_as="tensordict") + + weights = strategy.extract_weights(policy) + target = nn.Linear(3, 4).cuda() + strategy.apply_weights(target, weights) + + assert torch.allclose(policy.weight, target.weight) + + +@pytest.mark.parametrize("strategy", ["state_dict", "tensordict"]) +def test_weight_strategy_parametrized(strategy): + weight_strategy = WeightStrategy(extract_as=strategy) + + policy = nn.Linear(3, 4) + weights = weight_strategy.extract_weights(policy) + + target = nn.Linear(3, 4) + with torch.no_grad(): + target.weight.fill_(0.0) + target.bias.fill_(0.0) + + weight_strategy.apply_weights(target, weights) + + assert torch.allclose(policy.weight, target.weight) + assert torch.allclose(policy.bias, target.bias) + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst", "-v"] + unknown) diff --git a/torchrl/collectors/weight_update.py b/torchrl/collectors/weight_update.py index 933c5096c64..97fa62d6a2b 100644 --- a/torchrl/collectors/weight_update.py +++ b/torchrl/collectors/weight_update.py @@ -20,6 +20,10 @@ class WeightUpdaterBase(metaclass=abc.ABCMeta): """A base class for updating remote policy weights on inference workers. + .. deprecated:: + WeightUpdaterBase is deprecated and will be removed in a future version. + Please use WeightSyncScheme from torchrl.weight_update.weight_sync_schemes instead. + The weight updater is the central piece of the weight update scheme: - In leaf collector nodes, it is responsible for sending the weights to the policy, which can be as simple as @@ -71,6 +75,18 @@ class WeightUpdaterBase(metaclass=abc.ABCMeta): _collector_wrs: list[Any] = None _post_hooks: list[Callable[[], Any]] | None = None + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + import warnings + + warnings.warn( + f"Creating {cls.__name__} which inherits from WeightUpdaterBase is deprecated. " + "Please use WeightSyncScheme from torchrl.weight_update.weight_sync_schemes instead. " + "This will be removed in a future version.", + DeprecationWarning, + stacklevel=2, + ) + @property def post_hooks(self) -> list[Callable[[], None]]: """The list of post-hooks registered to the weight updater.""" diff --git a/torchrl/weight_update/__init__.py b/torchrl/weight_update/__init__.py new file mode 100644 index 00000000000..556064a6113 --- /dev/null +++ b/torchrl/weight_update/__init__.py @@ -0,0 +1,50 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .weight_sync_schemes import ( + DistributedTransport, + DistributedWeightSyncScheme, + MPTransport, + MultiProcessWeightSyncScheme, + NoWeightSyncScheme, + RayActorTransport, + RayModuleTransformReceiver, + RayModuleTransformScheme, + RayModuleTransformSender, + RayTransport, + RayWeightSyncScheme, + RPCTransport, + RPCWeightSyncScheme, + SharedMemTransport, + SharedMemWeightSyncScheme, + TransportBackend, + WeightReceiver, + WeightSender, + WeightStrategy, + WeightSyncScheme, +) + +__all__ = [ + "TransportBackend", + "MPTransport", + "SharedMemTransport", + "RayTransport", + "RayActorTransport", + "RPCTransport", + "DistributedTransport", + "WeightStrategy", + "WeightSender", + "WeightReceiver", + "RayModuleTransformSender", + "RayModuleTransformReceiver", + "WeightSyncScheme", + "MultiProcessWeightSyncScheme", + "SharedMemWeightSyncScheme", + "NoWeightSyncScheme", + "RayWeightSyncScheme", + "RayModuleTransformScheme", + "RPCWeightSyncScheme", + "DistributedWeightSyncScheme", +] diff --git a/torchrl/weight_update/weight_sync_schemes.py b/torchrl/weight_update/weight_sync_schemes.py new file mode 100644 index 00000000000..763753896b2 --- /dev/null +++ b/torchrl/weight_update/weight_sync_schemes.py @@ -0,0 +1,1327 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import abc +import weakref +from collections.abc import Iterator +from typing import Any, Literal, Protocol + +from tensordict import TensorDict, TensorDictBase + +from torch import nn + +__all__ = [ + "TransportBackend", + "MPTransport", + "SharedMemTransport", + "RayTransport", + "RayActorTransport", + "RPCTransport", + "DistributedTransport", + "WeightStrategy", + "WeightSender", + "WeightReceiver", + "RayModuleTransformSender", + "RayModuleTransformReceiver", + "WeightSyncScheme", + "MultiProcessWeightSyncScheme", + "SharedMemWeightSyncScheme", + "NoWeightSyncScheme", + "RayWeightSyncScheme", + "RayModuleTransformScheme", + "RPCWeightSyncScheme", + "DistributedWeightSyncScheme", +] + +# ============================================================================ +# Transport Layer Abstraction +# ============================================================================ + + +class TransportBackend(Protocol): + """Abstract interface for different communication mechanisms.""" + + def send_weights(self, model_id: str, weights: Any) -> None: + """Send weights to the receiver.""" + ... + + def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: + """Receive weights from the sender. Returns (model_id, weights) or None if timeout.""" + ... + + def check_connection(self) -> bool: + """Check if the connection is still alive.""" + ... + + +class MPTransport: + """Multiprocessing transport using pipes. + + Args: + pipe_connection (mp.Pipe): The pipe connection to use for communication. + timeout (float): The timeout for waiting for acknowledgment. Default is 10 seconds. + """ + + def __init__(self, pipe_connection, timeout: float = 10.0): + self.timeout = timeout + self.pipe = pipe_connection + + def send_weights(self, model_id: str, weights: Any) -> None: + """Send weights through the pipe. + + Sends weights and waits for acknowledgment to ensure delivery. + """ + self.pipe.send(((model_id, weights), "update_weights")) + # Wait for acknowledgment + self.check_ack("updated") + + def send_weights_async(self, model_id: str, weights: Any) -> None: + """Send weights through the pipe without waiting for acknowledgment. + + Use wait_ack() to wait for acknowledgment after sending to all workers. + """ + self.pipe.send(((model_id, weights), "update_weights")) + + def wait_ack(self) -> None: + """Wait for acknowledgment from worker.""" + self.check_ack("updated") + + def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: + """Receive weights from the pipe (used in worker process).""" + if self.pipe.poll(timeout): + data_in, msg = self.pipe.recv() + if msg == "update_weights": + model_id, weights = data_in + return model_id, weights + return None + + def send_ack(self, message: str = "updated") -> None: + """Send acknowledgment back to sender.""" + self.pipe.send((None, message)) + + def check_ack(self, message: str = "updated") -> None: + """Check for acknowledgment.""" + _, msg = self.pipe.recv() + if msg != message: + raise RuntimeError(f"Expected acknowledgment '{message}', got '{msg}'") + + def check_connection(self) -> bool: + return not self.pipe.closed + + +class SharedMemTransport: + """Shared memory transport for in-place weight updates. + + This transport updates shared memory tensors directly without message passing. + Workers automatically see weight updates without explicit communication. + + The transport supports lazy registration with pipe-based buffer distribution: + - On first weight send for a model, creates shared memory and sends buffer via pipes + - Workers receive the buffer reference and update their local references + - Subsequent updates are pure in-place shared memory (zero-copy) + + This hybrid approach solves the chicken-and-egg problem: workers can start before + weights are available, and they'll receive the shared buffer references when ready. + + Args: + policy_weights: Dictionary mapping model_id to shared TensorDict weights. + Can be empty if using lazy registration. + auto_register: Whether to automatically register models on first weight send. + Default is True. Set to False to require explicit registration via + register_weights(). + """ + + def __init__( + self, + policy_weights: dict[str, TensorDictBase] | None = None, + auto_register: bool = True, + ): + self._policy_weights = policy_weights if policy_weights is not None else {} + self._auto_register = auto_register + self._pipes = [] # List of pipes to send initial buffer references + self._registered_with_workers = ( + set() + ) # Track which model_ids have been sent to workers + + def register_pipe(self, pipe: Any) -> None: + """Register a pipe for sending buffer references on first weight send. + + Args: + pipe: Pipe connection to a worker process. + """ + if pipe not in self._pipes: + self._pipes.append(pipe) + + def register_weights(self, model_id: str, weights: TensorDictBase) -> None: + """Register a shared memory weights TensorDict for a model. + + This method allows explicit registration of shared weights. It's optional + when auto_register=True (the default), but required when auto_register=False. + + If pipes are registered and this model hasn't been sent to workers yet, + this will trigger sending the buffer reference to all workers. + """ + if not isinstance(weights, TensorDictBase): + raise ValueError(f"Weights must be a TensorDictBase, got {type(weights)}") + + is_new_registration = model_id not in self._policy_weights + self._policy_weights[model_id] = weights + + # If this is a new registration and we have pipes, send buffer to workers + if ( + is_new_registration + and self._pipes + and model_id not in self._registered_with_workers + ): + self._send_buffer_to_workers(model_id, weights) + + def _send_buffer_to_workers(self, model_id: str, buffer: TensorDictBase) -> None: + """Send shared memory buffer reference to all workers via pipes. + + This is called once per model_id when lazy registration occurs. + Workers receive the buffer and update their local references. + + Note: We send buffer.data to avoid gradient tracking issues when crossing + process boundaries. The .data attribute gives us the underlying tensors + without autograd metadata. + """ + for pipe in self._pipes: + # Send special registration message with the shared buffer + # Use .data to strip gradient information (can't serialize non-leaf tensors with requires_grad) + pipe.send(((model_id, buffer.data), "register_shared_weights")) + + # Wait for acknowledgments from all workers + for pipe in self._pipes: + _, msg = pipe.recv() + if msg != "registered": + raise RuntimeError(f"Expected 'registered' acknowledgment, got '{msg}'") + + self._registered_with_workers.add(model_id) + + def send_weights(self, model_id: str, weights: Any) -> None: + """Update weights in-place in shared memory. + + If the model is not registered and auto_register=True, it will be automatically + registered by creating a shared memory copy of the provided weights. The shared + buffer reference is sent to all workers via pipes on first registration, then + subsequent updates are pure in-place shared memory. + + Args: + model_id: Identifier for the model whose weights to update. + weights: New weights to send. Can be a TensorDictBase or dict. + + Raises: + KeyError: If model is not registered and auto_register=False. + ValueError: If weights type is unsupported for auto-registration. + """ + if model_id not in self._policy_weights: + if not self._auto_register: + raise KeyError( + f"Model '{model_id}' not registered in SharedMemTransport. " + f"Available models: {list(self._policy_weights.keys())}. " + f"Either register the model using register_weights() or enable auto_register." + ) + + # Auto-register on first send + if isinstance(weights, TensorDictBase): + # Create shared memory copy + # Unflatten keys if they're flat (e.g., 'module.0.weight' -> nested structure) + # This is necessary for to_module() to work properly + weights_to_share = weights.clone() + # Check if keys are flattened by looking for dots in key names + if any("." in key for key in weights_to_share.keys()): + weights_to_share = weights_to_share.unflatten_keys(".") + shared_buffer = weights_to_share.share_memory_() + elif isinstance(weights, dict): + # Convert dict to TensorDict and share + # Unflatten if keys are flat + weights_td = TensorDict(weights, batch_size=[]) + if any("." in key for key in weights_td.keys()): + weights_td = weights_td.unflatten_keys(".") + shared_buffer = weights_td.share_memory_() + else: + raise ValueError( + f"Cannot auto-register model '{model_id}' with weights type: {type(weights)}. " + f"Supported types for auto-registration: TensorDictBase, dict. " + f"Please manually register shared weights using register_weights()." + ) + + self._policy_weights[model_id] = shared_buffer + + # Send buffer reference to all workers if we have pipes + if self._pipes and model_id not in self._registered_with_workers: + self._send_buffer_to_workers(model_id, shared_buffer) + + shared_weights = self._policy_weights[model_id] + + # Update shared memory in-place (workers see this automatically) + if isinstance(weights, TensorDictBase): + # Unflatten if needed to match shared buffer structure + weights_to_update = weights + if any("." in key for key in weights.keys()): + weights_to_update = weights.unflatten_keys(".") + shared_weights.data.update_( + weights_to_update.data + if hasattr(weights_to_update, "data") + else weights_to_update + ) + elif isinstance(weights, dict): + # For dict updates, check if we need to unflatten keys + if any("." in key for key in weights.keys()): + # Convert to TensorDict, unflatten, then update + weights_td = TensorDict(weights, batch_size=[]) + weights_td = weights_td.unflatten_keys(".") + shared_weights.data.update_(weights_td.data) + else: + # Direct key-by-key update for non-flattened dict + for key, value in weights.items(): + if key in shared_weights.keys(True, True): + shared_weights.set(key, value) + else: + raise ValueError(f"Unsupported weights type: {type(weights)}") + + def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: + """No-op for shared memory - weights are already visible.""" + return None + + def send_ack(self, message: str = "updated") -> None: + """No-op for shared memory - no acknowledgment needed.""" + + def check_ack(self, message: str = "updated") -> None: + """No-op for shared memory - no acknowledgment needed.""" + + def check_connection(self) -> bool: + """Shared memory is always 'connected'.""" + return True + + +class RayTransport: + """Ray transport for communicating with a single Ray collector actor. + + This transport handles weight updates for ONE specific remote collector. + Multiple transports are created for multiple collectors, following the + same pattern as multiprocess collectors. + """ + + def __init__(self, remote_collector=None): + try: + import ray + + self.ray = ray + except ImportError: + raise ImportError("Ray is required for RayTransport") + self._remote_collector = remote_collector + + def send_weights(self, model_id: str, weights: Any) -> None: + """Send weights to the remote collector via Ray. + + Note: We don't pass model_id to the remote collector because remote + collectors don't have weight senders - they apply weights directly to + their local policy. + """ + if self._remote_collector is None: + return + + # Put weights in Ray's object store for efficient distribution + # Ray will automatically deduplicate if the same weights are sent to multiple actors + weights_ref = self.ray.put(weights) + + # Send to the remote collector and wait for completion + # This ensures weights are applied before we continue + future = self._remote_collector.update_policy_weights_.remote( + policy_or_weights=weights_ref + ) + self.ray.wait([future], num_returns=1) + + def send_weights_async(self, model_id: str, weights: Any) -> None: + """Send weights to remote collector without waiting for completion. + + Use wait_ack() to wait for completion after sending to all workers. + """ + if self._remote_collector is None: + return + + weights_ref = self.ray.put(weights) + self._pending_future = self._remote_collector.update_policy_weights_.remote( + policy_or_weights=weights_ref + ) + + def wait_ack(self) -> None: + """Wait for the remote collector to finish applying weights.""" + if hasattr(self, "_pending_future"): + self.ray.wait([self._pending_future], num_returns=1) + del self._pending_future + + def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: + """Ray workers typically don't receive weights through this transport.""" + return None + + def check_connection(self) -> bool: + """Check if Ray is initialized.""" + return self.ray.is_initialized() + + +class RayActorTransport: + """Ray transport for communicating with Ray actors (not collectors). + + This transport is designed for updating models hosted within Ray actors, + such as RayModuleTransform instances. It directly calls the actor's + update_weights method rather than going through collector update methods. + """ + + def __init__(self, actor_ref=None, update_method: str = "tensordict"): + try: + import ray + + self.ray = ray + except ImportError: + raise ImportError("Ray is required for RayActorTransport") + + self._actor_ref = actor_ref + self._update_method = update_method + + def set_actor(self, actor_ref): + """Set the Ray actor reference to communicate with.""" + self._actor_ref = actor_ref + + def send_weights(self, model_id: str, weights: Any) -> None: + """Send weights to the Ray actor.""" + if self._actor_ref is None: + return + + weights_ref = self.ray.put(weights) + + if self._update_method == "tensordict": + self.ray.get( + self._actor_ref._update_weights_tensordict.remote(params=weights_ref) + ) + elif self._update_method == "state_dict": + self.ray.get( + self._actor_ref._update_weights_state_dict.remote( + state_dict=weights_ref + ) + ) + else: + raise ValueError(f"Unknown update method: {self._update_method}") + + def send_weights_async(self, model_id: str, weights: Any) -> None: + """Send weights to Ray actor without waiting for completion. + + Use wait_ack() to wait for completion after sending to all actors. + """ + if self._actor_ref is None: + return + + weights_ref = self.ray.put(weights) + + if self._update_method == "tensordict": + self._pending_future = self._actor_ref._update_weights_tensordict.remote( + params=weights_ref + ) + elif self._update_method == "state_dict": + self._pending_future = self._actor_ref._update_weights_state_dict.remote( + state_dict=weights_ref + ) + else: + raise ValueError(f"Unknown update method: {self._update_method}") + + def wait_ack(self) -> None: + """Wait for Ray actor to finish applying weights.""" + if hasattr(self, "_pending_future"): + self.ray.get(self._pending_future) + del self._pending_future + + def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: + """Ray actor workers receive weights through direct method calls.""" + return None + + def send_ack(self, message: str = "updated") -> None: + """No acknowledgment needed for Ray actors.""" + + def check_ack(self, message: str = "updated") -> None: + """No acknowledgment needed for Ray actors.""" + + def check_connection(self) -> bool: + """Check if Ray is initialized and actor exists.""" + if not self.ray.is_initialized(): + return False + if self._actor_ref is None: + return False + return True + + +class RPCTransport: + """RPC transport for communicating with a single RPC remote collector. + + This transport handles weight updates for ONE specific remote collector via + torch.distributed.rpc. Multiple transports are created for multiple collectors, + following the same pattern as multiprocess collectors. + """ + + def __init__(self, collector_info=None, collector_rref=None, collector_class=None): + self._collector_info = collector_info + self._collector_rref = collector_rref + self._collector_class = collector_class + + def send_weights(self, model_id: str, weights: Any) -> None: + """Send weights to the remote collector via RPC. + + Note: We don't pass model_id to the remote collector because remote + collectors don't have weight senders - they apply weights directly to + their local policy. + """ + if self._collector_info is None or self._collector_rref is None: + return + + from torch.distributed import rpc + + # Send weights to the remote collector and wait for completion + rpc.rpc_sync( + self._collector_info, + self._collector_class.update_policy_weights_, + args=(self._collector_rref, weights), + ) + + def send_weights_async(self, model_id: str, weights: Any) -> None: + """Send weights to remote collector without waiting for completion. + + Use wait_ack() to wait for completion after sending to all workers. + """ + if self._collector_info is None or self._collector_rref is None: + return + + from torch.distributed import rpc + + # Send weights asynchronously + self._pending_future = rpc.rpc_async( + self._collector_info, + self._collector_class.update_policy_weights_, + args=(self._collector_rref, weights), + ) + + def wait_ack(self) -> None: + """Wait for the RPC call to complete.""" + if hasattr(self, "_pending_future"): + self._pending_future.wait() + del self._pending_future + + def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: + """RPC workers typically don't receive weights through this transport.""" + return None + + def check_connection(self) -> bool: + """Check if RPC is initialized.""" + from torch.distributed import rpc + + return rpc.is_initialized() if hasattr(rpc, "is_initialized") else True + + +class DistributedTransport: + """torch.distributed transport for communicating with a single distributed worker. + + This transport handles weight updates for ONE specific distributed worker via + torch.distributed send/recv. Multiple transports are created for multiple workers, + following the same pattern as multiprocess collectors. + """ + + def __init__(self, store=None, rank=None, sync=True): + """Initialize the DistributedTransport. + + Args: + store: TCPStore for communication. + rank: Worker rank (1-indexed). + sync: Whether to use synchronous weight updates. + """ + self._store = store + self._rank = rank + self._sync = sync + + def send_weights(self, model_id: str, weights: Any) -> None: + """Send weights to the distributed worker. + + Note: We don't pass model_id to the remote collector because remote + collectors don't have weight senders - they apply weights directly to + their local policy. + """ + if self._store is None or self._rank is None: + return + + # Instruct worker to expect weight update + self._store.set(f"NODE_{self._rank}_in", b"update_weights") + + # Send weights via torch.distributed + if self._sync: + weights.send(self._rank) + else: + weights.isend(self._rank) + + # Wait for acknowledgment + status = self._store.get(f"NODE_{self._rank}_out") + if status != b"updated": + raise RuntimeError(f"Expected 'updated' but got status {status}.") + self._store.delete_key(f"NODE_{self._rank}_out") + + def send_weights_async(self, model_id: str, weights: Any) -> None: + """Send weights to distributed worker without waiting for acknowledgment. + + Use wait_ack() to wait for acknowledgment after sending to all workers. + """ + if self._store is None or self._rank is None: + return + + # Instruct worker to expect weight update + self._store.set(f"NODE_{self._rank}_in", b"update_weights") + + # Send weights via torch.distributed + if self._sync: + weights.send(self._rank) + else: + weights.isend(self._rank) + + def wait_ack(self) -> None: + """Wait for acknowledgment from distributed worker.""" + if self._store is None or self._rank is None: + return + + status = self._store.get(f"NODE_{self._rank}_out") + if status != b"updated": + raise RuntimeError(f"Expected 'updated' but got status {status}.") + self._store.delete_key(f"NODE_{self._rank}_out") + + def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: + """Distributed workers receive weights through torch.distributed primitives.""" + return None + + def check_connection(self) -> bool: + """Check if torch.distributed is initialized.""" + import torch + + return torch.distributed.is_initialized() + + +# ============================================================================ +# Weight Strategies +# ============================================================================ + + +class WeightStrategy: + """Unified strategy for weight transmission. + + This strategy handles both extraction and application of weights, supporting + both TensorDict and state_dict formats. + + Args: + extract_as (str): Format for extracting weights. Can be: + - "tensordict" (default): Extract weights as TensorDict + - "state_dict": Extract weights as PyTorch state_dict + + The application format is automatically detected based on the type of weights + received (dict -> state_dict, TensorDict -> tensordict). + """ + + def __init__(self, extract_as: Literal["tensordict", "state_dict"] = "tensordict"): + if extract_as not in ("tensordict", "state_dict"): + raise ValueError( + f"extract_as must be 'tensordict' or 'state_dict', got {extract_as}" + ) + self.extract_as = extract_as + + def extract_weights(self, source: Any) -> Any: + """Extract weights from source model in the specified format. + + Args: + source: The model to extract weights from. Can be: + - nn.Module: PyTorch module + - TensorDictBase: TensorDict + - dict: State dictionary + + Returns: + Weights in the format specified by `extract_as` constructor argument. + """ + if self.extract_as == "tensordict": + # Extract as TensorDict + if isinstance(source, nn.Module): + return TensorDict.from_module(source) + elif isinstance(source, TensorDictBase): + return source + elif isinstance(source, dict): + # Convert state_dict to TensorDict + return TensorDict(source, batch_size=[]) + else: + raise ValueError( + f"Unsupported source type for TensorDict extraction: {type(source)}" + ) + else: # state_dict + # Extract as state_dict + if isinstance(source, nn.Module): + return source.state_dict() + elif isinstance(source, dict): + return source + elif isinstance(source, TensorDictBase): + # Convert TensorDict to state_dict + return source.to_dict() + else: + raise ValueError( + f"Unsupported source type for state_dict extraction: {type(source)}" + ) + + def apply_weights(self, destination: Any, weights: Any) -> None: + """Apply weights to destination model. + + The format is automatically detected from the weights type: + - dict -> state_dict format + - TensorDictBase -> tensordict format + + Args: + destination: The model to apply weights to. Can be: + - nn.Module: PyTorch module + - TensorDictBase: TensorDict + - dict: State dictionary + weights: The weights to apply (dict or TensorDictBase). + """ + if weights is None: + return + + # Auto-detect format from weights type + if isinstance(weights, dict): + # Apply state_dict format + if isinstance(destination, nn.Module): + destination.load_state_dict(weights) + elif isinstance(destination, dict): + destination = TensorDict(destination) + weights = TensorDict(weights) + destination.data.update_(weights.data) + elif isinstance(destination, TensorDictBase): + weights_td = TensorDict(weights) + if (dest_keys := sorted(destination.keys(True, True))) != sorted( + weights.keys(True, True) + ): + weights_td = weights_td.unflatten_keys(".") + weights_keys = sorted(weights_td.keys(True, True)) + if dest_keys != weights_keys: + raise ValueError( + f"The keys of the weights and destination do not match: {dest_keys} != {weights_keys}" + ) + destination.data.update_(weights_td.data) + else: + raise ValueError( + f"Unsupported destination type for state_dict: {type(destination)}" + ) + elif isinstance(weights, TensorDictBase): + # Apply TensorDict format + if isinstance(destination, nn.Module): + weights.to_module(destination) + elif isinstance(destination, TensorDictBase): + destination.data.update_(weights.data) + elif isinstance(destination, dict): + destination_td = TensorDict(destination) + if (dest_keys := sorted(destination_td.keys(True, True))) != sorted( + weights.keys(True, True) + ): + weights = weights.unflatten_keys(".") + weights_keys = sorted(weights.keys(True, True)) + if dest_keys != weights_keys: + raise ValueError( + f"The keys of the weights and destination do not match: {dest_keys} != {weights_keys}" + ) + destination_td.data.update_(weights.data) + else: + raise ValueError( + f"Unsupported destination type for TensorDict: {type(destination)}" + ) + else: + raise ValueError( + f"Unsupported weights type: {type(weights)}. Expected dict or TensorDictBase." + ) + + +def _get_strategy(strategy: Literal["tensordict", "state_dict"]) -> WeightStrategy: + """Get strategy object from string name. + + Args: + strategy: Either "tensordict" or "state_dict". + + Returns: + WeightStrategy: Strategy configured with the specified extraction format. + """ + if strategy not in ("tensordict", "state_dict"): + raise ValueError( + f"Unknown strategy: {strategy}. Must be 'tensordict' or 'state_dict'." + ) + return WeightStrategy(extract_as=strategy) + + +# ============================================================================ +# Sender (Trainer/Main Process Side) +# ============================================================================ + + +class WeightSender: + """Sends weights for ONE model to ALL workers. + + This class handles sending weights to all workers via their transports. + Weight extraction is the responsibility of the caller. + """ + + _transport: TransportBackend | None + _transports: dict[int, TransportBackend] + + def __init__(self, scheme: WeightSyncScheme): + self._scheme = scheme + self._transports: dict[int, TransportBackend] = {} # worker_idx -> transport + self._transport: TransportBackend = None + self._model_id = "policy" # Default model ID + self._strategy = _get_strategy(scheme.strategy) + + def register_worker(self, worker_idx: int, pipe_or_context: Any) -> None: + """Register a worker's communication pipe. + + Args: + worker_idx: The worker index. + pipe_or_context: The pipe connection for this worker. + """ + if worker_idx not in self._transports: + self._transports[worker_idx] = self._scheme.create_transport( + pipe_or_context + ) + + def _iterate_transports(self) -> Iterator[TransportBackend]: + if not self._transports: + yield self._transport + else: + yield from self._transports.values() + + def update_weights(self, weights: Any) -> None: + """Send weights to ALL workers for this model. + + Args: + weights: Weights to send. + + Note: + This method sends weights to all workers in parallel (non-blocking), + then waits for all acknowledgments. This is much faster than sending + sequentially when there are many workers. + """ + model_id = getattr(self, "_model_id", "policy") + transports = list(self._iterate_transports()) + + # Send to all workers first (non-blocking if transport supports it) + for transport in transports: + if hasattr(transport, "send_weights_async"): + transport.send_weights_async(model_id, weights) + else: + # Fallback for transports that don't support async send + transport.send_weights(model_id, weights) + + # Wait for all acknowledgments + for transport in transports: + if hasattr(transport, "wait_ack"): + transport.wait_ack() + + +# ============================================================================ +# Receiver (Worker Process Side) +# ============================================================================ + + +class WeightReceiver: + """Receives weights for ONE model in ONE worker. + + This class handles receiving weights via transport and applying them to + a registered model in the worker process. + """ + + def __init__(self, scheme: WeightSyncScheme): + self._scheme = scheme + self._context_ref = None # weakref to inner_collector + self._transport = None # lazy + self._model_ref = None + self._strategy = _get_strategy(scheme.strategy) + + def set_context(self, context: Any) -> None: + """Set the context object (inner_collector) for resolving references. + + Args: + context: The inner collector instance in the worker process. + """ + self._context_ref = weakref.ref(context) + + def register_model(self, model_ref: Any) -> None: + """Register the model to apply weights to. + + Args: + model_ref: Either a direct object reference or a string path like 'policy' or 'env.value_net'. + """ + self._model_ref = model_ref + + def register_worker_transport(self, pipe: Any) -> None: + """Register this worker's communication pipe. + + Args: + pipe: The pipe connection for this worker. + """ + self._transport = self._scheme.create_transport(pipe) + + def apply_weights(self, weights: Any) -> None: + """Apply received weights to registered model. + + Args: + weights: The weights to apply. + """ + if self._model_ref is None: + raise ValueError("No model registered") + + model = self._resolve_model_ref() + self._strategy.apply_weights(model, weights) + + # Send acknowledgment if transport supports it + if hasattr(self._transport, "send_ack"): + self._transport.send_ack("updated") + + def _resolve_model_ref(self) -> Any: + """Resolve model reference to actual object.""" + if isinstance(self._model_ref, str): + if self._context_ref is None: + raise ValueError("Context is required to resolve string references") + context = self._context_ref() + if context is None: + raise ValueError("Context has been garbage collected") + return _resolve_model(context, self._model_ref) + return self._model_ref + + +class RayModuleTransformSender(WeightSender): + """Specialized sender for :class:`~torchrl.envs.transforms.module.RayModuleTransform` actors. + + This sender handles weight updates for models hosted within Ray actors. + Unlike the base WeightSender which uses pipes for multiprocessing, + this sender directly communicates with Ray actors via their remote methods. + + For Ray actors, there is typically only one shared actor instance, so we + store a single transport rather than per-worker transports. + """ + + def __init__(self, scheme: RayModuleTransformScheme): + super().__init__(scheme) + self._actor_ref = None + self._single_transport = None + self._context_ref = None + self._model_id_str = None + + def set_context(self, context: Any, model_id: str) -> None: + """Set context for lazy actor resolution. + + Args: + context: The collector instance. + model_id: String path to the Ray actor (e.g., "env.transform[0]"). + """ + self._context_ref = weakref.ref(context) + self._model_id_str = model_id + + def register_worker(self, worker_idx: int, pipe_or_context: Any) -> None: + """For Ray actors, worker registration is a no-op. + + Ray actors are shared across all workers, so we don't need per-worker + transports. The actor reference is resolved lazily on first use. + """ + + def update_weights(self, weights: Any) -> None: + """Send weights to the Ray actor. + + Args: + weights: Weights to send. + """ + if self._single_transport is None: + self._initialize_transport() + + if self._single_transport is not None: + model_id = getattr(self, "_model_id", "policy") + self._single_transport.send_weights(model_id, weights) + + def _initialize_transport(self) -> None: + """Lazily initialize the transport by resolving the actor reference.""" + if self._context_ref is None or self._model_id_str is None: + return + + context = self._context_ref() + if context is None: + return + + model = _resolve_model(context, self._model_id_str) + if hasattr(model, "_actor"): + self._actor_ref = model._actor + self._single_transport = self._scheme.create_transport(model) + elif type(model).__name__ == "ActorHandle": + self._actor_ref = model + self._single_transport = self._scheme.create_transport(model) + + +class RayModuleTransformReceiver(WeightReceiver): + """Specialized receiver for RayModuleTransform actors. + + This receiver handles weight updates within Ray actors. + Since Ray actors receive weights through direct method calls, + this receiver primarily validates and applies weights locally. + """ + + def __init__(self, scheme: RayModuleTransformScheme): + super().__init__(scheme) + + def register_worker_transport(self, actor_or_context: Any) -> None: + """Register the Ray actor's transport. + + Args: + actor_or_context: Either a Ray actor reference or a context object. + """ + self._transport = self._scheme.create_transport(actor_or_context) + + def apply_weights(self, weights: Any) -> None: + """Apply received weights to registered model. + + For Ray actors, weights are applied directly to the module + within the actor's process space. + + Args: + weights: The weights to apply. + """ + if self._model_ref is None: + raise ValueError("No model registered") + + model = self._resolve_model_ref() + self._strategy.apply_weights(model, weights) + + +# ============================================================================ +# Weight Synchronization Schemes +# ============================================================================ + + +class WeightSyncScheme(metaclass=abc.ABCMeta): + """Configuration for how to synchronize ONE model across workers. + + A scheme is a pure configuration object that specifies: + - The transmission strategy (state_dict vs tensordict) + - How to create the transport for communication + + Each scheme is responsible for one model type but is shared across all workers. + """ + + def __init__(self, strategy: Literal["state_dict", "tensordict"] = "state_dict"): + self.strategy = strategy + + @abc.abstractmethod + def create_transport(self, pipe_or_context: Any) -> TransportBackend: + """Create transport for communication. + + Args: + pipe_or_context: Either a pipe connection or context object to extract pipe from. + + Returns: + A transport backend instance. + """ + ... + + def create_sender(self) -> WeightSender: + """Create a sender for this scheme. + + Returns: + WeightSender instance configured for this scheme. + """ + return WeightSender(self) + + def create_receiver(self) -> WeightReceiver: + """Create a receiver for this scheme. + + Returns: + WeightReceiver instance configured for this scheme. + """ + return WeightReceiver(self) + + +class MultiProcessWeightSyncScheme(WeightSyncScheme): + """Weight synchronization for multiprocess operations using pipes. + + This scheme creates transports that communicate via multiprocessing pipes. + """ + + def create_transport(self, pipe: Any) -> TransportBackend: + """Create an MPTransport using the provided pipe.""" + return MPTransport(pipe) + + +class SharedMemWeightSyncScheme(WeightSyncScheme): + """Weight synchronization using shared memory. + + This scheme mimics the old WeightUpdater behavior by using shared memory + for in-place weight updates. Workers automatically see weight updates + without explicit message passing. + + By default, this scheme uses lazy registration: models are automatically + registered on the first weight send. This makes it seamless to use with + configuration systems like Hydra where schemes are created before models + are available. + + Args: + policy_weights: Dictionary mapping model_id to shared TensorDict weights. + Can be empty if using lazy registration (auto_register=True). + strategy: The weight transmission strategy (default: "tensordict"). + auto_register: Whether to automatically register models on first weight send. + Default is True. Set to False to require explicit registration via + register_shared_weights(). + + Example: + >>> # With auto-registration (default) - works with Hydra configs + >>> scheme = SharedMemWeightSyncScheme() + >>> # Models are auto-registered on first weight send + + >>> # With explicit registration + >>> scheme = SharedMemWeightSyncScheme(auto_register=False) + >>> shared_weights = TensorDict.from_module(model).share_memory_() + >>> scheme.register_shared_weights("policy", shared_weights) + """ + + def __init__( + self, + policy_weights: dict[str, TensorDictBase] | None = None, + strategy: str = "tensordict", + auto_register: bool = True, + ): + super().__init__(strategy) + self.policy_weights = policy_weights if policy_weights is not None else {} + self.auto_register = auto_register + # Create a single shared transport for all workers + self._shared_transport = SharedMemTransport( + self.policy_weights, auto_register=auto_register + ) + + def register_shared_weights(self, model_id: str, weights: TensorDictBase) -> None: + """Register shared memory weights for a model. + + This method allows explicit registration of shared weights. It's optional + when auto_register=True (the default), but required when auto_register=False. + + Args: + model_id: Identifier for the model. + weights: Shared memory TensorDict containing the model's weights. + """ + self.policy_weights[model_id] = weights + self._shared_transport.register_weights(model_id, weights) + + def create_transport(self, pipe_or_context: Any) -> TransportBackend: + """Create shared memory transport and register pipe for lazy buffer distribution. + + For lazy registration to work, we register each worker's pipe with the transport. + On first weight send, the transport will send buffer references via these pipes. + + Returns the shared transport instance that all workers will use. + Since this is shared memory, there's only one transport shared by all workers. + """ + # Register the pipe for lazy buffer distribution + if pipe_or_context is not None: + self._shared_transport.register_pipe(pipe_or_context) + return self._shared_transport + + +class NoWeightSyncScheme(WeightSyncScheme): + """No-op weight synchronization scheme. + + This scheme disables weight synchronization entirely. + """ + + def create_transport(self, pipe_or_context: Any) -> TransportBackend: + """Returns None as no transport is needed.""" + # Return a dummy transport that does nothing + class NoOpTransport: + def send_weights(self, model_id: str, weights: Any) -> None: + pass + + def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: + return None + + def check_connection(self) -> bool: + return True + + return NoOpTransport() + + +class RayWeightSyncScheme(WeightSyncScheme): + """Weight synchronization for Ray distributed computing. + + This scheme uses Ray's object store and remote calls to synchronize weights + across distributed workers (Ray actors). + + Each remote collector gets its own transport, following the same pattern + as multiprocess collectors. + """ + + def create_transport(self, pipe_or_context: Any) -> TransportBackend: + """Create Ray-based transport for a specific remote collector. + + Args: + pipe_or_context: The Ray actor handle for the remote collector. + + Returns: + RayTransport configured for this specific remote collector. + """ + return RayTransport(remote_collector=pipe_or_context) + + +class RayModuleTransformScheme(WeightSyncScheme): + """Weight synchronization for RayModuleTransform actors. + + This scheme is designed specifically for updating models hosted within + Ray actors, such as RayModuleTransform instances. It creates a transport + that directly calls the actor's weight update methods. + + Args: + strategy (str): The weight transmission strategy ("state_dict" or "tensordict"). + Default is "tensordict". + """ + + def __init__(self, strategy: str = "tensordict"): + super().__init__(strategy) + + def create_transport(self, pipe_or_context: Any) -> TransportBackend: + """Create RayActorTransport for the given actor. + + Args: + pipe_or_context: Either a Ray actor reference or a context object + from which to extract the actor reference. + + Returns: + RayActorTransport configured with the actor reference. + """ + actor_ref = self._extract_actor_ref(pipe_or_context) + return RayActorTransport(actor_ref=actor_ref, update_method=self.strategy) + + def _extract_actor_ref(self, pipe_or_context: Any) -> Any: + """Extract the Ray actor reference from the context. + + Args: + pipe_or_context: Either a direct actor reference or an object + with an `_actor` attribute. + + Returns: + The Ray actor reference. + """ + if hasattr(pipe_or_context, "_actor"): + return pipe_or_context._actor + return pipe_or_context + + def create_sender(self) -> RayModuleTransformSender: + """Create a specialized sender for Ray actor communication.""" + return RayModuleTransformSender(self) + + def create_receiver(self) -> RayModuleTransformReceiver: + """Create a specialized receiver for Ray actor communication.""" + return RayModuleTransformReceiver(self) + + +class RPCWeightSyncScheme(WeightSyncScheme): + """Weight synchronization for torch.distributed.rpc. + + This scheme uses RPC calls to synchronize weights across distributed + workers. Each remote collector gets its own transport, following the + same pattern as multiprocess collectors. + """ + + def create_transport(self, pipe_or_context: Any) -> TransportBackend: + """Create RPC-based transport for a specific remote collector. + + Args: + pipe_or_context: A tuple of (collector_info, collector_rref, collector_class) + for the remote collector. + + Returns: + RPCTransport configured for this specific remote collector. + """ + if isinstance(pipe_or_context, tuple) and len(pipe_or_context) == 3: + collector_info, collector_rref, collector_class = pipe_or_context + return RPCTransport( + collector_info=collector_info, + collector_rref=collector_rref, + collector_class=collector_class, + ) + # If just passed the info directly + return RPCTransport(collector_info=pipe_or_context) + + +class DistributedWeightSyncScheme(WeightSyncScheme): + """Weight synchronization for torch.distributed. + + This scheme uses torch.distributed primitives (send/recv) to synchronize + weights across distributed workers. Each worker gets its own transport, + following the same pattern as multiprocess collectors. + + Args: + backend (str): The distributed backend ("gloo", "nccl", etc.) + sync (bool): Whether to use synchronous weight updates + """ + + def __init__(self, backend: str = "gloo", sync: bool = True): + super().__init__() + self.backend = backend + self.sync = sync + + def create_transport(self, pipe_or_context: Any) -> TransportBackend: + """Create distributed transport for a specific worker. + + Args: + pipe_or_context: A tuple of (store, rank) for the worker. + + Returns: + DistributedTransport configured for this specific worker. + """ + if isinstance(pipe_or_context, tuple) and len(pipe_or_context) == 2: + store, rank = pipe_or_context + return DistributedTransport(store=store, rank=rank, sync=self.sync) + # Fallback - shouldn't normally happen + return DistributedTransport() + + +# ============================================================================ +# Helper Functions +# ============================================================================ + + +def _resolve_model(context: Any, model_id: str) -> Any: + """Resolve model_id like 'policy' or 'env.value_net' to actual object. + + Also processes getitem notation like 'env.transform[0]' to actual object. + + Args: + context: The context object (collector or inner_collector). + model_id: A string address like "policy" or "env.value_net". + + Returns: + The object at the specified address. + + Examples: + _resolve_model(collector, "policy") # -> collector.policy + _resolve_model(collector, "env.value_net") # -> collector.env.value_net + """ + parts = model_id.split(".") + obj = context + for i, part in enumerate(parts): + if "[" in part: + key, *indices = part.split("[") + indices = [int(index[:-1]) for index in indices] + try: + obj = getattr(obj, key) + except AttributeError: + raise AttributeError( + f"Attribute {key} from {parts[:i+1]} not found in {'.'.join(parts[:i])}={obj}" + ) + for index in indices: + obj = obj[index] + else: + try: + obj = getattr(obj, part) + except AttributeError: + raise AttributeError( + f"Attribute {part} from {parts[:i+1]} not found in {'.'.join(parts[:i])}={obj}" + ) + return obj From a37298fd23ba28ae58c3f60195e6122882c76c7e Mon Sep 17 00:00:00 2001 From: vmoens Date: Sat, 18 Oct 2025 10:53:02 -0700 Subject: [PATCH 2/6] Update [ghstack-poisoned] --- docs/source/reference/collectors.rst | 292 +++++++++++++++++- examples/collectors/weight_sync_collectors.py | 269 ++++++++++++++++ examples/collectors/weight_sync_standalone.py | 201 ++++++++++++ 3 files changed, 761 insertions(+), 1 deletion(-) create mode 100644 examples/collectors/weight_sync_collectors.py create mode 100644 examples/collectors/weight_sync_standalone.py diff --git a/docs/source/reference/collectors.rst b/docs/source/reference/collectors.rst index ca8dca38e4e..aadeb6644d5 100644 --- a/docs/source/reference/collectors.rst +++ b/docs/source/reference/collectors.rst @@ -117,9 +117,299 @@ try to limit the cases where a deepcopy will be executed. The following chart sh Policy copy decision tree in Collectors. -Weight Synchronization in Distributed Environments +Weight Synchronization using Weight Update Schemes -------------------------------------------------- +RL pipelines are typically split in two big computational buckets: training, and inference. +While the inference pipeline sends data to the training one, the training pipeline needs to occasionally +synchronize its weights with the inference one. +In the most basic setting (fully synchronized data collection with traditional neural networks), the same weights are +used in both instances. From there, anything can happen: + +- In multiprocessed or distributed settings, several copies of the policy can be held by the inference workers (named + `DataCollectors` in TorchRL). When synchronizing the weights, each worker needs to receive a new copy of the weights + for his instance of the policy. +- In some cases, the environment or the postprocessing hooks can rely on the usage of a model which itself needs + synchronization. This means that there can be multiple ends in the data transfer API and one needs to think beyond + policy-to-policy weight synchronization strategies. +- In the LLM world, the inference engine and the training one are very different: they will use different libraries, + kernels and calling APIs (e.g., `generate` vs. `forward`). The weight format can also be drastically different (quantized + vs non-quantized). + This makes the weight synchronization much more complex, as one cannot simply dump and load a state dict on both ends. +- One typically also has to choose who instantiates a transfer: should this come from the inference engine who actively + asks for new weights, or must it only be the trainer who pushes its weights to the workers? An intermediate approach + is to store the weights on some intermediary server and let the workers fetch them when necessary. + +TorchRL tries to account for each of these problems in a flexible manner. We individuate four basic components in a weight +transfer: + +- A `Sender` class that somehow gets the weights (or a reference to them) and initializes the transfer; +- A `Receiver` class that casts the weights to the destination module (policy or other utility module); +- A `Transport` class that codes up the actual transfer of the weights (through shared memory, nccl or anything else). +- A Scheme that defines what sender, receiver and transport have to be used and how to initialize them. + +Each of these classes is detailed below. + +Usage Examples +~~~~~~~~~~~~~~ + +.. note:: + **Runnable versions** of these examples are available in the repository: + + - `examples/collectors/weight_sync_standalone.py `_: Standalone weight synchronization + - `examples/collectors/weight_sync_collectors.py `_: Collector integration + +Using Weight Update Schemes Independently +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Weight update schemes can be used outside of collectors for custom synchronization scenarios. Here's a basic example: + +.. code-block:: python + + import torch + import torch.nn as nn + from torch import multiprocessing as mp + from tensordict import TensorDict + from torchrl.weight_update import ( + MultiProcessWeightSyncScheme, + SharedMemWeightSyncScheme, + ) + + # Create a simple policy + policy = nn.Linear(4, 2) + + # Example 1: Multiprocess weight synchronization with state_dict + # -------------------------------------------------------------- + # On the main process side (trainer): + scheme = MultiProcessWeightSyncScheme(strategy="state_dict") + sender = scheme.create_sender() + + # Register worker pipes + parent_pipe, child_pipe = mp.Pipe() + sender.register_worker(worker_idx=0, pipe_or_context=parent_pipe) + + # Send weights to workers + weights = policy.state_dict() + sender.update_weights(weights) + + # On the worker process side: + # receiver = scheme.create_receiver() + # receiver.register_model(policy) + # receiver.register_worker_transport(child_pipe) + # # Receive and apply weights + # result = receiver._transport.receive_weights(timeout=5.0) + # if result is not None: + # model_id, weights = result + # receiver.apply_weights(weights) + + # Example 2: Shared memory weight synchronization + # ------------------------------------------------ + # Create shared memory scheme with auto-registration + shared_scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) + shared_sender = shared_scheme.create_sender() + + # Register worker pipes for lazy registration + parent_pipe2, child_pipe2 = mp.Pipe() + shared_sender.register_worker(worker_idx=0, pipe_or_context=parent_pipe2) + + # Send weights (automatically creates shared buffer on first send) + weights_td = TensorDict.from_module(policy) + shared_sender.update_weights(weights_td) + + # Workers automatically see updates via shared memory! + +Using Weight Update Schemes with Collectors +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Weight update schemes integrate seamlessly with TorchRL collectors, enabling efficient weight synchronization +across multiple inference workers: + +.. code-block:: python + + import torch.nn as nn + from tensordict.nn import TensorDictModule + from torchrl.collectors import SyncDataCollector, MultiSyncDataCollector + from torchrl.envs import GymEnv + from torchrl.weight_update import ( + MultiProcessWeightSyncScheme, + SharedMemWeightSyncScheme, + NoWeightSyncScheme, + ) + + # Create environment and policy + env = GymEnv("CartPole-v1") + policy = TensorDictModule( + nn.Linear(env.observation_spec["observation"].shape[-1], + env.action_spec.shape[-1]), + in_keys=["observation"], + out_keys=["action"], + ) + + # Example 1: Single collector with multiprocess scheme + # ----------------------------------------------------- + scheme = MultiProcessWeightSyncScheme(strategy="state_dict") + + collector = SyncDataCollector( + create_env_fn=lambda: GymEnv("CartPole-v1"), + policy=policy, + frames_per_batch=64, + total_frames=1000, + weight_sync_schemes={"policy": scheme}, + ) + + # Collect data and update weights periodically + for i, data in enumerate(collector): + # ... training step with data ... + + # Update policy weights every N iterations + if i % 10 == 0: + new_weights = policy.state_dict() + collector.update_policy_weights_(new_weights) + + collector.shutdown() + + # Example 2: Multiple collectors with shared memory + # -------------------------------------------------- + # Shared memory is more efficient for frequent updates + shared_scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) + + collector = MultiSyncDataCollector( + create_env_fn=[ + lambda: GymEnv("CartPole-v1"), + lambda: GymEnv("CartPole-v1"), + lambda: GymEnv("CartPole-v1"), + ], + policy=policy, + frames_per_batch=192, + total_frames=10000, + weight_sync_schemes={"policy": shared_scheme}, + ) + + # Workers automatically see weight updates via shared memory + for data in collector: + # ... training ... + collector.update_policy_weights_(TensorDict.from_module(policy)) + + collector.shutdown() + + # Example 3: Multiple models (policy + value network) + # --------------------------------------------------- + value_net = TensorDictModule( + nn.Linear(env.observation_spec["observation"].shape[-1], 1), + in_keys=["observation"], + out_keys=["value"], + ) + + weight_sync_schemes = { + "policy": MultiProcessWeightSyncScheme(strategy="state_dict"), + "value": MultiProcessWeightSyncScheme(strategy="state_dict"), + } + + collector = SyncDataCollector( + create_env_fn=lambda: GymEnv("CartPole-v1"), + policy=policy, + frames_per_batch=64, + total_frames=1000, + weight_sync_schemes=weight_sync_schemes, + ) + + # Update multiple models independently + collector.update_policy_weights_( + {"policy": policy.state_dict(), "value": value_net.state_dict()} + ) + + collector.shutdown() + + # Example 4: Disable weight synchronization + # ------------------------------------------ + # Useful for debugging or when using a shared policy reference + no_sync_scheme = NoWeightSyncScheme() + + collector = SyncDataCollector( + create_env_fn=lambda: GymEnv("CartPole-v1"), + policy=policy, + frames_per_batch=64, + total_frames=1000, + weight_sync_schemes={"policy": no_sync_scheme}, + ) + +.. note:: + When using ``SharedMemWeightSyncScheme``, weight updates are zero-copy and extremely fast since all + processes share the same memory buffers. This is ideal for frequent weight updates but requires all + processes to be on the same machine. + +.. note:: + The ``strategy`` parameter determines the weight format: ``"state_dict"`` uses PyTorch's native state + dictionaries, while ``"tensordict"`` uses TensorDict format which can be more efficient for structured + models and supports advanced features like lazy initialization. + +Weight Senders +~~~~~~~~~~~~~~ + +.. currentmodule:: torchrl.weight_update + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + WeightSender + RayModuleTransformSender + +Weight Receivers +~~~~~~~~~~~~~~~~ + +.. currentmodule:: torchrl.weight_update + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + WeightReceiver + RayModuleTransformReceiver + +Transports +~~~~~~~~~~ + +.. currentmodule:: torchrl.weight_update + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + TransportBackend + MPTransport + SharedMemTransport + RayTransport + RayActorTransport + RPCTransport + DistributedTransport + +Schemes +~~~~~~~ + +.. currentmodule:: torchrl.weight_update + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + WeightSyncScheme + MultiProcessWeightSyncScheme + SharedMemWeightSyncScheme + NoWeightSyncScheme + RayWeightSyncScheme + RayModuleTransformScheme + RPCWeightSyncScheme + DistributedWeightSyncScheme + +Legacy: Weight Synchronization in Distributed Environments +---------------------------------------------------------- + +.. warning:: + The `WeightUpdater` is considered legacy as per the 0.11 release and will be deprecated soon. + The Weight update schemes, which provides more flexibility and a better compatibility with heavy + weight transfers (e.g., LLMs) is to be preferred. + In distributed and multiprocessed environments, ensuring that all instances of a policy are synchronized with the latest trained weights is crucial for consistent performance. The API introduces a flexible and extensible mechanism for updating policy weights across different devices and processes, accommodating various deployment scenarios. diff --git a/examples/collectors/weight_sync_collectors.py b/examples/collectors/weight_sync_collectors.py new file mode 100644 index 00000000000..d18365e9725 --- /dev/null +++ b/examples/collectors/weight_sync_collectors.py @@ -0,0 +1,269 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Weight Synchronization Schemes - Collector Integration +======================================================= + +This example demonstrates how to use weight synchronization schemes with TorchRL +collectors for efficient weight updates across multiple inference workers. + +The examples show different synchronization strategies and use cases including +single collectors, multiple collectors, multiple models, and no synchronization. +""" + +import torch +import torch.nn as nn +from tensordict import TensorDict +from tensordict.nn import TensorDictModule +from torchrl.collectors import SyncDataCollector, MultiSyncDataCollector +from torchrl.envs import GymEnv +from torchrl.weight_update import ( + MultiProcessWeightSyncScheme, + SharedMemWeightSyncScheme, + NoWeightSyncScheme, +) + + +def example_single_collector_multiprocess(): + """Example 1: Single collector with multiprocess scheme.""" + print("\n" + "="*70) + print("Example 1: Single Collector with Multiprocess Scheme") + print("="*70) + + # Create environment and policy + env = GymEnv("CartPole-v1") + policy = TensorDictModule( + nn.Linear( + env.observation_spec["observation"].shape[-1], + env.action_spec.shape[-1] + ), + in_keys=["observation"], + out_keys=["action"], + ) + env.close() + + # Create weight sync scheme + scheme = MultiProcessWeightSyncScheme(strategy="state_dict") + + print("Creating collector with multiprocess weight sync...") + collector = SyncDataCollector( + create_env_fn=lambda: GymEnv("CartPole-v1"), + policy=policy, + frames_per_batch=64, + total_frames=200, + weight_sync_schemes={"policy": scheme}, + ) + + # Collect data and update weights periodically + print("Collecting data...") + for i, data in enumerate(collector): + print(f"Iteration {i}: Collected {data.numel()} transitions") + + # Update policy weights every 2 iterations + if i % 2 == 0: + new_weights = policy.state_dict() + collector.update_policy_weights_(new_weights) + print(f" → Updated policy weights") + + if i >= 2: # Just run a few iterations for demo + break + + collector.shutdown() + print("✓ Single collector example completed!\n") + + +def example_multi_collector_shared_memory(): + """Example 2: Multiple collectors with shared memory.""" + print("\n" + "="*70) + print("Example 2: Multiple Collectors with Shared Memory") + print("="*70) + + # Create environment and policy + env = GymEnv("CartPole-v1") + policy = TensorDictModule( + nn.Linear( + env.observation_spec["observation"].shape[-1], + env.action_spec.shape[-1] + ), + in_keys=["observation"], + out_keys=["action"], + ) + env.close() + + # Shared memory is more efficient for frequent updates + scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) + + print("Creating multi-collector with shared memory...") + collector = MultiSyncDataCollector( + create_env_fn=[ + lambda: GymEnv("CartPole-v1"), + lambda: GymEnv("CartPole-v1"), + lambda: GymEnv("CartPole-v1"), + ], + policy=policy, + frames_per_batch=192, + total_frames=400, + weight_sync_schemes={"policy": scheme}, + ) + + # Workers automatically see weight updates via shared memory + print("Collecting data...") + for i, data in enumerate(collector): + print(f"Iteration {i}: Collected {data.numel()} transitions") + + # Update weights frequently (shared memory makes this very fast) + collector.update_policy_weights_(TensorDict.from_module(policy)) + print(f" → Updated policy weights via shared memory") + + if i >= 1: # Just run a couple iterations for demo + break + + collector.shutdown() + print("✓ Multi-collector with shared memory example completed!\n") + + +def example_multiple_models(): + """Example 3: Multiple models (policy + value network).""" + print("\n" + "="*70) + print("Example 3: Multiple Models (Policy + Value Network)") + print("="*70) + + # Create environment + env = GymEnv("CartPole-v1") + + # Create policy and value network + policy = TensorDictModule( + nn.Linear( + env.observation_spec["observation"].shape[-1], + env.action_spec.shape[-1] + ), + in_keys=["observation"], + out_keys=["action"], + ) + + value_net = TensorDictModule( + nn.Linear( + env.observation_spec["observation"].shape[-1], + 1 + ), + in_keys=["observation"], + out_keys=["value"], + ) + env.close() + + # Create separate schemes for each model + weight_sync_schemes = { + "policy": MultiProcessWeightSyncScheme(strategy="state_dict"), + "value": MultiProcessWeightSyncScheme(strategy="state_dict"), + } + + print("Creating collector with multiple models...") + collector = SyncDataCollector( + create_env_fn=lambda: GymEnv("CartPole-v1"), + policy=policy, + frames_per_batch=64, + total_frames=200, + weight_sync_schemes=weight_sync_schemes, + ) + + print("Collecting data...") + for i, data in enumerate(collector): + print(f"Iteration {i}: Collected {data.numel()} transitions") + + # Update both models independently + collector.update_policy_weights_( + { + "policy": policy.state_dict(), + "value": value_net.state_dict() + } + ) + print(f" → Updated both policy and value network weights") + + if i >= 1: + break + + collector.shutdown() + print("✓ Multiple models example completed!\n") + + +def example_no_weight_sync(): + """Example 4: Disable weight synchronization.""" + print("\n" + "="*70) + print("Example 4: Disable Weight Synchronization") + print("="*70) + + # Create environment and policy + env = GymEnv("CartPole-v1") + policy = TensorDictModule( + nn.Linear( + env.observation_spec["observation"].shape[-1], + env.action_spec.shape[-1] + ), + in_keys=["observation"], + out_keys=["action"], + ) + env.close() + + # Useful for debugging or when using a shared policy reference + scheme = NoWeightSyncScheme() + + print("Creating collector with no weight synchronization...") + collector = SyncDataCollector( + create_env_fn=lambda: GymEnv("CartPole-v1"), + policy=policy, + frames_per_batch=64, + total_frames=200, + weight_sync_schemes={"policy": scheme}, + ) + + print("Collecting data (no weight updates)...") + for i, data in enumerate(collector): + print(f"Iteration {i}: Collected {data.numel()} transitions") + + # Weight updates are no-ops with NoWeightSyncScheme + collector.update_policy_weights_(policy.state_dict()) + print(f" → Weight update call was a no-op") + + if i >= 1: + break + + collector.shutdown() + print("✓ No weight sync example completed!\n") + + +def main(): + """Run all examples.""" + print("\n" + "="*70) + print("Weight Synchronization Schemes - Collector Integration Examples") + print("="*70) + + # Set multiprocessing start method + import torch.multiprocessing as mp + try: + mp.set_start_method('spawn') + except RuntimeError: + pass # Already set + + # Run examples + example_single_collector_multiprocess() + example_multi_collector_shared_memory() + example_multiple_models() + example_no_weight_sync() + + print("\n" + "="*70) + print("All examples completed successfully!") + print("="*70) + print("\nKey takeaways:") + print(" • MultiProcessWeightSyncScheme: Good for general multiprocess scenarios") + print(" • SharedMemWeightSyncScheme: Fast zero-copy updates for same-machine workers") + print(" • Multiple models: Each model can have its own sync scheme") + print(" • NoWeightSyncScheme: Useful for debugging or shared policy references") + print("="*70 + "\n") + + +if __name__ == "__main__": + main() + diff --git a/examples/collectors/weight_sync_standalone.py b/examples/collectors/weight_sync_standalone.py new file mode 100644 index 00000000000..83492256412 --- /dev/null +++ b/examples/collectors/weight_sync_standalone.py @@ -0,0 +1,201 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Weight Synchronization Schemes - Standalone Usage +================================================== + +This example demonstrates how to use weight synchronization schemes independently +of collectors for custom synchronization scenarios. + +The weight synchronization infrastructure provides flexible sender/receiver patterns +that can be used for various multiprocessing scenarios. +""" + +import torch +import torch.nn as nn +from torch import multiprocessing as mp +from tensordict import TensorDict +from torchrl.weight_update import ( + MultiProcessWeightSyncScheme, + SharedMemWeightSyncScheme, +) + + +def worker_process_mp(child_pipe, model_state): + """Worker process that receives weights via multiprocessing pipe.""" + print("Worker: Starting...") + + # Create a policy on the worker side + policy = nn.Linear(4, 2) + with torch.no_grad(): + policy.weight.fill_(0.0) + policy.bias.fill_(0.0) + + # Create receiver and register the policy + scheme = MultiProcessWeightSyncScheme(strategy="state_dict") + receiver = scheme.create_receiver() + receiver.register_model(policy) + receiver.register_worker_transport(child_pipe) + + print(f"Worker: Before update - weight sum: {policy.weight.sum().item():.4f}") + + # Receive and apply weights + result = receiver._transport.receive_weights(timeout=5.0) + if result is not None: + model_id, weights = result + receiver.apply_weights(weights) + print(f"Worker: After update - weight sum: {policy.weight.sum().item():.4f}") + else: + print("Worker: No weights received") + + # Store final state for verification + model_state['weight_sum'] = policy.weight.sum().item() + model_state['bias_sum'] = policy.bias.sum().item() + + +def worker_process_shared_mem(child_pipe, model_state): + """Worker process that receives shared memory buffer reference.""" + print("SharedMem Worker: Starting...") + + # Create a policy on the worker side + policy = nn.Linear(4, 2) + + # Wait for shared memory buffer registration + if child_pipe.poll(timeout=10.0): + data, msg = child_pipe.recv() + if msg == "register_shared_weights": + model_id, shared_weights = data + print(f"SharedMem Worker: Received shared buffer for model '{model_id}'") + # Apply shared weights to policy + shared_weights.to_module(policy) + # Send acknowledgment + child_pipe.send((None, "registered")) + + # Small delay to ensure main process updates shared memory + import time + time.sleep(0.5) + + print(f"SharedMem Worker: weight sum: {policy.weight.sum().item():.4f}") + + # Store final state for verification + model_state['weight_sum'] = policy.weight.sum().item() + model_state['bias_sum'] = policy.bias.sum().item() + + +def example_multiprocess_sync(): + """Example 1: Multiprocess weight synchronization with state_dict.""" + print("\n" + "="*70) + print("Example 1: Multiprocess Weight Synchronization") + print("="*70) + + # Create a simple policy on main process + policy = nn.Linear(4, 2) + with torch.no_grad(): + policy.weight.fill_(1.0) + policy.bias.fill_(0.5) + + print(f"Main: Policy weight sum: {policy.weight.sum().item():.4f}") + + # Create scheme and sender + scheme = MultiProcessWeightSyncScheme(strategy="state_dict") + sender = scheme.create_sender() + + # Create pipe for communication + parent_pipe, child_pipe = mp.Pipe() + sender.register_worker(worker_idx=0, pipe_or_context=parent_pipe) + + # Start worker process + manager = mp.Manager() + model_state = manager.dict() + process = mp.Process(target=worker_process_mp, args=(child_pipe, model_state)) + process.start() + + # Send weights to worker + weights = policy.state_dict() + print("Main: Sending weights to worker...") + sender.update_weights(weights) + + # Wait for worker to complete + process.join(timeout=10.0) + + if process.is_alive(): + print("Warning: Worker process did not terminate in time") + process.terminate() + else: + print(f"Main: Worker completed. Worker's weight sum: {model_state['weight_sum']:.4f}") + print(f"✓ Weight synchronization successful!") + + +def example_shared_memory_sync(): + """Example 2: Shared memory weight synchronization.""" + print("\n" + "="*70) + print("Example 2: Shared Memory Weight Synchronization") + print("="*70) + + # Create a simple policy + policy = nn.Linear(4, 2) + + # Create shared memory scheme with auto-registration + scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) + sender = scheme.create_sender() + + # Create pipe for lazy registration + parent_pipe, child_pipe = mp.Pipe() + sender.register_worker(worker_idx=0, pipe_or_context=parent_pipe) + + # Start worker process + manager = mp.Manager() + model_state = manager.dict() + process = mp.Process(target=worker_process_shared_mem, args=(child_pipe, model_state)) + process.start() + + # Send weights (automatically creates shared buffer on first send) + weights_td = TensorDict.from_module(policy) + with torch.no_grad(): + weights_td["weight"].fill_(2.0) + weights_td["bias"].fill_(1.0) + + print(f"Main: Sending weights via shared memory...") + sender.update_weights(weights_td) + + # Workers automatically see updates via shared memory! + print("Main: Weights are now in shared memory, workers can access them") + + # Wait for worker to complete + process.join(timeout=10.0) + + if process.is_alive(): + print("Warning: Worker process did not terminate in time") + process.terminate() + else: + print(f"Main: Worker completed. Worker's weight sum: {model_state['weight_sum']:.4f}") + print(f"✓ Shared memory synchronization successful!") + + +def main(): + """Run all examples.""" + print("\n" + "="*70) + print("Weight Synchronization Schemes - Standalone Usage Examples") + print("="*70) + + # Set multiprocessing start method + try: + mp.set_start_method('spawn') + except RuntimeError: + pass # Already set + + # Run examples + example_multiprocess_sync() + example_shared_memory_sync() + + print("\n" + "="*70) + print("All examples completed successfully!") + print("="*70 + "\n") + + +if __name__ == "__main__": + main() + From 5fb1d9a32f1318a72400306b0e6990d984a0cdf5 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sat, 18 Oct 2025 11:04:55 -0700 Subject: [PATCH 3/6] Update [ghstack-poisoned] --- docs/source/reference/collectors.rst | 42 ------- examples/collectors/weight_sync_collectors.py | 119 +----------------- versions.html | 51 ++++++++ 3 files changed, 53 insertions(+), 159 deletions(-) create mode 100755 versions.html diff --git a/docs/source/reference/collectors.rst b/docs/source/reference/collectors.rst index aadeb6644d5..2431eceee35 100644 --- a/docs/source/reference/collectors.rst +++ b/docs/source/reference/collectors.rst @@ -233,7 +233,6 @@ across multiple inference workers: from torchrl.weight_update import ( MultiProcessWeightSyncScheme, SharedMemWeightSyncScheme, - NoWeightSyncScheme, ) # Create environment and policy @@ -292,47 +291,6 @@ across multiple inference workers: collector.shutdown() - # Example 3: Multiple models (policy + value network) - # --------------------------------------------------- - value_net = TensorDictModule( - nn.Linear(env.observation_spec["observation"].shape[-1], 1), - in_keys=["observation"], - out_keys=["value"], - ) - - weight_sync_schemes = { - "policy": MultiProcessWeightSyncScheme(strategy="state_dict"), - "value": MultiProcessWeightSyncScheme(strategy="state_dict"), - } - - collector = SyncDataCollector( - create_env_fn=lambda: GymEnv("CartPole-v1"), - policy=policy, - frames_per_batch=64, - total_frames=1000, - weight_sync_schemes=weight_sync_schemes, - ) - - # Update multiple models independently - collector.update_policy_weights_( - {"policy": policy.state_dict(), "value": value_net.state_dict()} - ) - - collector.shutdown() - - # Example 4: Disable weight synchronization - # ------------------------------------------ - # Useful for debugging or when using a shared policy reference - no_sync_scheme = NoWeightSyncScheme() - - collector = SyncDataCollector( - create_env_fn=lambda: GymEnv("CartPole-v1"), - policy=policy, - frames_per_batch=64, - total_frames=1000, - weight_sync_schemes={"policy": no_sync_scheme}, - ) - .. note:: When using ``SharedMemWeightSyncScheme``, weight updates are zero-copy and extremely fast since all processes share the same memory buffers. This is ideal for frequent weight updates but requires all diff --git a/examples/collectors/weight_sync_collectors.py b/examples/collectors/weight_sync_collectors.py index d18365e9725..fbb1a8a1166 100644 --- a/examples/collectors/weight_sync_collectors.py +++ b/examples/collectors/weight_sync_collectors.py @@ -14,7 +14,6 @@ single collectors, multiple collectors, multiple models, and no synchronization. """ -import torch import torch.nn as nn from tensordict import TensorDict from tensordict.nn import TensorDictModule @@ -23,7 +22,6 @@ from torchrl.weight_update import ( MultiProcessWeightSyncScheme, SharedMemWeightSyncScheme, - NoWeightSyncScheme, ) @@ -66,7 +64,7 @@ def example_single_collector_multiprocess(): if i % 2 == 0: new_weights = policy.state_dict() collector.update_policy_weights_(new_weights) - print(f" → Updated policy weights") + print(" → Updated policy weights") if i >= 2: # Just run a few iterations for demo break @@ -116,7 +114,7 @@ def example_multi_collector_shared_memory(): # Update weights frequently (shared memory makes this very fast) collector.update_policy_weights_(TensorDict.from_module(policy)) - print(f" → Updated policy weights via shared memory") + print(" → Updated policy weights via shared memory") if i >= 1: # Just run a couple iterations for demo break @@ -125,115 +123,6 @@ def example_multi_collector_shared_memory(): print("✓ Multi-collector with shared memory example completed!\n") -def example_multiple_models(): - """Example 3: Multiple models (policy + value network).""" - print("\n" + "="*70) - print("Example 3: Multiple Models (Policy + Value Network)") - print("="*70) - - # Create environment - env = GymEnv("CartPole-v1") - - # Create policy and value network - policy = TensorDictModule( - nn.Linear( - env.observation_spec["observation"].shape[-1], - env.action_spec.shape[-1] - ), - in_keys=["observation"], - out_keys=["action"], - ) - - value_net = TensorDictModule( - nn.Linear( - env.observation_spec["observation"].shape[-1], - 1 - ), - in_keys=["observation"], - out_keys=["value"], - ) - env.close() - - # Create separate schemes for each model - weight_sync_schemes = { - "policy": MultiProcessWeightSyncScheme(strategy="state_dict"), - "value": MultiProcessWeightSyncScheme(strategy="state_dict"), - } - - print("Creating collector with multiple models...") - collector = SyncDataCollector( - create_env_fn=lambda: GymEnv("CartPole-v1"), - policy=policy, - frames_per_batch=64, - total_frames=200, - weight_sync_schemes=weight_sync_schemes, - ) - - print("Collecting data...") - for i, data in enumerate(collector): - print(f"Iteration {i}: Collected {data.numel()} transitions") - - # Update both models independently - collector.update_policy_weights_( - { - "policy": policy.state_dict(), - "value": value_net.state_dict() - } - ) - print(f" → Updated both policy and value network weights") - - if i >= 1: - break - - collector.shutdown() - print("✓ Multiple models example completed!\n") - - -def example_no_weight_sync(): - """Example 4: Disable weight synchronization.""" - print("\n" + "="*70) - print("Example 4: Disable Weight Synchronization") - print("="*70) - - # Create environment and policy - env = GymEnv("CartPole-v1") - policy = TensorDictModule( - nn.Linear( - env.observation_spec["observation"].shape[-1], - env.action_spec.shape[-1] - ), - in_keys=["observation"], - out_keys=["action"], - ) - env.close() - - # Useful for debugging or when using a shared policy reference - scheme = NoWeightSyncScheme() - - print("Creating collector with no weight synchronization...") - collector = SyncDataCollector( - create_env_fn=lambda: GymEnv("CartPole-v1"), - policy=policy, - frames_per_batch=64, - total_frames=200, - weight_sync_schemes={"policy": scheme}, - ) - - print("Collecting data (no weight updates)...") - for i, data in enumerate(collector): - print(f"Iteration {i}: Collected {data.numel()} transitions") - - # Weight updates are no-ops with NoWeightSyncScheme - collector.update_policy_weights_(policy.state_dict()) - print(f" → Weight update call was a no-op") - - if i >= 1: - break - - collector.shutdown() - print("✓ No weight sync example completed!\n") - - def main(): """Run all examples.""" print("\n" + "="*70) @@ -250,8 +139,6 @@ def main(): # Run examples example_single_collector_multiprocess() example_multi_collector_shared_memory() - example_multiple_models() - example_no_weight_sync() print("\n" + "="*70) print("All examples completed successfully!") @@ -259,8 +146,6 @@ def main(): print("\nKey takeaways:") print(" • MultiProcessWeightSyncScheme: Good for general multiprocess scenarios") print(" • SharedMemWeightSyncScheme: Fast zero-copy updates for same-machine workers") - print(" • Multiple models: Each model can have its own sync scheme") - print(" • NoWeightSyncScheme: Useful for debugging or shared policy references") print("="*70 + "\n") diff --git a/versions.html b/versions.html new file mode 100755 index 00000000000..23e9c8f61ed --- /dev/null +++ b/versions.html @@ -0,0 +1,51 @@ + + + + + + + + + + + + + +
+
+

PyTorch Documentation

+
+

Pick a version

+ +

You can view previous versions of the torchrl documentation + here. + +

+
+ + From 64ab805be8482d1fb8c3bc3ff83706c26805698f Mon Sep 17 00:00:00 2001 From: vmoens Date: Sat, 18 Oct 2025 19:44:38 -0700 Subject: [PATCH 4/6] Update [ghstack-poisoned] --- examples/collectors/weight_sync_collectors.py | 68 +++++------ examples/collectors/weight_sync_standalone.py | 102 ++++++++-------- test/test_weightsync.py | 114 +----------------- torchrl/weight_update/weight_sync_schemes.py | 4 +- 4 files changed, 91 insertions(+), 197 deletions(-) diff --git a/examples/collectors/weight_sync_collectors.py b/examples/collectors/weight_sync_collectors.py index fbb1a8a1166..a3962966c8c 100644 --- a/examples/collectors/weight_sync_collectors.py +++ b/examples/collectors/weight_sync_collectors.py @@ -17,7 +17,7 @@ import torch.nn as nn from tensordict import TensorDict from tensordict.nn import TensorDictModule -from torchrl.collectors import SyncDataCollector, MultiSyncDataCollector +from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector from torchrl.envs import GymEnv from torchrl.weight_update import ( MultiProcessWeightSyncScheme, @@ -27,25 +27,24 @@ def example_single_collector_multiprocess(): """Example 1: Single collector with multiprocess scheme.""" - print("\n" + "="*70) + print("\n" + "=" * 70) print("Example 1: Single Collector with Multiprocess Scheme") - print("="*70) - + print("=" * 70) + # Create environment and policy env = GymEnv("CartPole-v1") policy = TensorDictModule( nn.Linear( - env.observation_spec["observation"].shape[-1], - env.action_spec.shape[-1] + env.observation_spec["observation"].shape[-1], env.action_spec.shape[-1] ), in_keys=["observation"], out_keys=["action"], ) env.close() - + # Create weight sync scheme scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - + print("Creating collector with multiprocess weight sync...") collector = SyncDataCollector( create_env_fn=lambda: GymEnv("CartPole-v1"), @@ -54,46 +53,45 @@ def example_single_collector_multiprocess(): total_frames=200, weight_sync_schemes={"policy": scheme}, ) - + # Collect data and update weights periodically print("Collecting data...") for i, data in enumerate(collector): print(f"Iteration {i}: Collected {data.numel()} transitions") - + # Update policy weights every 2 iterations if i % 2 == 0: new_weights = policy.state_dict() collector.update_policy_weights_(new_weights) print(" → Updated policy weights") - + if i >= 2: # Just run a few iterations for demo break - + collector.shutdown() print("✓ Single collector example completed!\n") def example_multi_collector_shared_memory(): """Example 2: Multiple collectors with shared memory.""" - print("\n" + "="*70) + print("\n" + "=" * 70) print("Example 2: Multiple Collectors with Shared Memory") - print("="*70) - + print("=" * 70) + # Create environment and policy env = GymEnv("CartPole-v1") policy = TensorDictModule( nn.Linear( - env.observation_spec["observation"].shape[-1], - env.action_spec.shape[-1] + env.observation_spec["observation"].shape[-1], env.action_spec.shape[-1] ), in_keys=["observation"], out_keys=["action"], ) env.close() - + # Shared memory is more efficient for frequent updates scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) - + print("Creating multi-collector with shared memory...") collector = MultiSyncDataCollector( create_env_fn=[ @@ -106,49 +104,51 @@ def example_multi_collector_shared_memory(): total_frames=400, weight_sync_schemes={"policy": scheme}, ) - + # Workers automatically see weight updates via shared memory print("Collecting data...") for i, data in enumerate(collector): print(f"Iteration {i}: Collected {data.numel()} transitions") - + # Update weights frequently (shared memory makes this very fast) collector.update_policy_weights_(TensorDict.from_module(policy)) print(" → Updated policy weights via shared memory") - + if i >= 1: # Just run a couple iterations for demo break - + collector.shutdown() print("✓ Multi-collector with shared memory example completed!\n") def main(): """Run all examples.""" - print("\n" + "="*70) + print("\n" + "=" * 70) print("Weight Synchronization Schemes - Collector Integration Examples") - print("="*70) - + print("=" * 70) + # Set multiprocessing start method import torch.multiprocessing as mp + try: - mp.set_start_method('spawn') + mp.set_start_method("spawn") except RuntimeError: pass # Already set - + # Run examples example_single_collector_multiprocess() example_multi_collector_shared_memory() - - print("\n" + "="*70) + + print("\n" + "=" * 70) print("All examples completed successfully!") - print("="*70) + print("=" * 70) print("\nKey takeaways:") print(" • MultiProcessWeightSyncScheme: Good for general multiprocess scenarios") - print(" • SharedMemWeightSyncScheme: Fast zero-copy updates for same-machine workers") - print("="*70 + "\n") + print( + " • SharedMemWeightSyncScheme: Fast zero-copy updates for same-machine workers" + ) + print("=" * 70 + "\n") if __name__ == "__main__": main() - diff --git a/examples/collectors/weight_sync_standalone.py b/examples/collectors/weight_sync_standalone.py index 83492256412..455dca1431c 100644 --- a/examples/collectors/weight_sync_standalone.py +++ b/examples/collectors/weight_sync_standalone.py @@ -16,8 +16,8 @@ import torch import torch.nn as nn -from torch import multiprocessing as mp from tensordict import TensorDict +from torch import multiprocessing as mp from torchrl.weight_update import ( MultiProcessWeightSyncScheme, SharedMemWeightSyncScheme, @@ -27,21 +27,21 @@ def worker_process_mp(child_pipe, model_state): """Worker process that receives weights via multiprocessing pipe.""" print("Worker: Starting...") - + # Create a policy on the worker side policy = nn.Linear(4, 2) with torch.no_grad(): policy.weight.fill_(0.0) policy.bias.fill_(0.0) - + # Create receiver and register the policy scheme = MultiProcessWeightSyncScheme(strategy="state_dict") receiver = scheme.create_receiver() receiver.register_model(policy) receiver.register_worker_transport(child_pipe) - + print(f"Worker: Before update - weight sum: {policy.weight.sum().item():.4f}") - + # Receive and apply weights result = receiver._transport.receive_weights(timeout=5.0) if result is not None: @@ -50,19 +50,19 @@ def worker_process_mp(child_pipe, model_state): print(f"Worker: After update - weight sum: {policy.weight.sum().item():.4f}") else: print("Worker: No weights received") - + # Store final state for verification - model_state['weight_sum'] = policy.weight.sum().item() - model_state['bias_sum'] = policy.bias.sum().item() + model_state["weight_sum"] = policy.weight.sum().item() + model_state["bias_sum"] = policy.bias.sum().item() def worker_process_shared_mem(child_pipe, model_state): """Worker process that receives shared memory buffer reference.""" print("SharedMem Worker: Starting...") - + # Create a policy on the worker side policy = nn.Linear(4, 2) - + # Wait for shared memory buffer registration if child_pipe.poll(timeout=10.0): data, msg = child_pipe.recv() @@ -73,129 +73,135 @@ def worker_process_shared_mem(child_pipe, model_state): shared_weights.to_module(policy) # Send acknowledgment child_pipe.send((None, "registered")) - + # Small delay to ensure main process updates shared memory import time + time.sleep(0.5) - + print(f"SharedMem Worker: weight sum: {policy.weight.sum().item():.4f}") - + # Store final state for verification - model_state['weight_sum'] = policy.weight.sum().item() - model_state['bias_sum'] = policy.bias.sum().item() + model_state["weight_sum"] = policy.weight.sum().item() + model_state["bias_sum"] = policy.bias.sum().item() def example_multiprocess_sync(): """Example 1: Multiprocess weight synchronization with state_dict.""" - print("\n" + "="*70) + print("\n" + "=" * 70) print("Example 1: Multiprocess Weight Synchronization") - print("="*70) - + print("=" * 70) + # Create a simple policy on main process policy = nn.Linear(4, 2) with torch.no_grad(): policy.weight.fill_(1.0) policy.bias.fill_(0.5) - + print(f"Main: Policy weight sum: {policy.weight.sum().item():.4f}") - + # Create scheme and sender scheme = MultiProcessWeightSyncScheme(strategy="state_dict") sender = scheme.create_sender() - + # Create pipe for communication parent_pipe, child_pipe = mp.Pipe() sender.register_worker(worker_idx=0, pipe_or_context=parent_pipe) - + # Start worker process manager = mp.Manager() model_state = manager.dict() process = mp.Process(target=worker_process_mp, args=(child_pipe, model_state)) process.start() - + # Send weights to worker weights = policy.state_dict() print("Main: Sending weights to worker...") sender.update_weights(weights) - + # Wait for worker to complete process.join(timeout=10.0) - + if process.is_alive(): print("Warning: Worker process did not terminate in time") process.terminate() else: - print(f"Main: Worker completed. Worker's weight sum: {model_state['weight_sum']:.4f}") + print( + f"Main: Worker completed. Worker's weight sum: {model_state['weight_sum']:.4f}" + ) print(f"✓ Weight synchronization successful!") def example_shared_memory_sync(): """Example 2: Shared memory weight synchronization.""" - print("\n" + "="*70) + print("\n" + "=" * 70) print("Example 2: Shared Memory Weight Synchronization") - print("="*70) - + print("=" * 70) + # Create a simple policy policy = nn.Linear(4, 2) - + # Create shared memory scheme with auto-registration scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) sender = scheme.create_sender() - + # Create pipe for lazy registration parent_pipe, child_pipe = mp.Pipe() sender.register_worker(worker_idx=0, pipe_or_context=parent_pipe) - + # Start worker process manager = mp.Manager() model_state = manager.dict() - process = mp.Process(target=worker_process_shared_mem, args=(child_pipe, model_state)) + process = mp.Process( + target=worker_process_shared_mem, args=(child_pipe, model_state) + ) process.start() - + # Send weights (automatically creates shared buffer on first send) weights_td = TensorDict.from_module(policy) with torch.no_grad(): weights_td["weight"].fill_(2.0) weights_td["bias"].fill_(1.0) - + print(f"Main: Sending weights via shared memory...") sender.update_weights(weights_td) - + # Workers automatically see updates via shared memory! print("Main: Weights are now in shared memory, workers can access them") - + # Wait for worker to complete process.join(timeout=10.0) - + if process.is_alive(): print("Warning: Worker process did not terminate in time") process.terminate() else: - print(f"Main: Worker completed. Worker's weight sum: {model_state['weight_sum']:.4f}") + print( + f"Main: Worker completed. Worker's weight sum: {model_state['weight_sum']:.4f}" + ) print(f"✓ Shared memory synchronization successful!") def main(): """Run all examples.""" - print("\n" + "="*70) + print("\n" + "=" * 70) print("Weight Synchronization Schemes - Standalone Usage Examples") - print("="*70) - + print("=" * 70) + # Set multiprocessing start method try: - mp.set_start_method('spawn') + mp.set_start_method("spawn") except RuntimeError: pass # Already set - + # Run examples example_multiprocess_sync() example_shared_memory_sync() - - print("\n" + "="*70) + + print("\n" + "=" * 70) print("All examples completed successfully!") - print("="*70 + "\n") + print("=" * 70 + "\n") if __name__ == "__main__": main() - diff --git a/test/test_weightsync.py b/test/test_weightsync.py index 91203c209f0..8f28d82a48a 100644 --- a/test/test_weightsync.py +++ b/test/test_weightsync.py @@ -12,7 +12,7 @@ from tensordict import TensorDict from tensordict.nn import TensorDictModule from torch import multiprocessing as mp -from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector +from torchrl.collectors import SyncDataCollector from torchrl.envs import GymEnv from torchrl.weight_update.weight_sync_schemes import ( @@ -272,118 +272,6 @@ def test_no_weight_sync_scheme(self): transport.send_weights("policy", weights) -class TestCollectorIntegration: - @pytest.fixture - def simple_env(self): - return GymEnv("CartPole-v1") - - @pytest.fixture - def simple_policy(self, simple_env): - return TensorDictModule( - nn.Linear( - simple_env.observation_spec["observation"].shape[-1], - simple_env.action_spec.shape[-1], - ), - in_keys=["observation"], - out_keys=["action"], - ) - - def test_syncdatacollector_multiprocess_scheme(self, simple_policy): - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - - collector = SyncDataCollector( - create_env_fn=lambda: GymEnv("CartPole-v1"), - policy=simple_policy, - frames_per_batch=64, - total_frames=128, - weight_sync_schemes={"policy": scheme}, - ) - - new_weights = simple_policy.state_dict() - with torch.no_grad(): - for key in new_weights: - new_weights[key].fill_(1.0) - - collector.update_policy_weights_(new_weights) - - for data in collector: - assert data.numel() > 0 - break - - collector.shutdown() - - def test_multisyncdatacollector_multiprocess_scheme(self, simple_policy): - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - - collector = MultiSyncDataCollector( - create_env_fn=[ - lambda: GymEnv("CartPole-v1"), - lambda: GymEnv("CartPole-v1"), - ], - policy=simple_policy, - frames_per_batch=64, - total_frames=128, - weight_sync_schemes={"policy": scheme}, - ) - - new_weights = simple_policy.state_dict() - with torch.no_grad(): - for key in new_weights: - new_weights[key].fill_(1.0) - - collector.update_policy_weights_(new_weights) - - for data in collector: - assert data.numel() > 0 - break - - collector.shutdown() - - def test_multisyncdatacollector_shared_mem_scheme(self, simple_policy): - scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) - - collector = MultiSyncDataCollector( - create_env_fn=[ - lambda: GymEnv("CartPole-v1"), - lambda: GymEnv("CartPole-v1"), - ], - policy=simple_policy, - frames_per_batch=64, - total_frames=128, - weight_sync_schemes={"policy": scheme}, - ) - - new_weights = TensorDict.from_module(simple_policy) - with torch.no_grad(): - new_weights["module"]["weight"].fill_(1.0) - new_weights["module"]["bias"].fill_(1.0) - - collector.update_policy_weights_(new_weights) - - for data in collector: - assert data.numel() > 0 - break - - collector.shutdown() - - def test_collector_no_weight_sync(self, simple_policy): - scheme = NoWeightSyncScheme() - - collector = SyncDataCollector( - create_env_fn=lambda: GymEnv("CartPole-v1"), - policy=simple_policy, - frames_per_batch=64, - total_frames=128, - weight_sync_schemes={"policy": scheme}, - ) - - for data in collector: - assert data.numel() > 0 - break - - collector.shutdown() - - class TestMultiModelUpdates: def test_multi_model_state_dict_updates(self): env = GymEnv("CartPole-v1") diff --git a/torchrl/weight_update/weight_sync_schemes.py b/torchrl/weight_update/weight_sync_schemes.py index 763753896b2..244b7c204f4 100644 --- a/torchrl/weight_update/weight_sync_schemes.py +++ b/torchrl/weight_update/weight_sync_schemes.py @@ -1313,7 +1313,7 @@ def _resolve_model(context: Any, model_id: str) -> Any: obj = getattr(obj, key) except AttributeError: raise AttributeError( - f"Attribute {key} from {parts[:i+1]} not found in {'.'.join(parts[:i])}={obj}" + f"Attribute {key} from {parts[:i + 1]} not found in {'.'.join(parts[:i])}={obj}" ) for index in indices: obj = obj[index] @@ -1322,6 +1322,6 @@ def _resolve_model(context: Any, model_id: str) -> Any: obj = getattr(obj, part) except AttributeError: raise AttributeError( - f"Attribute {part} from {parts[:i+1]} not found in {'.'.join(parts[:i])}={obj}" + f"Attribute {part} from {parts[:i + 1]} not found in {'.'.join(parts[:i])}={obj}" ) return obj From 3ce8cdc01a7d15715f359d327050bea1242456db Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 19 Oct 2025 15:49:16 -0700 Subject: [PATCH 5/6] Update [ghstack-poisoned] --- test/test_weightsync.py | 108 ---------------------------------------- 1 file changed, 108 deletions(-) diff --git a/test/test_weightsync.py b/test/test_weightsync.py index 8f28d82a48a..0e48d564fb0 100644 --- a/test/test_weightsync.py +++ b/test/test_weightsync.py @@ -272,114 +272,6 @@ def test_no_weight_sync_scheme(self): transport.send_weights("policy", weights) -class TestMultiModelUpdates: - def test_multi_model_state_dict_updates(self): - env = GymEnv("CartPole-v1") - - policy = TensorDictModule( - nn.Linear( - env.observation_spec["observation"].shape[-1], env.action_spec.shape[-1] - ), - in_keys=["observation"], - out_keys=["action"], - ) - - value = TensorDictModule( - nn.Linear(env.observation_spec["observation"].shape[-1], 1), - in_keys=["observation"], - out_keys=["value"], - ) - - weight_sync_schemes = { - "policy": MultiProcessWeightSyncScheme(strategy="state_dict"), - "value": MultiProcessWeightSyncScheme(strategy="state_dict"), - } - - collector = SyncDataCollector( - create_env_fn=lambda: GymEnv("CartPole-v1"), - policy=policy, - frames_per_batch=64, - total_frames=128, - weight_sync_schemes=weight_sync_schemes, - ) - - policy_weights = policy.state_dict() - value_weights = value.state_dict() - - with torch.no_grad(): - for key in policy_weights: - policy_weights[key].fill_(1.0) - for key in value_weights: - value_weights[key].fill_(2.0) - - collector.update_policy_weights_( - weights_dict={ - "policy": policy_weights, - "value": value_weights, - } - ) - - for data in collector: - assert data.numel() > 0 - break - - collector.shutdown() - env.close() - - def test_multi_model_tensordict_updates(self): - env = GymEnv("CartPole-v1") - - policy = TensorDictModule( - nn.Linear( - env.observation_spec["observation"].shape[-1], env.action_spec.shape[-1] - ), - in_keys=["observation"], - out_keys=["action"], - ) - - value = TensorDictModule( - nn.Linear(env.observation_spec["observation"].shape[-1], 1), - in_keys=["observation"], - out_keys=["value"], - ) - - weight_sync_schemes = { - "policy": MultiProcessWeightSyncScheme(strategy="tensordict"), - "value": MultiProcessWeightSyncScheme(strategy="tensordict"), - } - - collector = SyncDataCollector( - create_env_fn=lambda: GymEnv("CartPole-v1"), - policy=policy, - frames_per_batch=64, - total_frames=128, - weight_sync_schemes=weight_sync_schemes, - ) - - policy_weights = TensorDict.from_module(policy) - value_weights = TensorDict.from_module(value) - - with torch.no_grad(): - policy_weights["module"]["weight"].fill_(1.0) - policy_weights["module"]["bias"].fill_(1.0) - value_weights["module"]["weight"].fill_(2.0) - value_weights["module"]["bias"].fill_(2.0) - - collector.update_policy_weights_( - weights_dict={ - "policy": policy_weights, - "value": value_weights, - } - ) - - for data in collector: - assert data.numel() > 0 - break - - collector.shutdown() - env.close() - - class TestHelpers: def test_resolve_model_simple(self): class Context: From 0516176595988bb440efcb267fe9aa5d283b6f23 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 19 Oct 2025 17:03:22 -0700 Subject: [PATCH 6/6] Update [ghstack-poisoned] --- examples/collectors/weight_sync_standalone.py | 6 +++--- test/test_weightsync.py | 3 --- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/examples/collectors/weight_sync_standalone.py b/examples/collectors/weight_sync_standalone.py index 455dca1431c..2d918cb10a2 100644 --- a/examples/collectors/weight_sync_standalone.py +++ b/examples/collectors/weight_sync_standalone.py @@ -129,7 +129,7 @@ def example_multiprocess_sync(): print( f"Main: Worker completed. Worker's weight sum: {model_state['weight_sum']:.4f}" ) - print(f"✓ Weight synchronization successful!") + print("Weight synchronization successful!") def example_shared_memory_sync(): @@ -163,7 +163,7 @@ def example_shared_memory_sync(): weights_td["weight"].fill_(2.0) weights_td["bias"].fill_(1.0) - print(f"Main: Sending weights via shared memory...") + print("Main: Sending weights via shared memory...") sender.update_weights(weights_td) # Workers automatically see updates via shared memory! @@ -179,7 +179,7 @@ def example_shared_memory_sync(): print( f"Main: Worker completed. Worker's weight sum: {model_state['weight_sum']:.4f}" ) - print(f"✓ Shared memory synchronization successful!") + print("Shared memory synchronization successful!") def main(): diff --git a/test/test_weightsync.py b/test/test_weightsync.py index 0e48d564fb0..9c2d2025087 100644 --- a/test/test_weightsync.py +++ b/test/test_weightsync.py @@ -10,10 +10,7 @@ import torch import torch.nn as nn from tensordict import TensorDict -from tensordict.nn import TensorDictModule from torch import multiprocessing as mp -from torchrl.collectors import SyncDataCollector -from torchrl.envs import GymEnv from torchrl.weight_update.weight_sync_schemes import ( _resolve_model,