diff --git a/docs/source/reference/collectors.rst b/docs/source/reference/collectors.rst index 2431eceee35..315e17e082f 100644 --- a/docs/source/reference/collectors.rst +++ b/docs/source/reference/collectors.rst @@ -162,7 +162,15 @@ Usage Examples Using Weight Update Schemes Independently ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Weight update schemes can be used outside of collectors for custom synchronization scenarios. Here's a basic example: +Weight update schemes can be used outside of collectors for custom synchronization scenarios. +The new simplified API provides four core methods for weight synchronization: + +- ``init_on_sender(model_id, **kwargs)`` - Initialize on the main process (trainer) side +- ``init_on_worker(model_id, **kwargs)`` - Initialize on worker process side +- ``get_sender()`` - Get the configured sender instance +- ``get_receiver()`` - Get the configured receiver instance + +Here's a basic example: .. code-block:: python @@ -182,39 +190,37 @@ Weight update schemes can be used outside of collectors for custom synchronizati # -------------------------------------------------------------- # On the main process side (trainer): scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - sender = scheme.create_sender() - - # Register worker pipes + + # Initialize scheme with pipes parent_pipe, child_pipe = mp.Pipe() - sender.register_worker(worker_idx=0, pipe_or_context=parent_pipe) - - # Send weights to workers + scheme.init_on_sender(model_id="policy", pipes=[parent_pipe]) + + # Get the sender and send weights + sender = scheme.get_sender() weights = policy.state_dict() - sender.update_weights(weights) + sender.send(weights) # Synchronous send + # or sender.send_async(weights); sender.wait_async() # Asynchronous send # On the worker process side: - # receiver = scheme.create_receiver() - # receiver.register_model(policy) - # receiver.register_worker_transport(child_pipe) - # # Receive and apply weights - # result = receiver._transport.receive_weights(timeout=5.0) - # if result is not None: - # model_id, weights = result - # receiver.apply_weights(weights) + # scheme.init_on_worker(model_id="policy", pipe=child_pipe, model=policy) + # receiver = scheme.get_receiver() + # # Non-blocking check for new weights + # if receiver.receive(timeout=0.001): + # # Weights were received and applied # Example 2: Shared memory weight synchronization # ------------------------------------------------ # Create shared memory scheme with auto-registration shared_scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) - shared_sender = shared_scheme.create_sender() - - # Register worker pipes for lazy registration + + # Initialize with pipes for lazy registration parent_pipe2, child_pipe2 = mp.Pipe() - shared_sender.register_worker(worker_idx=0, pipe_or_context=parent_pipe2) - - # Send weights (automatically creates shared buffer on first send) + shared_scheme.init_on_sender(model_id="policy", pipes=[parent_pipe2]) + + # Get sender and send weights (automatically creates shared buffer on first send) + shared_sender = shared_scheme.get_sender() weights_td = TensorDict.from_module(policy) - shared_sender.update_weights(weights_td) + shared_sender.send(weights_td) # Workers automatically see updates via shared memory! diff --git a/test/test_weightsync.py b/test/test_weightsync.py index f5a4515f224..2ccd4308ccf 100644 --- a/test/test_weightsync.py +++ b/test/test_weightsync.py @@ -5,6 +5,9 @@ from __future__ import annotations import argparse +import importlib.util +import pickle +import time import pytest import torch @@ -16,14 +19,20 @@ from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector from torchrl.weight_update.weight_sync_schemes import ( _resolve_model, + DistributedWeightSyncScheme, MPTransport, MultiProcessWeightSyncScheme, NoWeightSyncScheme, + RayModuleTransformScheme, + RayWeightSyncScheme, + RPCWeightSyncScheme, SharedMemTransport, SharedMemWeightSyncScheme, WeightStrategy, ) +_has_ray = importlib.util.find_spec("ray") is not None + def worker_update_policy(pipe, timeout=5.0): policy = nn.Linear(4, 2) @@ -32,9 +41,8 @@ def worker_update_policy(pipe, timeout=5.0): policy.bias.fill_(0.0) scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - receiver = scheme.create_receiver() - receiver.register_model(policy) - receiver.register_worker_transport(pipe) + scheme.init_on_worker(model_id="policy", pipe=pipe, model=policy) + receiver = scheme.get_receiver() if receiver._transport.pipe.poll(timeout): data, msg = receiver._transport.pipe.recv() @@ -52,9 +60,8 @@ def worker_update_policy_tensordict(pipe, timeout=5.0): policy.bias.fill_(0.0) scheme = MultiProcessWeightSyncScheme(strategy="tensordict") - receiver = scheme.create_receiver() - receiver.register_model(policy) - receiver.register_worker_transport(pipe) + scheme.init_on_worker(model_id="policy", pipe=pipe, model=policy) + receiver = scheme.get_receiver() if receiver._transport.pipe.poll(timeout): data, msg = receiver._transport.pipe.recv() @@ -75,8 +82,6 @@ def worker_shared_mem(pipe, timeout=10.0): shared_weights.to_module(policy) pipe.send((None, "registered")) - import time - time.sleep(0.5) return policy.weight.sum().item(), policy.bias.sum().item() @@ -192,39 +197,46 @@ def test_cross_format_conversion(self): class TestWeightSyncSchemes: + """Tests for weight sync schemes using the new simplified API. + + Lower-level transport and legacy API tests are in TestTransportBackends. + """ + def test_multiprocess_scheme_state_dict(self): parent_pipe, child_pipe = mp.Pipe() scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - sender = scheme.create_sender() - sender.register_worker(0, parent_pipe) + scheme.init_on_sender(model_id="policy", pipes=[parent_pipe]) + sender = scheme.get_sender() proc = mp.Process(target=worker_update_policy, args=(child_pipe,)) - proc.start() + try: + proc.start() - weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)} - sender.update_weights(weights) - - proc.join(timeout=10.0) - assert not proc.is_alive() + weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)} + sender.send(weights) + finally: + proc.join(timeout=10.0) + assert not proc.is_alive() def test_multiprocess_scheme_tensordict(self): parent_pipe, child_pipe = mp.Pipe() scheme = MultiProcessWeightSyncScheme(strategy="tensordict") - sender = scheme.create_sender() - sender.register_worker(0, parent_pipe) + scheme.init_on_sender(model_id="policy", pipes=[parent_pipe]) + sender = scheme.get_sender() proc = mp.Process(target=worker_update_policy_tensordict, args=(child_pipe,)) - proc.start() - - weights = TensorDict( - {"weight": torch.ones(2, 4), "bias": torch.ones(2)}, batch_size=[] - ) - sender.update_weights(weights) + try: + proc.start() - proc.join(timeout=10.0) - assert not proc.is_alive() + weights = TensorDict( + {"weight": torch.ones(2, 4), "bias": torch.ones(2)}, batch_size=[] + ) + sender.send(weights) + finally: + proc.join(timeout=10.0) + assert not proc.is_alive() def test_shared_mem_scheme(self): shared_buffer = TensorDict( @@ -270,6 +282,51 @@ def test_no_weight_sync_scheme(self): weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)} transport.send_weights("policy", weights) + @classmethod + def _worker_with_receive(cls, pipe, scheme): + policy = nn.Linear(4, 2) + with torch.no_grad(): + policy.weight.fill_(0.0) + policy.bias.fill_(0.0) + + scheme.init_on_worker(model_id="policy", pipe=pipe, model=policy) + receiver = scheme.get_receiver() + + # Non-blocking receive should return False when no data + result = receiver.receive(timeout=0.001) + assert result is False + + # Now actually receive the weights + result = receiver.receive(timeout=5.0) + assert result is True + + # Check weights were applied + return policy.weight.sum().item(), policy.bias.sum().item() + + def test_receiver_receive_method(self): + """Test the new non-blocking receive() method.""" + + parent_pipe, child_pipe = mp.Pipe() + + scheme = MultiProcessWeightSyncScheme(strategy="state_dict") + scheme.init_on_sender(model_id="policy", pipes=[parent_pipe]) + sender = scheme.get_sender() + + proc = mp.Process(target=self._worker_with_receive, args=(child_pipe, scheme)) + try: + proc.start() + + # Give worker time to call receive with no data + + time.sleep(0.1) + + weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)} + sender.send(weights) + + finally: + proc.join(timeout=10.0) + assert not proc.is_alive() + class TestCollectorIntegration: @pytest.fixture @@ -568,6 +625,239 @@ def test_weight_strategy_parametrized(strategy): assert torch.allclose(policy.bias, target.bias) +class TestSerializeScheme: + """Test that WeightSyncScheme instances can be serialized after initialization. + + This is critical for multiprocessing and Ray, where schemes may be pickled + and sent across process boundaries. The _sender and _receiver attributes + contain non-serializable objects (pipes, weak references, etc.) and must + be excluded from serialization. + """ + + def test_multiprocess_scheme_serialize_before_init(self): + """Test that uninitialized scheme can be pickled.""" + scheme = MultiProcessWeightSyncScheme(strategy="state_dict") + + # Serialize and deserialize + pickled = pickle.dumps(scheme) + restored = pickle.loads(pickled) + + # Check that configuration is preserved + assert restored.strategy == "state_dict" + assert restored._sender is None + assert restored._receiver is None + assert not restored._initialized_on_sender + assert not restored._initialized_on_worker + + def test_multiprocess_scheme_serialize_after_sender_init(self): + """Test that initialized sender can be pickled (excluding runtime state).""" + parent_pipe, child_pipe = mp.Pipe() + + scheme = MultiProcessWeightSyncScheme(strategy="state_dict") + scheme.init_on_sender(model_id="policy", pipes=[parent_pipe]) + + # Scheme now has _sender with non-serializable pipes + assert scheme._sender is not None + assert scheme._initialized_on_sender + + # Serialize and deserialize + pickled = pickle.dumps(scheme) + restored = pickle.loads(pickled) + + # Check that configuration is preserved but runtime state is cleared + assert restored.strategy == "state_dict" + assert restored._sender is None # Runtime state excluded + assert restored._receiver is None + assert not restored._initialized_on_sender # Reset + assert not restored._initialized_on_worker + + # Clean up + parent_pipe.close() + child_pipe.close() + + def test_shared_mem_scheme_serialize_before_init(self): + """Test that uninitialized SharedMemWeightSyncScheme can be pickled.""" + scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) + + # Serialize and deserialize + pickled = pickle.dumps(scheme) + restored = pickle.loads(pickled) + + # Check that configuration is preserved + assert restored.strategy == "tensordict" + assert restored._sender is None + assert restored._receiver is None + + def test_shared_mem_scheme_serialize_after_init(self): + """Test that initialized SharedMemWeightSyncScheme can be pickled.""" + parent_pipe, child_pipe = mp.Pipe() + + # Create shared buffer + shared_buffer = TensorDict( + {"weight": torch.zeros(2, 4), "bias": torch.zeros(2)}, batch_size=[] + ).share_memory_() + + scheme = SharedMemWeightSyncScheme( + policy_weights={"policy": shared_buffer}, + strategy="tensordict", + auto_register=False, + ) + + def init_on_sender(scheme, child_pipe): + (model_id, data), msg = child_pipe.recv() + if msg == "register_shared_weights": + child_pipe.send((None, "registered")) + else: + raise ValueError(f"Expected 'register_shared_weights' but got {msg}") + + # Initialize the scheme with the pipes, in 2 separate threads because init requires acknowledgement from the worker + import threading + + future_sender = threading.Thread( + target=scheme.init_on_sender, + kwargs={"model_id": "policy", "pipes": [parent_pipe]}, + ) + future_receiver = threading.Thread( + target=init_on_sender, + kwargs={"scheme": scheme, "child_pipe": child_pipe}, + ) + future_receiver.start() + future_sender.start() + future_receiver.join() + future_sender.join() + + # Scheme now has _sender with non-serializable state + assert scheme._sender is not None + + # Serialize and deserialize + pickled = pickle.dumps(scheme) + restored = pickle.loads(pickled) + + # Check that configuration is preserved but runtime state is cleared + assert restored.strategy == "tensordict" + assert restored._sender is None + assert not restored._initialized_on_sender + + # Note: policy_weights dict is preserved (but may need re-sharing) + assert "policy" in restored.policy_weights + + # Clean up + parent_pipe.close() + child_pipe.close() + + def test_no_weight_sync_scheme_serialize(self): + """Test that NoWeightSyncScheme can be pickled.""" + scheme = NoWeightSyncScheme() + scheme.init_on_sender(model_id="policy") + + # Serialize and deserialize + pickled = pickle.dumps(scheme) + restored = pickle.loads(pickled) + + # Check that it's still a no-op scheme + assert restored._sender is None + assert restored._receiver is None + + @pytest.mark.skipif( + not torch.distributed.is_available(), reason="torch.distributed not available" + ) + def test_distributed_scheme_serialize_before_init(self): + """Test that uninitialized DistributedWeightSyncScheme can be pickled.""" + + scheme = DistributedWeightSyncScheme(backend="gloo", sync=True) + + # Serialize and deserialize + pickled = pickle.dumps(scheme) + restored = pickle.loads(pickled) + + # Check that configuration is preserved + assert restored.backend == "gloo" + assert restored.sync is True + assert restored._sender is None + assert restored._receiver is None + + @pytest.mark.skipif(not _has_ray, reason="Ray not available") + def test_ray_weight_sync_scheme_serialize_before_init(self): + """Test that uninitialized RayWeightSyncScheme can be pickled.""" + scheme = RayWeightSyncScheme(strategy="state_dict") + + # Serialize and deserialize + pickled = pickle.dumps(scheme) + restored = pickle.loads(pickled) + + # Check that configuration is preserved + assert restored.strategy == "state_dict" + assert restored._sender is None + assert restored._receiver is None + + @pytest.mark.skipif(not _has_ray, reason="Ray not available") + def test_ray_module_transform_scheme_serialize_before_init(self): + """Test that uninitialized RayModuleTransformScheme can be pickled.""" + + scheme = RayModuleTransformScheme(strategy="tensordict") + + # Serialize and deserialize + pickled = pickle.dumps(scheme) + restored = pickle.loads(pickled) + + # Check that configuration is preserved + assert restored.strategy == "tensordict" + assert restored._sender is None + assert restored._receiver is None + + @pytest.mark.skipif( + not torch.distributed.is_available(), reason="torch.distributed not available" + ) + def test_rpc_weight_sync_scheme_serialize_before_init(self): + """Test that uninitialized RPCWeightSyncScheme can be pickled.""" + + scheme = RPCWeightSyncScheme(strategy="state_dict") + + # Serialize and deserialize + pickled = pickle.dumps(scheme) + restored = pickle.loads(pickled) + + # Check that configuration is preserved + assert restored.strategy == "state_dict" + assert restored._sender is None + assert restored._receiver is None + + def test_scheme_reinitialization_after_unpickle(self): + """Test that a scheme can be re-initialized after unpickling. + + This is the expected workflow: pickle a scheme, unpickle it in a worker, + then call init_on_worker() to establish new runtime resources. + """ + # Initialize and pickle a scheme + parent_pipe, child_pipe = mp.Pipe() + + scheme = MultiProcessWeightSyncScheme(strategy="state_dict") + scheme.init_on_sender(model_id="policy", pipes=[parent_pipe]) + + pickled = pickle.dumps(scheme) + + # Clean up original + parent_pipe.close() + + # Unpickle and re-initialize + restored = pickle.loads(pickled) + + # Should be able to initialize again with new pipes + new_parent, new_child = mp.Pipe() + + # Re-initialize on sender + restored.init_on_sender(model_id="policy", pipes=[new_parent]) + sender = restored.get_sender() + + assert sender is not None + assert restored._initialized_on_sender + + # Clean up + new_parent.close() + new_child.close() + child_pipe.close() + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst", "-v"] + unknown) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index e9513fba033..3686368ae71 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -471,7 +471,10 @@ def _weight_update_impl( processed_weights = self._extract_weights_if_needed( weights, target_model_id ) - self._weight_senders[target_model_id].update_weights(processed_weights) + # Use new send() API with worker_ids support + self._weight_senders[target_model_id].send( + weights=processed_weights, worker_ids=worker_ids + ) elif self._weight_updater is not None: # unreachable raise RuntimeError @@ -533,7 +536,12 @@ def next(self): return None @abc.abstractmethod - def shutdown(self, timeout: float | None = None, close_env: bool = True) -> None: + def shutdown( + self, + timeout: float | None = None, + close_env: bool = True, + raise_on_error: bool = True, + ) -> None: raise NotImplementedError @abc.abstractmethod @@ -2041,23 +2049,35 @@ def reset(self, index=None, **kwargs) -> None: ) self._shuttle["collector"] = collector_metadata - def shutdown(self, timeout: float | None = None, close_env: bool = True) -> None: + def shutdown( + self, + timeout: float | None = None, + close_env: bool = True, + raise_on_error: bool = True, + ) -> None: """Shuts down all workers and/or closes the local environment. Args: timeout (float, optional): The timeout for closing pipes between workers. No effect for this class. close_env (bool, optional): Whether to close the environment. Defaults to `True`. + raise_on_error (bool, optional): Whether to raise an error if the shutdown fails. Defaults to `True`. """ - if not self.closed: - self.closed = True - del self._shuttle - if self._use_buffers: - del self._final_rollout - if close_env and not self.env.is_closed: - self.env.close() - del self.env - return + try: + if not self.closed: + self.closed = True + del self._shuttle + if self._use_buffers: + del self._final_rollout + if close_env and not self.env.is_closed: + self.env.close(raise_if_closed=raise_on_error) + del self.env + return + except Exception as e: + if raise_on_error: + raise e + else: + pass def __del__(self): try: @@ -2125,9 +2145,10 @@ def __repr__(self) -> str: try: env_str = indent(f"env={self.env}", 4 * " ") policy_str = indent(f"policy={self._wrapped_policy}", 4 * " ") - td_out_str = indent( - f"td_out={getattr(self, '_final_rollout', None)}", 4 * " " - ) + td_out_str = repr(getattr(self, "_final_rollout", None)) + if len(td_out_str) > 50: + td_out_str = td_out_str[:50] + "..." + td_out_str = indent(f"td_out={td_out_str}", 4 * " ") string = ( f"{self.__class__.__name__}(" f"\n{env_str}," @@ -2181,6 +2202,34 @@ def getattr_rb(self, attr): # send command to rb to return the attr return getattr(self.replay_buffer, attr) + def get_model(self, model_id: str): + """Get model instance by ID (for weight sync schemes). + + Args: + model_id: Model identifier (e.g., "policy", "value_net") + + Returns: + The model instance + + Raises: + ValueError: If model_id is not recognized + """ + if model_id == "policy": + # Return the unwrapped policy instance for weight synchronization + # The unwrapped policy has the same parameter structure as what's + # extracted in the main process, avoiding key mismatches when + # the policy is auto-wrapped (e.g., WrappablePolicy -> TensorDictModule) + if hasattr(self, "policy") and self.policy is not None: + return self.policy + else: + raise ValueError(f"No policy found for model_id '{model_id}'") + else: + # Try to resolve via attribute access + if hasattr(self, model_id): + return getattr(self, model_id) + else: + raise ValueError(f"Unknown model_id: {model_id}") + class _MultiDataCollector(DataCollectorBase): """Runs a given number of DataCollectors on separate processes. @@ -2527,7 +2576,11 @@ def __init__( self._setup_preemptive_threshold(preemptive_threshold) # Run worker processes - self._run_processes() + try: + self._run_processes() + except Exception as e: + self.shutdown(raise_on_error=False) + raise e # Set up frame tracking and other options self._exclude_private_keys = True @@ -2917,15 +2970,7 @@ def _run_processes(self) -> None: 1, torch.get_num_threads() - total_workers ) # 1 more thread for this proc - # Initialize weight senders for multiprocess collectors - if self._weight_sync_schemes: - # Create one sender per model using scheme's factory method - for model_id, scheme in self._weight_sync_schemes.items(): - sender = scheme.create_sender() - sender._model_id = model_id - if hasattr(sender, "set_context"): - sender.set_context(self, model_id) - self._weight_senders[model_id] = sender + # Weight senders will be initialized after workers are ready (via init_on_sender) torch.set_num_threads(self.num_threads) queue_out = mp.Queue(self._queue_len) # sends data from proc to main self.procs = [] @@ -3037,11 +3082,7 @@ def _run_processes(self) -> None: self.procs.append(proc) self.pipes.append(pipe_parent) - # Register worker with senders - if self._weight_senders: - for _, sender in self._weight_senders.items(): - sender.register_worker(i, pipe_parent) - + # Worker registration now handled by init_on_sender() after workers are ready for i, pipe_parent in enumerate(self.pipes): pipe_parent.poll(timeout=INSTANTIATE_TIMEOUT) try: @@ -3093,30 +3134,19 @@ def _run_processes(self) -> None: # Legacy string error message raise RuntimeError(msg) - # For SharedMemWeightSyncScheme, pre-register shared weights now that workers are ready - # This avoids deadlock when workers are busy collecting and can't respond to registration messages + # Initialize all weight sync schemes now that workers are ready + # This calls init_on_sender() for each scheme which: + # 1. Creates transports for all workers + # 2. Creates and configures the sender + # 3. For SharedMemWeightSyncScheme, distributes buffer references to avoid deadlock if self._weight_sync_schemes: for model_id, scheme in self._weight_sync_schemes.items(): - if isinstance(scheme, SharedMemWeightSyncScheme): - sender = self._weight_senders[model_id] - # Get the shared memory weights from _policy_weights_dict - # Use prepare_weights with None to trigger cache lookup - from torchrl.weight_update.weight_sync_schemes import _get_strategy - - strategy = _get_strategy(scheme.strategy) - weights = scheme.prepare_weights( - weights=None, - model_id=model_id, - strategy=strategy, - context=self, - ) - if weights is not None: - # Register the shared weights directly with each transport - # This ensures the transports use the same shared memory buffer - # that we'll update later, rather than creating a clone - for transport in sender._iterate_transports(): - if hasattr(transport, "register_weights"): - transport.register_weights(model_id, weights) + # Check if scheme has new API or legacy API + if hasattr(scheme, "init_on_sender"): + scheme.init_on_sender(model_id=model_id, context=self) + # Get the initialized sender + self._weight_senders[model_id] = scheme.get_sender() + # else: keep using legacy _weight_senders initialization from before self.queue_out = queue_out self.closed = False @@ -3231,18 +3261,30 @@ def __del__(self): # __del__ will not affect the program. pass - def shutdown(self, timeout: float | None = None, close_env: bool = True) -> None: + def shutdown( + self, + timeout: float | None = None, + close_env: bool = True, + raise_on_error: bool = True, + ) -> None: """Shuts down all processes. This operation is irreversible. Args: timeout (float, optional): The timeout for closing pipes between workers. close_env (bool, optional): Whether to close the environment. Defaults to `True`. + raise_on_error (bool, optional): Whether to raise an error if the shutdown fails. Defaults to `True`. """ if not close_env: raise RuntimeError( f"Cannot shutdown {type(self).__name__} collector without environment being closed." ) - self._shutdown_main(timeout) + try: + self._shutdown_main(timeout) + except Exception as e: + if raise_on_error: + raise e + else: + pass def _shutdown_main(self, timeout: float | None = None) -> None: if timeout is None: @@ -3478,6 +3520,52 @@ def getattr_rb(self, attr): """Get an attribute from the replay buffer.""" return getattr(self.replay_buffer, attr) + def get_model(self, model_id: str): + """Get model instance by ID (for weight sync schemes). + + Args: + model_id: Model identifier (e.g., "policy", "value_net") + + Returns: + The model instance + + Raises: + ValueError: If model_id is not recognized + """ + if model_id == "policy": + # Return the fallback policy instance + if hasattr(self, "_fallback_policy") and self._fallback_policy is not None: + return self._fallback_policy + elif hasattr(self, "policy") and self.policy is not None: + return self.policy + else: + raise ValueError(f"No policy found for model_id '{model_id}'") + else: + # Try to resolve via attribute access + if hasattr(self, model_id): + return getattr(self, model_id) + else: + raise ValueError(f"Unknown model_id: {model_id}") + + def get_cached_weights(self, model_id: str): + """Get cached shared memory weights if available (for weight sync schemes). + + Args: + model_id: Model identifier + + Returns: + Cached TensorDict weights or None if not available + """ + if model_id == "policy" and hasattr(self, "_policy_weights_dict"): + # Get the policy device (first device if list) + policy_device = self.policy_device + if isinstance(policy_device, (list, tuple)): + policy_device = policy_device[0] if len(policy_device) > 0 else None + + # Return cached weights for this device + return self._policy_weights_dict.get(policy_device) + return None + @accept_remote_rref_udf_invocation class MultiSyncDataCollector(_MultiDataCollector): @@ -3597,7 +3685,12 @@ def next(self): return super().next() # for RPC - def shutdown(self, timeout: float | None = None, close_env: bool = True) -> None: + def shutdown( + self, + timeout: float | None = None, + close_env: bool = True, + raise_on_error: bool = True, + ) -> None: if not close_env: raise RuntimeError( f"Cannot shutdown {type(self).__name__} collector without environment being closed." @@ -3606,7 +3699,13 @@ def shutdown(self, timeout: float | None = None, close_env: bool = True) -> None del self.out_buffer if hasattr(self, "buffers"): del self.buffers - return super().shutdown(timeout=timeout) + try: + return super().shutdown(timeout=timeout) + except Exception as e: + if raise_on_error: + raise e + else: + pass # for RPC def set_seed(self, seed: int, static_seed: bool = False) -> int: @@ -3998,14 +4097,19 @@ def next(self): return super().next() # for RPC - def shutdown(self, timeout: float | None = None, close_env: bool = True) -> None: + def shutdown( + self, + timeout: float | None = None, + close_env: bool = True, + raise_on_error: bool = True, + ) -> None: if hasattr(self, "out_tensordicts"): del self.out_tensordicts if not close_env: raise RuntimeError( f"Cannot shutdown {type(self).__name__} collector without environment being closed." ) - return super().shutdown(timeout=timeout) + return super().shutdown(timeout=timeout, raise_on_error=raise_on_error) # for RPC def set_seed(self, seed: int, static_seed: bool = False) -> int: @@ -4360,8 +4464,15 @@ def next(self): return super().next() # for RPC - def shutdown(self, timeout: float | None = None, close_env: bool = True) -> None: - return super().shutdown(timeout=timeout, close_env=close_env) + def shutdown( + self, + timeout: float | None = None, + close_env: bool = True, + raise_on_error: bool = True, + ) -> None: + return super().shutdown( + timeout=timeout, close_env=close_env, raise_on_error=raise_on_error + ) # for RPC def set_seed(self, seed: int, static_seed: bool = False) -> int: @@ -4449,13 +4560,20 @@ def _main_async_collector( # Set up weight receivers for worker process if weight_sync_schemes: inner_collector._weight_receivers = {} + inner_collector.pipe = pipe_child # Add pipe attribute for context for model_id, scheme in weight_sync_schemes.items(): - receiver = scheme.create_receiver() - receiver.set_context(inner_collector) - receiver.register_worker_transport(pipe_child) + # Check if scheme has new API or legacy API + if hasattr(scheme, "init_on_worker"): + scheme.init_on_worker(model_id=model_id, context=inner_collector) + receiver = scheme.get_receiver() + else: + # Legacy API + receiver = scheme.create_receiver() + receiver.set_context(inner_collector) + receiver.register_worker_transport(pipe_child) - model = _resolve_model(inner_collector, model_id) - receiver.register_model(model) + model = _resolve_model(inner_collector, model_id) + receiver.register_model(model) inner_collector._weight_receivers[model_id] = receiver else: @@ -4591,12 +4709,21 @@ def _main_async_collector( # Only apply if the model is an nn.Module (has learnable parameters) try: model = receiver._resolve_model_ref() - if isinstance(model, nn.Module): - receiver.apply_weights(shared_buffer) - except (ValueError, AttributeError): - # Model not registered or not an nn.Module (e.g., RandomPolicy) - # Skip weight application - this is expected for policies without parameters - pass + except (ValueError, AttributeError) as e: + # Model not registered or reference is invalid + if verbose: + torchrl_logger.warning( + f"worker {idx} could not resolve model '{model_id}': {e}" + ) + continue + + if isinstance(model, nn.Module): + receiver.apply_weights(shared_buffer) + else: + if verbose: + torchrl_logger.info( + f"worker {idx} skipping weight application for non-nn.Module model '{model_id}'" + ) if verbose: torchrl_logger.info( @@ -4644,6 +4771,13 @@ def _main_async_collector( inner_collector.init_random_frames = float("inf") else: inner_collector.init_random_frames = -1 + + # Note: For MultiProcessWeightSyncScheme, weight updates are handled by the + # main message loop above (msg == "update_weights" case). The receiver.receive() + # pattern is only used for schemes with separate communication channels like + # SharedMemWeightSyncScheme (shared memory) or DistributedWeightSyncScheme (TCPStore). + # Calling receiver.receive() here would interfere with the pipe-based message protocol. + next_data = next(dc_iter) if pipe_child.poll(_MIN_TIMEOUT): # in this case, main send a message to the worker while it was busy collecting trajectories. diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py index 67ca90363fe..b8b28345872 100644 --- a/torchrl/collectors/distributed/ray.py +++ b/torchrl/collectors/distributed/ray.py @@ -550,26 +550,18 @@ def check_list_length_consistency(*lists): self._weight_sync_schemes = weight_sync_schemes self._weight_senders = {} - # Set up weight senders now that remote collectors exist + # Set up weight senders using the new simplified API for model_id, scheme in self._weight_sync_schemes.items(): - sender = scheme.create_sender() - sender._model_id = model_id - - # Register each remote collector as a separate worker - # This follows the same pattern as multiprocess collectors - for worker_idx, remote_collector in enumerate(self.remote_collectors): - # Create a transport for this specific collector - # Pass the collector as context so the transport knows which one to talk to - sender.register_worker(worker_idx, remote_collector) - - # Set context and register model - if hasattr(sender, "set_context"): - sender.set_context(self, model_id) - - # Store reference to source model for automatic extraction - if model_id == "policy": - sender._source_model = self.policy + # Initialize the scheme on the sender (main process) side + # Pass remote collectors as the "workers" for Ray schemes + scheme.init_on_sender( + model_id=model_id, + remote_collectors=self.remote_collectors, + source_model=self.policy if model_id == "policy" else None, + ) + # Get the configured sender from the scheme + sender = scheme.get_sender() self._weight_senders[model_id] = sender self.weight_updater = None # Don't use legacy system diff --git a/torchrl/weight_update/weight_sync_schemes.py b/torchrl/weight_update/weight_sync_schemes.py index 2f3bf24ae40..34b35da9446 100644 --- a/torchrl/weight_update/weight_sync_schemes.py +++ b/torchrl/weight_update/weight_sync_schemes.py @@ -5,6 +5,7 @@ from __future__ import annotations import abc + import weakref from collections.abc import Iterator from typing import Any, Literal, Protocol @@ -74,9 +75,8 @@ def send_weights(self, model_id: str, weights: Any) -> None: Sends weights and waits for acknowledgment to ensure delivery. """ - self.pipe.send(((model_id, weights), "update_weights")) - # Wait for acknowledgment - self.check_ack("updated") + self.send_weights_async(model_id, weights) + self.wait_ack() def send_weights_async(self, model_id: str, weights: Any) -> None: """Send weights through the pipe without waiting for acknowledgment. @@ -90,12 +90,30 @@ def wait_ack(self) -> None: self.check_ack("updated") def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - """Receive weights from the pipe (used in worker process).""" + """Receive weights from the pipe (used in worker process). + + This method only handles weight update messages. Other messages + (like "close", "continue", etc.) are ignored and should be handled + by the main worker loop. + + Returns: + Tuple of (model_id, weights) if weights were received, None if no data available + or if a non-weight message was received. + """ if self.pipe.poll(timeout): data_in, msg = self.pipe.recv() if msg == "update_weights": model_id, weights = data_in return model_id, weights + else: + # Not a weight update message - put it back and return None + # This allows the main worker loop to handle other messages + # Note: We can't actually "put it back", so we'll just return None + # and the message is lost. This is why receive() should only be called + # when we're expecting weight updates, not in the main message loop. + return None + # No data available - return None instead of raising TimeoutError + # This allows non-blocking checks in the worker loop return None def send_ack(self, message: str = "updated") -> None: @@ -130,7 +148,7 @@ class SharedMemTransport: policy_weights: Dictionary mapping model_id to shared TensorDict weights. Can be empty if using lazy registration. auto_register: Whether to automatically register models on first weight send. - Default is True. Set to False to require explicit registration via + Default is True. Set to `False` to require explicit registration via register_weights(). """ @@ -142,9 +160,8 @@ def __init__( self._policy_weights = policy_weights if policy_weights is not None else {} self._auto_register = auto_register self._pipes = [] # List of pipes to send initial buffer references - self._registered_with_workers = ( - set() - ) # Track which model_ids have been sent to workers + # Track which model_ids have been sent to workers + self._registered_with_workers = set() def register_pipe(self, pipe: Any) -> None: """Register a pipe for sending buffer references on first weight send. @@ -162,23 +179,32 @@ def register_weights(self, model_id: str, weights: TensorDictBase) -> None: when auto_register=True (the default), but required when auto_register=False. If pipes are registered and this model hasn't been sent to workers yet, - this will trigger sending the buffer reference to all workers. + this will trigger sending the buffer reference to all workers. If pipes + aren't registered yet, weights are stored and will be sent when pipes + become available (during init_on_sender). """ if not isinstance(weights, TensorDictBase): raise ValueError(f"Weights must be a TensorDictBase, got {type(weights)}") is_new_registration = model_id not in self._policy_weights - self._policy_weights[model_id] = weights + if is_new_registration: + self._policy_weights[model_id] = weights + else: + raise RuntimeError("Re-registering weights is not supported.") # If this is a new registration and we have pipes, send buffer to workers - if ( - is_new_registration - and self._pipes - and model_id not in self._registered_with_workers - ): - self._send_buffer_to_workers(model_id, weights) - - def _send_buffer_to_workers(self, model_id: str, buffer: TensorDictBase) -> None: + # If pipes aren't available yet, defer sending until init_on_sender is called + if self._pipes: + if model_id not in self._registered_with_workers: + self._send_buffer_to_workers(model_id, weights) + else: + raise RuntimeError( + f"Model '{model_id}' has already been registered with workers." + ) + + def _send_buffer_to_workers( + self, model_id: str, buffer: TensorDictBase, timeout: float = 10.0 + ) -> None: """Send shared memory buffer reference to all workers via pipes. This is called once per model_id when lazy registration occurs. @@ -195,6 +221,8 @@ def _send_buffer_to_workers(self, model_id: str, buffer: TensorDictBase) -> None # Wait for acknowledgments from all workers for pipe in self._pipes: + if not pipe.poll(timeout): + raise TimeoutError(f"Timeout waiting for acknowledgment from worker") _, msg = pipe.recv() if msg != "registered": raise RuntimeError(f"Expected 'registered' acknowledgment, got '{msg}'") @@ -226,28 +254,21 @@ def send_weights(self, model_id: str, weights: Any) -> None: ) # Auto-register on first send - if isinstance(weights, TensorDictBase): - # Create shared memory copy - # Unflatten keys if they're flat (e.g., 'module.0.weight' -> nested structure) - # This is necessary for to_module() to work properly - weights_to_share = weights.clone() - # Check if keys are flattened by looking for dots in key names - if any("." in key for key in weights_to_share.keys()): - weights_to_share = weights_to_share.unflatten_keys(".") - shared_buffer = weights_to_share.share_memory_() - elif isinstance(weights, dict): - # Convert dict to TensorDict and share - # Unflatten if keys are flat - weights_td = TensorDict(weights, batch_size=[]) - if any("." in key for key in weights_td.keys()): - weights_td = weights_td.unflatten_keys(".") - shared_buffer = weights_td.share_memory_() - else: + if isinstance(weights, dict): + weights = TensorDict(weights) + if not isinstance(weights, TensorDictBase): raise ValueError( f"Cannot auto-register model '{model_id}' with weights type: {type(weights)}. " f"Supported types for auto-registration: TensorDictBase, dict. " f"Please manually register shared weights using register_weights()." ) + # Unflatten keys if they're flat (e.g., 'module.0.weight' -> nested structure) + # This is necessary for to_module() to work properly + weights_to_share = weights + # Check if keys are flattened by looking for dots in key names + if any("." in key for key in weights_to_share.keys()): + weights_to_share = weights_to_share.unflatten_keys(".") + shared_buffer = weights_to_share.share_memory_() self._policy_weights[model_id] = shared_buffer @@ -258,30 +279,15 @@ def send_weights(self, model_id: str, weights: Any) -> None: shared_weights = self._policy_weights[model_id] # Update shared memory in-place (workers see this automatically) - if isinstance(weights, TensorDictBase): - # Unflatten if needed to match shared buffer structure - weights_to_update = weights - if any("." in key for key in weights.keys()): - weights_to_update = weights.unflatten_keys(".") - shared_weights.data.update_( - weights_to_update.data - if hasattr(weights_to_update, "data") - else weights_to_update - ) - elif isinstance(weights, dict): - # For dict updates, check if we need to unflatten keys - if any("." in key for key in weights.keys()): - # Convert to TensorDict, unflatten, then update - weights_td = TensorDict(weights, batch_size=[]) - weights_td = weights_td.unflatten_keys(".") - shared_weights.data.update_(weights_td.data) - else: - # Direct key-by-key update for non-flattened dict - for key, value in weights.items(): - if key in shared_weights.keys(True, True): - shared_weights.set(key, value) - else: + if isinstance(weights, dict): + weights = TensorDict(weights) + if not isinstance(weights, TensorDictBase): raise ValueError(f"Unsupported weights type: {type(weights)}") + # Unflatten if needed to match shared buffer structure + weights_to_update = weights + if any("." in key for key in weights.keys()): + weights_to_update = weights.unflatten_keys(".") + shared_weights.data.update_(weights_to_update.data) def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: """No-op for shared memory - weights are already visible.""" @@ -306,7 +312,11 @@ class RayTransport: same pattern as multiprocess collectors. """ - def __init__(self, remote_collector=None): + def __init__( + self, + remote_collector=None, + tensor_transport: Literal["object_store", "nixl"] = "object_store", + ): try: import ray @@ -314,6 +324,7 @@ def __init__(self, remote_collector=None): except ImportError: raise ImportError("Ray is required for RayTransport") self._remote_collector = remote_collector + self._tensor_transport = tensor_transport def send_weights(self, model_id: str, weights: Any) -> None: """Send weights to the remote collector via Ray. @@ -327,7 +338,7 @@ def send_weights(self, model_id: str, weights: Any) -> None: # Put weights in Ray's object store for efficient distribution # Ray will automatically deduplicate if the same weights are sent to multiple actors - weights_ref = self.ray.put(weights) + weights_ref = self.ray.put(weights, _tensor_transport=self._tensor_transport) # Send to the remote collector and wait for completion # This ensures weights are applied before we continue @@ -344,7 +355,7 @@ def send_weights_async(self, model_id: str, weights: Any) -> None: if self._remote_collector is None: return - weights_ref = self.ray.put(weights) + weights_ref = self.ray.put(weights, _tensor_transport=self._tensor_transport) self._pending_future = self._remote_collector.update_policy_weights_.remote( policy_or_weights=weights_ref ) @@ -354,6 +365,8 @@ def wait_ack(self) -> None: if hasattr(self, "_pending_future"): self.ray.wait([self._pending_future], num_returns=1) del self._pending_future + else: + raise RuntimeError("No pending future. Did you call send_weights_async?") def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: """Ray workers typically don't receive weights through this transport.""" @@ -372,7 +385,12 @@ class RayActorTransport: update_weights method rather than going through collector update methods. """ - def __init__(self, actor_ref=None, update_method: str = "tensordict"): + def __init__( + self, + actor_ref=None, + update_method: str = "tensordict", + tensor_transport: Literal["object_store", "nixl"] = "object_store", + ): try: import ray @@ -382,6 +400,7 @@ def __init__(self, actor_ref=None, update_method: str = "tensordict"): self._actor_ref = actor_ref self._update_method = update_method + self._tensor_transport = tensor_transport def set_actor(self, actor_ref): """Set the Ray actor reference to communicate with.""" @@ -392,7 +411,7 @@ def send_weights(self, model_id: str, weights: Any) -> None: if self._actor_ref is None: return - weights_ref = self.ray.put(weights) + weights_ref = self.ray.put(weights, _tensor_transport=self._tensor_transport) if self._update_method == "tensordict": self.ray.get( @@ -415,7 +434,7 @@ def send_weights_async(self, model_id: str, weights: Any) -> None: if self._actor_ref is None: return - weights_ref = self.ray.put(weights) + weights_ref = self.ray.put(weights, _tensor_transport=self._tensor_transport) if self._update_method == "tensordict": self._pending_future = self._actor_ref._update_weights_tensordict.remote( @@ -433,6 +452,8 @@ def wait_ack(self) -> None: if hasattr(self, "_pending_future"): self.ray.get(self._pending_future) del self._pending_future + else: + raise RuntimeError("No pending future. Did you call send_weights_async?") def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: """Ray actor workers receive weights through direct method calls.""" @@ -538,6 +559,7 @@ def __init__(self, store=None, rank=None, sync=True): self._store = store self._rank = rank self._sync = sync + self._weights_buffer = None # TensorDict buffer for receiving weights def send_weights(self, model_id: str, weights: Any) -> None: """Send weights to the distributed worker. @@ -592,12 +614,70 @@ def wait_ack(self) -> None: self._store.delete_key(f"NODE_{self._rank}_out") def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - """Distributed workers receive weights through torch.distributed primitives.""" + """Receive weights via torch.distributed, using TCPStore for signaling. + + This implements the RPC-like pattern: + 1. Check TCPStore for signal (non-blocking) + 2. If signal present, receive weights via torch.distributed + 3. Clean up signal and send acknowledgment + + Args: + timeout: Timeout for receiving (currently not used for TCPStore check) + + Returns: + Tuple of (model_id, weights) if weights were received, None otherwise. + """ + if self._store is None or self._rank is None: + return None + + try: + # Non-blocking check of TCPStore "mailbox" for signal + msg = self._store.get(f"NODE_{self._rank}_in") + + if msg == b"update_weights": + # Initialize weights buffer on first use + if self._weights_buffer is None: + self._weights_buffer = TensorDict() + + # Receive weights via torch.distributed + # recv() and irecv() update the TensorDict in place + if self._sync: + self._weights_buffer.recv(src=0) + else: + # irecv() blocks until weights are received + self._weights_buffer.irecv(src=0) + + # Clean up the signal + self._store.delete_key(f"NODE_{self._rank}_in") + + # Note: Acknowledgment is sent separately via send_ack() if transport supports it + # This matches the pattern in WeightReceiver.receive() + + # Return model_id and received weights + # For distributed transport, we use "policy" as default model_id + return ("policy", self._weights_buffer) + else: + raise ValueError(f"Expected 'update_weights' but got {msg}") + except KeyError: + # No message in store - no weights available + return None + return None + def send_ack(self, message: str = "updated") -> None: + """Send acknowledgment back to sender via TCPStore. + + Args: + message: Acknowledgment message to send (default: "updated") + """ + if self._store is None or self._rank is None: + return + + self._store.set(f"NODE_{self._rank}_out", message.encode()) + def check_connection(self) -> bool: """Check if torch.distributed is initialized.""" - import torch + import torch.distributed return torch.distributed.is_initialized() @@ -687,33 +767,29 @@ def apply_weights(self, destination: Any, weights: Any) -> None: # Auto-detect format from weights type if isinstance(weights, dict): - weights = TensorDict(weights).unflatten_keys(".") + weights = TensorDict(weights) + if any("." in key for key in weights.keys()): + weights = weights.unflatten_keys(".") + if isinstance(destination, nn.Module): + destination = TensorDict.from_module(destination) + elif isinstance(destination, dict): + destination = TensorDict(destination) + if any(isinstance(key, str) and "." in key for key in destination.keys()): + destination = destination.unflatten_keys(".") if isinstance(weights, TensorDictBase): # Apply TensorDict format - if isinstance(destination, nn.Module): - destination = TensorDict.from_module(destination) - - if isinstance(destination, dict): - destination_td = TensorDict(destination) - if (dest_keys := sorted(destination_td.keys(True, True))) != sorted( - weights.keys(True, True) - ): - weights = weights.unflatten_keys(".") - weights_keys = sorted(weights.keys(True, True)) - if dest_keys != weights_keys: - raise ValueError( - f"The keys of the weights and destination do not match: {dest_keys} != {weights_keys}" - ) - destination = destination_td - if isinstance(destination, TensorDictBase): - destination.data.update_(weights.data) + try: + destination.data.update_(weights.data) + except Exception as e: + raise KeyError( + f"Error updating destination: {e}. Destination keys: {destination.keys(True, True)}, weights keys: {weights.keys(True, True)}" + ) else: raise ValueError( f"Unsupported destination type for TensorDict: {type(destination)}" ) - else: raise ValueError( f"Unsupported weights type: {type(weights)}. Expected dict or TensorDictBase." @@ -742,10 +818,10 @@ def _get_strategy(strategy: Literal["tensordict", "state_dict"]) -> WeightStrate class WeightSender: - """Sends weights for ONE model to ALL workers. + """Sends weights for ONE model to workers. - This class handles sending weights to all workers via their transports. - Weight extraction is the responsibility of the caller. + A single sender can broadcast to all workers or send to specific workers. + Created and managed by WeightSyncScheme. Users should not instantiate directly. """ _transport: TransportBackend | None @@ -754,13 +830,16 @@ class WeightSender: def __init__(self, scheme: WeightSyncScheme): self._scheme = scheme self._transports: dict[int, TransportBackend] = {} # worker_idx -> transport - self._transport: TransportBackend = None + self._transport: TransportBackend | None = None self._model_id = "policy" # Default model ID self._strategy = _get_strategy(scheme.strategy) self._context_ref = None # weakref to collector for model resolution + self._pending_async = False # Track if async send is pending + + def _set_context(self, context: Any, model_id: str | None = None) -> None: + """Set the context object (collector) for model resolution (internal). - def set_context(self, context: Any, model_id: str | None = None) -> None: - """Set the context object (collector) for model resolution. + This is now handled by init_on_sender(). Only kept for internal use. Args: context: The collector instance. @@ -770,8 +849,10 @@ def set_context(self, context: Any, model_id: str | None = None) -> None: if model_id is not None: self._model_id = model_id - def register_worker(self, worker_idx: int, pipe_or_context: Any) -> None: - """Register a worker's communication pipe. + def _register_worker(self, worker_idx: int, pipe_or_context: Any) -> None: + """Register a worker's communication pipe (internal). + + This is now handled by init_on_sender(). Only kept for internal use. Args: worker_idx: The worker index. @@ -782,30 +863,62 @@ def register_worker(self, worker_idx: int, pipe_or_context: Any) -> None: pipe_or_context ) - def _iterate_transports(self) -> Iterator[TransportBackend]: - if not self._transports: - yield self._transport + def _iterate_transports( + self, worker_ids: int | list[int] | None = None + ) -> Iterator[TransportBackend]: + """Iterate over transports for specified workers.""" + if worker_ids is None: + # All workers + if not self._transports: + yield self._transport + else: + yield from self._transports.values() else: - yield from self._transports.values() + # Specific workers + if isinstance(worker_ids, int): + worker_ids = [worker_ids] + for worker_id in worker_ids: + if worker_id in self._transports: + yield self._transports[worker_id] + else: + raise ValueError(f"Worker {worker_id} not registered") + + def send( + self, + weights: Any = None, + worker_ids: int | list[int] | None = None, + ) -> None: + """Send weights synchronously to workers. - def update_weights(self, weights: Any) -> None: - """Send weights to ALL workers for this model. + This method: + 1. Prepares weights (extracts from model if weights=None) + 2. Sends to specified workers (or all if worker_ids=None) + 3. Waits for acknowledgments from those workers + 4. Returns when workers have applied the weights Args: - weights: Weights to send (can be None, nn.Module, TensorDict, etc.). - Will be prepared by the scheme's prepare_weights method. - - Note: - This method sends weights to all workers in parallel (non-blocking), - then waits for all acknowledgments. This is much faster than sending - sequentially when there are many workers. + weights: Weights to send. Can be: + - None: Extract from model via context.get_model(model_id) + - nn.Module: Extract weights from module + - TensorDict: Use directly + - dict: Convert to TensorDict + worker_ids: Which workers to send to: + - None: Send to all workers (default) + - int: Send to single worker + - list[int]: Send to specific workers + + Note: This is a blocking call that ensures specified workers are updated + before returning. """ - model_id = getattr(self, "_model_id", "policy") + if self._pending_async: + raise RuntimeError( + "Cannot call send() while an async send is pending. Call wait_async() first." + ) - # Get context for model resolution if available + model_id = getattr(self, "_model_id", "policy") context = self._context_ref() if self._context_ref is not None else None - # Let the scheme prepare the weights (extract, convert, cache lookup, etc.) + # Let the scheme prepare the weights prepared_weights = self._scheme.prepare_weights( weights=weights, model_id=model_id, @@ -813,7 +926,7 @@ def update_weights(self, weights: Any) -> None: context=context, ) - transports = list(self._iterate_transports()) + transports = list(self._iterate_transports(worker_ids)) # Send to all workers first (non-blocking if transport supports it) for transport in transports: @@ -828,6 +941,98 @@ def update_weights(self, weights: Any) -> None: if hasattr(transport, "wait_ack"): transport.wait_ack() + def send_async( + self, + weights: Any = None, + worker_ids: int | list[int] | None = None, + ) -> None: + """Send weights asynchronously to workers (non-blocking). + + This initiates the send but returns immediately without waiting + for workers to acknowledge. You must call wait_async() before + the next send_async() or send() call. + + Args: + weights: Same as send() + worker_ids: Same as send() + + Raises: + RuntimeError: If a previous send_async() is still pending + """ + if self._pending_async: + raise RuntimeError( + "Cannot call send_async() again while a previous send is pending. Call wait_async() first." + ) + + model_id = getattr(self, "_model_id", "policy") + context = self._context_ref() if self._context_ref is not None else None + + # Let the scheme prepare the weights + prepared_weights = self._scheme.prepare_weights( + weights=weights, + model_id=model_id, + strategy=self._strategy, + context=context, + ) + + # Store transports for wait_async + self._pending_transports = list(self._iterate_transports(worker_ids)) + + # Send to all workers (non-blocking) + for transport in self._pending_transports: + if hasattr(transport, "send_weights_async"): + transport.send_weights_async(model_id, prepared_weights) + else: + raise RuntimeError( + f"transport of type {type(transport)} does not support async send." + ) + + self._pending_async = True + + def wait_async(self) -> None: + """Wait for a pending async send to complete. + + Blocks until all workers have acknowledged the previous send_async(). + This must be called after send_async() before any subsequent sends. + + Raises: + RuntimeError: If no async send is pending + """ + if not self._pending_async: + raise RuntimeError("No async send is pending. Call send_async() first.") + + # Wait for all acknowledgments + for transport in self._pending_transports: + if hasattr(transport, "wait_ack"): + transport.wait_ack() + + self._pending_async = False + self._pending_transports = None + + # Legacy method - kept for backward compatibility + def update_weights(self, weights: Any) -> None: + """Send weights to ALL workers for this model (legacy). + + Args: + weights: Weights to send (can be None, nn.Module, TensorDict, etc.). + + Note: + This is the legacy method. Use send() instead. + """ + self.send(weights=weights) + + def __getstate__(self): + """Pickle support: discard context weakref.""" + state = self.__dict__.copy() + state["_context_ref"] = None + state["_pending_async"] = False + state["_pending_transports"] = None + return state + + def __setstate__(self, state): + """Pickle support: restore state without context.""" + self.__dict__.update(state) + # ============================================================================ # Receiver (Worker Process Side) @@ -837,8 +1042,7 @@ def update_weights(self, weights: Any) -> None: class WeightReceiver: """Receives weights for ONE model in ONE worker. - This class handles receiving weights via transport and applying them to - a registered model in the worker process. + Created and managed by WeightSyncScheme. Users should not instantiate directly. """ def __init__(self, scheme: WeightSyncScheme): @@ -848,35 +1052,85 @@ def __init__(self, scheme: WeightSyncScheme): self._model_ref = None self._strategy = _get_strategy(scheme.strategy) - def set_context(self, context: Any) -> None: - """Set the context object (inner_collector) for resolving references. + def _set_context(self, context: Any) -> None: + """Set the context object (inner_collector) for resolving references (internal). + + This is now handled by init_on_worker(). Only kept for internal use. Args: context: The inner collector instance in the worker process. """ self._context_ref = weakref.ref(context) - def register_model(self, model_ref: Any) -> None: - """Register the model to apply weights to. + def _register_model(self, model_ref: Any) -> None: + """Register the model to apply weights to (internal). + + This is now handled by init_on_worker(). Only kept for internal use. Args: model_ref: Either a direct object reference or a string path like 'policy' or 'env.value_net'. """ self._model_ref = model_ref - def register_worker_transport(self, pipe: Any) -> None: - """Register this worker's communication pipe. + def _register_worker_transport(self, pipe: Any) -> None: + """Register this worker's communication pipe (internal). + + This is now handled by init_on_worker(). Only kept for internal use. Args: pipe: The pipe connection for this worker. """ self._transport = self._scheme.create_transport(pipe) + def receive(self, timeout: float = 0.001) -> bool: + """Check for and apply new weights (non-blocking). + + This method is called in the worker's main loop to check if + new weights have been sent. If weights are available, they + are applied to the registered model immediately. + + Args: + timeout: Maximum time to wait for weights (seconds). + Use 0 for immediate return. + + Returns: + True if weights were received and applied + False if no weights were available + + Note: For SharedMemWeightSyncScheme, this always returns False + since workers automatically see updates via shared memory. + """ + if self._transport is None: + return False + + # Try to receive weights + result = self._transport.receive_weights(timeout=timeout) + if result is None: + return False + + model_id, weights = result + + # Apply weights to the model + if self._model_ref is None: + raise ValueError("No model registered") + + model = self._resolve_model_ref() + self._strategy.apply_weights(model, weights) + + # Send acknowledgment if transport supports it + if hasattr(self._transport, "send_ack"): + self._transport.send_ack("updated") + + return True + def apply_weights(self, weights: Any) -> None: - """Apply received weights to registered model. + """Apply received weights to registered model (legacy). Args: weights: The weights to apply. + + Note: + This is the legacy method. Use receive() in the worker loop instead. """ if self._model_ref is None: raise ValueError("No model registered") @@ -899,6 +1153,16 @@ def _resolve_model_ref(self) -> Any: return _resolve_model(context, self._model_ref) return self._model_ref + def __getstate__(self): + """Pickle support: discard context weakref.""" + state = self.__dict__.copy() + state["_context_ref"] = None + return state + + def __setstate__(self, state): + """Pickle support: restore state without context.""" + self.__dict__.update(state) + class RayModuleTransformSender(WeightSender): """Specialized sender for :class:`~torchrl.envs.transforms.module.RayModuleTransform` actors. @@ -918,8 +1182,10 @@ def __init__(self, scheme: RayModuleTransformScheme): self._context_ref = None self._model_id_str = None - def set_context(self, context: Any, model_id: str) -> None: - """Set context for lazy actor resolution. + def _set_context(self, context: Any, model_id: str) -> None: + """Set context for lazy actor resolution (internal). + + This is now handled by init_on_sender(). Only kept for internal use. Args: context: The collector instance. @@ -928,8 +1194,8 @@ def set_context(self, context: Any, model_id: str) -> None: self._context_ref = weakref.ref(context) self._model_id_str = model_id - def register_worker(self, worker_idx: int, pipe_or_context: Any) -> None: - """For Ray actors, worker registration is a no-op. + def _register_worker(self, worker_idx: int, pipe_or_context: Any) -> None: + """For Ray actors, worker registration is a no-op (internal). Ray actors are shared across all workers, so we don't need per-worker transports. The actor reference is resolved lazily on first use. @@ -977,8 +1243,10 @@ class RayModuleTransformReceiver(WeightReceiver): def __init__(self, scheme: RayModuleTransformScheme): super().__init__(scheme) - def register_worker_transport(self, actor_or_context: Any) -> None: - """Register the Ray actor's transport. + def _register_worker_transport(self, actor_or_context: Any) -> None: + """Register the Ray actor's transport (internal). + + This is now handled by init_on_worker(). Only kept for internal use. Args: actor_or_context: Either a Ray actor reference or a context object. @@ -1009,16 +1277,108 @@ def apply_weights(self, weights: Any) -> None: class WeightSyncScheme(metaclass=abc.ABCMeta): """Configuration for how to synchronize ONE model across workers. - A scheme is a pure configuration object that specifies: - - The transmission strategy (state_dict vs tensordict) - - How to create the transport for communication - - Each scheme is responsible for one model type but is shared across all workers. + A scheme manages synchronization of ONE model across workers. + The collector maintains a dict of {model_id: scheme} pairs. """ def __init__(self, strategy: Literal["state_dict", "tensordict"] = "state_dict"): self.strategy = strategy + self._sender = None + self._receiver = None + self._initialized_on_sender = False + self._initialized_on_worker = False + + def init_on_sender( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on the main process (sender side). + + This method is called once in the collector's _run_processes() method, + after workers have been started and are ready to receive messages. + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object (e.g., collector) providing: + - .pipes: list[mp.Connection] + - .get_model(model_id: str) -> nn.Module + - .get_cached_weights(model_id: str) -> TensorDict | None + - .num_workers: int + **kwargs: Alternative to context (pipes, num_workers, model, cached_weights, etc.) + """ + raise NotImplementedError + + def init_on_worker( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on worker process (receiver side). + + This method is called once in each worker's initialization. + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object (e.g., inner collector) providing: + - .pipe: mp.Connection + - .get_model(model_id: str) -> nn.Module + **kwargs: Alternative to context (pipe, model, etc.) + """ + raise NotImplementedError + + def get_sender(self) -> WeightSender: + """Get the sender instance. + + Returns: + Sender instance for sending weights to workers + + Raises: + RuntimeError: If init_on_sender() hasn't been called yet + """ + if not self._initialized_on_sender or self._sender is None: + raise RuntimeError( + f"Must call init_on_sender() before get_sender() on {type(self).__name__}" + ) + return self._sender + + def get_receiver(self) -> WeightReceiver: + """Get the receiver instance. + + Returns: + Receiver instance for receiving weights in this worker + + Raises: + RuntimeError: If init_on_worker() hasn't been called yet + """ + if not self._initialized_on_worker or self._receiver is None: + raise RuntimeError( + f"Must call init_on_worker() before get_receiver() on {type(self).__name__}" + ) + return self._receiver + def __getstate__(self): + """Prepare the scheme for pickling by excluding non-serializable runtime state. + + Sender and receiver objects contain pipes, weak references, and other + non-serializable resources that should not be pickled. These will be + re-initialized when needed after unpickling. + """ + state = self.__dict__.copy() + # Remove non-serializable runtime state + state["_sender"] = None + state["_receiver"] = None + state["_initialized_on_sender"] = False + state["_initialized_on_worker"] = False + return state + + def __setstate__(self, state): + """Restore the scheme from pickling.""" + self.__dict__.update(state) + + # Legacy methods - kept for backward compatibility @abc.abstractmethod def create_transport(self, pipe_or_context: Any) -> TransportBackend: """Create transport for communication. @@ -1032,7 +1392,7 @@ def create_transport(self, pipe_or_context: Any) -> TransportBackend: ... def create_sender(self) -> WeightSender: - """Create a sender for this scheme. + """Create a sender for this scheme (legacy). Returns: WeightSender instance configured for this scheme. @@ -1040,7 +1400,7 @@ def create_sender(self) -> WeightSender: return WeightSender(self) def create_receiver(self) -> WeightReceiver: - """Create a receiver for this scheme. + """Create a receiver for this scheme (legacy). Returns: WeightReceiver instance configured for this scheme. @@ -1103,8 +1463,87 @@ class MultiProcessWeightSyncScheme(WeightSyncScheme): This scheme creates transports that communicate via multiprocessing pipes. """ + def init_on_sender( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on the main process (sender side). + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object providing pipes and num_workers + **kwargs: Alternative to context (pipes, num_workers, etc.) + """ + # Extract parameters from context or kwargs + if context is not None: + pipes = getattr(context, "pipes", None) + num_workers = getattr(context, "num_workers", None) + else: + pipes = kwargs.get("pipes") + num_workers = kwargs.get("num_workers") + + if pipes is None: + raise ValueError("pipes must be provided via context or kwargs") + if num_workers is None: + num_workers = len(pipes) if pipes else 0 + + # Create sender and register all workers + sender = WeightSender(self) + sender._model_id = model_id + if context is not None: + sender._context_ref = weakref.ref(context) + + for worker_idx, pipe in enumerate(pipes): + sender._register_worker(worker_idx, pipe) + + self._sender = sender + self._initialized_on_sender = True + + def init_on_worker( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on worker process (receiver side). + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object providing pipe and model + **kwargs: Alternative to context (pipe, model, etc.) + """ + # Extract parameters from context or kwargs + if context is not None: + pipe = getattr(context, "pipe", None) + if hasattr(context, "get_model"): + model = context.get_model(model_id) + else: + model = None + else: + pipe = kwargs.get("pipe") + model = kwargs.get("model") + + if pipe is None: + raise ValueError("pipe must be provided via context or kwargs") + + # Create receiver and register model + receiver = WeightReceiver(self) + if context is not None: + receiver._context_ref = weakref.ref(context) + receiver._register_worker_transport(pipe) + if model is not None: + receiver._register_model(model) + else: + # Register by model_id for later resolution + receiver._register_model(model_id) + + self._receiver = receiver + self._initialized_on_worker = True + def create_transport(self, pipe: Any) -> TransportBackend: - """Create an MPTransport using the provided pipe.""" + """Create an MPTransport using the provided pipe (legacy).""" return MPTransport(pipe) @@ -1163,11 +1602,117 @@ def register_shared_weights(self, model_id: str, weights: TensorDictBase) -> Non model_id: Identifier for the model. weights: Shared memory TensorDict containing the model's weights. """ - self.policy_weights[model_id] = weights + # Don't set self.policy_weights[model_id] here - register_weights does that + # (self.policy_weights and transport._policy_weights are the same dict) self._shared_transport.register_weights(model_id, weights) + def init_on_sender( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on the main process (sender side). + + For SharedMemWeightSyncScheme, this handles: + 1. Getting cached shared memory weights from context + 2. Pre-registering the weights with the transport + 3. Distributing buffer references to all workers (avoiding later deadlock) + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object providing pipes, cached_weights + **kwargs: Alternative to context (pipes, cached_weights, etc.) + """ + # Extract parameters from context or kwargs + if context is not None: + pipes = getattr(context, "pipes", None) + num_workers = getattr(context, "num_workers", None) + # Try to get cached shared memory weights + if hasattr(context, "get_cached_weights"): + cached_weights = context.get_cached_weights(model_id) + else: + cached_weights = None + else: + pipes = kwargs.get("pipes") + num_workers = kwargs.get("num_workers") + cached_weights = kwargs.get("cached_weights") + + if pipes is None: + raise ValueError("pipes must be provided via context or kwargs") + if num_workers is None: + num_workers = len(pipes) if pipes else 0 + + # Register pipes with shared transport for lazy buffer distribution + for pipe in pipes: + self._shared_transport.register_pipe(pipe) + + # If we have cached shared memory weights, pre-register them + if cached_weights is not None: + # Check if already registered to avoid re-registration error + if model_id not in self.policy_weights: + self.register_shared_weights(model_id, cached_weights) + + # Send buffer references for any weights that were pre-registered + # before pipes were available (e.g., via explicit register_shared_weights call) + if model_id in self.policy_weights: + if model_id not in self._shared_transport._registered_with_workers: + self._shared_transport._send_buffer_to_workers( + model_id, self.policy_weights[model_id] + ) + + # Create sender with the shared transport + sender = WeightSender(self) + sender._model_id = model_id + sender._transport = self._shared_transport # Use shared transport + if context is not None: + sender._context_ref = weakref.ref(context) + + self._sender = sender + self._initialized_on_sender = True + + def init_on_worker( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on worker process (receiver side). + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object providing pipe and model + **kwargs: Alternative to context (pipe, model, etc.) + """ + # Extract parameters from context or kwargs + if context is not None: + getattr(context, "pipe", None) + if hasattr(context, "get_model"): + model = context.get_model(model_id) + else: + model = None + else: + model = kwargs.get("model") + + # For shared memory, we don't need the pipe in the receiver + # The transport is shared and workers see updates automatically + + # Create receiver with the shared transport + receiver = WeightReceiver(self) + if context is not None: + receiver._context_ref = weakref.ref(context) + receiver._transport = self._shared_transport # Use shared transport + if model is not None: + receiver._register_model(model) + else: + # Register by model_id for later resolution + receiver._register_model(model_id) + + self._receiver = receiver + self._initialized_on_worker = True + def create_transport(self, pipe_or_context: Any) -> TransportBackend: - """Create shared memory transport and register pipe for lazy buffer distribution. + """Create shared memory transport and register pipe for lazy buffer distribution (legacy). For lazy registration to work, we register each worker's pipe with the transport. On first weight send, the transport will send buffer references via these pipes. @@ -1223,8 +1768,48 @@ class NoWeightSyncScheme(WeightSyncScheme): This scheme disables weight synchronization entirely. """ + def init_on_sender( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on the main process (sender side). + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object (not used) + **kwargs: Optional parameters (not used) + """ + # Create a no-op sender + sender = WeightSender(self) + sender._model_id = model_id + + self._sender = sender + self._initialized_on_sender = True + + def init_on_worker( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on worker process (receiver side). + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object (not used) + **kwargs: Optional parameters (not used) + """ + # Create a no-op receiver + receiver = WeightReceiver(self) + receiver._model_ref = model_id + + self._receiver = receiver + self._initialized_on_worker = True + def create_transport(self, pipe_or_context: Any) -> TransportBackend: - """Returns None as no transport is needed.""" + """Returns None as no transport is needed (legacy).""" # Return a dummy transport that does nothing class NoOpTransport: def send_weights(self, model_id: str, weights: Any) -> None: @@ -1260,6 +1845,88 @@ def create_transport(self, pipe_or_context: Any) -> TransportBackend: """ return RayTransport(remote_collector=pipe_or_context) + def init_on_sender( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on the main process (sender side). + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object providing remote_collectors + **kwargs: Alternative to context (remote_collectors, source_model, etc.) + """ + # Extract parameters from context or kwargs + if context is not None: + remote_collectors = getattr(context, "remote_collectors", None) + num_workers = getattr(context, "num_workers", None) or getattr( + context, "num_collectors", None + ) + else: + remote_collectors = kwargs.get("remote_collectors") + num_workers = kwargs.get("num_workers") or kwargs.get("num_collectors") + + if remote_collectors is None: + raise ValueError("remote_collectors must be provided via context or kwargs") + if num_workers is None: + num_workers = len(remote_collectors) if remote_collectors else 0 + + # Create sender and register all workers (Ray actors) + sender = WeightSender(self) + sender._model_id = model_id + + # Register each Ray actor - _register_worker will create the transport + for worker_idx, remote_collector in enumerate(remote_collectors): + sender._register_worker(worker_idx, remote_collector) + + # Set context with weak reference to avoid circular refs + if context is not None: + sender._set_context(weakref.ref(context), model_id) + + # Store source model reference if provided for automatic weight extraction + source_model = kwargs.get("source_model") + if source_model is not None: + sender._source_model = source_model + + self._sender = sender + self._initialized_on_sender = True + + def init_on_worker( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on worker process (receiver side). + + For Ray workers, weight updates are handled via remote method calls, + so this is typically a no-op. The receiver is created but doesn't + need special initialization. + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object (typically the remote collector) + **kwargs: Optional parameters (pipe, model, etc.) + """ + # Create receiver + receiver = WeightReceiver(self) + + # Register model if provided + model = kwargs.get("model") or ( + getattr(context, "policy", None) if context else None + ) + if model is not None: + receiver._register_model(model) + + # Set context if provided + if context is not None: + receiver._set_context(weakref.ref(context)) + + self._receiver = receiver + self._initialized_on_worker = True + class RayModuleTransformScheme(WeightSyncScheme): """Weight synchronization for RayModuleTransform actors. @@ -1311,6 +1978,98 @@ def create_receiver(self) -> RayModuleTransformReceiver: """Create a specialized receiver for Ray actor communication.""" return RayModuleTransformReceiver(self) + def init_on_sender( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on the main process (sender side). + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object providing actor references + **kwargs: Alternative to context (actors, actor_refs, source_model, etc.) + """ + # Extract actor references from context or kwargs + if context is not None: + # Could be actor_refs, actors, or remote_collectors + actor_refs = ( + getattr(context, "actor_refs", None) + or getattr(context, "actors", None) + or getattr(context, "remote_collectors", None) + ) + else: + actor_refs = ( + kwargs.get("actor_refs") + or kwargs.get("actors") + or kwargs.get("remote_collectors") + ) + + if actor_refs is None: + raise ValueError( + "actor_refs (or actors) must be provided via context or kwargs" + ) + + # Create specialized sender + sender = self.create_sender() + sender._model_id = model_id + + # Register all actors - _register_worker will create the transport + for worker_idx, actor_ref in enumerate(actor_refs): + sender._register_worker(worker_idx, actor_ref) + + # Set context with weak reference + if context is not None: + sender._set_context(weakref.ref(context), model_id) + + # Store source model if provided + source_model = kwargs.get("source_model") + if source_model is not None: + sender._source_model = source_model + + self._sender = sender + self._initialized_on_sender = True + + def init_on_worker( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on worker process (receiver side). + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object (typically the actor itself) + **kwargs: Optional parameters (actor_ref, model, etc.) + """ + # Create specialized receiver + receiver = self.create_receiver() + + # Extract actor reference if needed + actor_ref = kwargs.get("actor_ref") or context + if actor_ref is not None: + # Register the transport for this actor + transport = self.create_transport(actor_ref) + receiver._register_worker_transport(transport) + + # Register model if provided + model = kwargs.get("model") or ( + getattr(context, "_actor_module", None) or getattr(context, "module", None) + if context + else None + ) + if model is not None: + receiver._register_model(model) + + # Set context if provided + if context is not None: + receiver._set_context(weakref.ref(context)) + + self._receiver = receiver + self._initialized_on_worker = True + class RPCWeightSyncScheme(WeightSyncScheme): """Weight synchronization for torch.distributed.rpc.