From a7e55caae9a212d56b442566b38565a8a2e4e683 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 14 Oct 2025 09:40:47 +0100 Subject: [PATCH] Update [ghstack-poisoned] --- test/llm/test_updaters.py | 1554 ++++++++++++++++- .../modules/llm/backends/vllm/vllm_async.py | 293 ++-- .../modules/llm/backends/vllm/vllm_utils.py | 12 +- torchrl/weight_update/llm/__init__.py | 32 + .../weight_update/llm/vllm_double_buffer.py | 362 ++++ torchrl/weight_update/llm/vllm_nccl.py | 699 ++++++++ 6 files changed, 2763 insertions(+), 189 deletions(-) create mode 100644 torchrl/weight_update/llm/__init__.py create mode 100644 torchrl/weight_update/llm/vllm_double_buffer.py create mode 100644 torchrl/weight_update/llm/vllm_nccl.py diff --git a/test/llm/test_updaters.py b/test/llm/test_updaters.py index 54aa3f6e2a9..ea79f1ad61e 100644 --- a/test/llm/test_updaters.py +++ b/test/llm/test_updaters.py @@ -3,32 +3,16 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -""" -Tests for vLLM weight updaters, including the new vLLMUpdaterV2. - -This module implements and tests vLLMUpdaterV2, which is an improved version -of the weight updater that automatically infers configuration from vLLM objects -instead of requiring manual specification of tensor parallel size and other parameters. - -Key improvements of vLLMUpdaterV2: -- Automatically detects tensor parallel size from vLLM engines; -- Supports multiple vLLM engine types: Ray workers, AsyncVLLM services, and local LLM instances; -- Simplifies API by removing need for manual configuration; -- Provides get_tp_size() method for introspection. - -The tests are organized by engine type to optimize GPU memory usage: -- Each test class manages its own vLLM instance and cleans up when done -- Abstract base class provides common test patterns -- Low KV cache settings are used to minimize GPU memory utilization -""" +import argparse import gc import importlib.util +import time from abc import ABC, abstractmethod import pytest import torch -from torchrl._utils import logger as torchrl_logger +from torchrl._utils import _DTYPE_TO_STR_DTYPE, _STR_DTYPE_TO_DTYPE, logger # Check for dependencies _has_vllm = importlib.util.find_spec("vllm") is not None @@ -127,9 +111,9 @@ def source_policy(self, model_name): try: del wrapper del model - torchrl_logger.info("Source policy cleaned up") + logger.info("Source policy cleaned up") except Exception as e: - torchrl_logger.warning(f"Error during source policy cleanup: {e}") + logger.warning(f"Error during source policy cleanup: {e}") finally: # Force garbage collection and CUDA memory cleanup gc.collect() @@ -143,7 +127,7 @@ def target_vllm_engine(self, model_name): def test_config_extraction(self, target_vllm_engine): """Test that configuration is correctly extracted from engine.""" - torchrl_logger.info( + logger.info( f"=== Testing config extraction for {type(target_vllm_engine).__name__} ===" ) @@ -163,13 +147,13 @@ def test_config_extraction(self, target_vllm_engine): model_metadata = target_vllm_engine.get_model_metadata() assert isinstance(model_metadata, dict) - torchrl_logger.info( + logger.info( f"✓ Config extraction test passed for {type(target_vllm_engine).__name__}" ) def test_updater_v2_creation(self, target_vllm_engine): """Test vLLMUpdaterV2 creation and configuration detection.""" - torchrl_logger.info( + logger.info( f"=== Testing vLLMUpdaterV2 creation for {type(target_vllm_engine).__name__} ===" ) @@ -182,14 +166,14 @@ def test_updater_v2_creation(self, target_vllm_engine): assert isinstance(updater.master_port, int) assert isinstance(updater.model_metadata, dict) - torchrl_logger.info( + logger.info( f"✓ vLLMUpdaterV2 creation test passed for {type(target_vllm_engine).__name__}" ) @pytest.mark.slow def test_weight_update_interface(self, source_policy, target_vllm_engine): """Test weight update using the vLLMUpdaterV2 interface.""" - torchrl_logger.info( + logger.info( f"=== Testing weight update for {type(target_vllm_engine).__name__} ===" ) @@ -199,7 +183,7 @@ def test_weight_update_interface(self, source_policy, target_vllm_engine): # Get model metadata from source policy model_metadata = vLLMUpdaterV2.get_model_metadata(source_policy) assert len(model_metadata) > 0 - torchrl_logger.info(f"Found {len(model_metadata)} parameters in model metadata") + logger.info(f"Found {len(model_metadata)} parameters in model metadata") # Initialize updater updater.init(model_metadata) @@ -207,13 +191,13 @@ def test_weight_update_interface(self, source_policy, target_vllm_engine): # Test weight update updater.push_weights_from_transformers(source_policy) - torchrl_logger.info( + logger.info( f"✓ Weight update test passed for {type(target_vllm_engine).__name__}" ) def test_error_handling(self): """Test error handling for invalid inputs.""" - torchrl_logger.info("=== Testing error handling ===") + logger.info("=== Testing error handling ===") # Test with non-RLvLLMEngine object class FakeEngine: @@ -222,7 +206,7 @@ class FakeEngine: with pytest.raises(TypeError, match="must implement RLvLLMEngine interface"): vLLMUpdaterV2(FakeEngine()) - torchrl_logger.info("✓ Error handling tests passed") + logger.info("✓ Error handling tests passed") @pytest.mark.skipif(not _has_ray, reason="missing ray dependencies") @@ -256,17 +240,15 @@ def target_vllm_engine(self, model_name): enable_prefix_caching=False, # Disable to save memory ) - torchrl_logger.info( - f"Created AsyncVLLM service with {service.num_replicas} replicas" - ) + logger.info(f"Created AsyncVLLM service with {service.num_replicas} replicas") yield service # Cleanup try: service.shutdown() - torchrl_logger.info("AsyncVLLM service shut down successfully") + logger.info("AsyncVLLM service shut down successfully") except Exception as e: - torchrl_logger.warning(f"Error during AsyncVLLM cleanup: {e}") + logger.warning(f"Error during AsyncVLLM cleanup: {e}") finally: # Force garbage collection and CUDA memory cleanup gc.collect() @@ -275,7 +257,7 @@ def target_vllm_engine(self, model_name): def test_async_vllm_specific_features(self, target_vllm_engine): """Test AsyncVLLM-specific features.""" - torchrl_logger.info("=== Testing AsyncVLLM-specific features ===") + logger.info("=== Testing AsyncVLLM-specific features ===") # Test that it's actually an AsyncVLLM instance assert isinstance(target_vllm_engine, AsyncVLLM) @@ -287,7 +269,7 @@ def test_async_vllm_specific_features(self, target_vllm_engine): # Test that actors are ready assert target_vllm_engine._launched is True - torchrl_logger.info("✓ AsyncVLLM-specific tests passed") + logger.info("✓ AsyncVLLM-specific tests passed") @pytest.mark.skipif(not _has_ray, reason="missing ray dependencies") @@ -318,16 +300,16 @@ def target_vllm_engine(self, model_name): max_num_seqs=1, # Minimal batch size ) - torchrl_logger.info("Created Ray worker") + logger.info("Created Ray worker") yield worker # Cleanup try: if hasattr(worker, "ray_actor") and ray is not None: ray.kill(worker.ray_actor) - torchrl_logger.info("Ray worker killed successfully") + logger.info("Ray worker killed successfully") except Exception as e: - torchrl_logger.warning(f"Error during Ray worker cleanup: {e}") + logger.warning(f"Error during Ray worker cleanup: {e}") finally: # Force garbage collection and CUDA memory cleanup gc.collect() @@ -336,7 +318,7 @@ def target_vllm_engine(self, model_name): def test_ray_worker_specific_features(self, target_vllm_engine): """Test Ray worker-specific features.""" - torchrl_logger.info("=== Testing Ray worker-specific features ===") + logger.info("=== Testing Ray worker-specific features ===") # Test that it's actually a RayLLMWorker instance assert isinstance(target_vllm_engine, RayLLMWorker) @@ -345,7 +327,7 @@ def test_ray_worker_specific_features(self, target_vllm_engine): assert hasattr(target_vllm_engine, "ray_actor") assert target_vllm_engine._tensor_parallel_size == 1 - torchrl_logger.info("✓ Ray worker-specific tests passed") + logger.info("✓ Ray worker-specific tests passed") class TestVLLMUpdaterV2WithLocalLLM(BaseVLLMUpdaterTest): @@ -368,7 +350,7 @@ def target_vllm_engine(self, model_name): max_num_seqs=1, # Minimal batch size ) - torchrl_logger.info("Created local LLM") + logger.info("Created local LLM") yield llm # Cleanup @@ -376,9 +358,9 @@ def target_vllm_engine(self, model_name): # For local LLM, we might need to explicitly delete the instance if hasattr(llm, "llm_instance"): del llm.llm_instance - torchrl_logger.info("Local LLM instance deleted") + logger.info("Local LLM instance deleted") except Exception as e: - torchrl_logger.warning(f"Error during local LLM cleanup: {e}") + logger.warning(f"Error during local LLM cleanup: {e}") finally: # Force garbage collection and CUDA memory cleanup gc.collect() @@ -387,7 +369,7 @@ def target_vllm_engine(self, model_name): def test_local_llm_specific_features(self, target_vllm_engine): """Test local LLM-specific features.""" - torchrl_logger.info("=== Testing local LLM-specific features ===") + logger.info("=== Testing local LLM-specific features ===") # Test that it's actually a LocalLLMWrapper instance assert isinstance(target_vllm_engine, LocalLLMWrapper) @@ -396,9 +378,1483 @@ def test_local_llm_specific_features(self, target_vllm_engine): assert hasattr(target_vllm_engine, "llm_instance") assert target_vllm_engine._tensor_parallel_size == 1 - torchrl_logger.info("✓ Local LLM-specific tests passed") + logger.info("✓ Local LLM-specific tests passed") + + +@pytest.mark.skipif(not _has_ray, reason="missing ray dependencies") +@pytest.mark.skipif(not _has_vllm, reason="missing vllm dependencies") +@pytest.mark.skipif(not _has_transformers, reason="missing transformers dependencies") +@pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 3, + reason="CUDA not available or not enough GPUs (need 3: 2 for vLLM workers, 1 for trainer)", +) +class TestWeightSyncVLLMNCCL: + """Test vLLM weight synchronization using the Sender/Receiver API. + + This test suite verifies weight synchronization between a transformer trainer + and vLLM inference workers using collective communication (NCCL). + """ + + @staticmethod + def serialize_metadata(metadata: dict[str, tuple[torch.dtype, torch.Size]]) -> dict: + """Convert metadata with torch dtypes and sizes to JSON-serializable format. + + Args: + metadata: Dict mapping parameter names to (dtype, shape) tuples + + Returns: + JSON-serializable dict with string dtype representations + """ + serialized = {} + for name, (dtype, shape) in metadata.items(): + serialized[name] = { + "dtype": _DTYPE_TO_STR_DTYPE[dtype], + "shape": list(shape), + } + return serialized + + @staticmethod + def deserialize_metadata( + serialized: dict, + ) -> dict[str, tuple[torch.dtype, torch.Size]]: + """Convert JSON-serialized metadata back to torch dtypes and sizes. + + Args: + serialized: JSON dict with string dtype representations + + Returns: + Dict mapping parameter names to (dtype, shape) tuples + """ + metadata = {} + for name, info in serialized.items(): + dtype = _STR_DTYPE_TO_DTYPE[info["dtype"]] + shape = torch.Size(info["shape"]) + metadata[name] = (dtype, shape) + return metadata + + @staticmethod + def _make_worker_vllm(model_name: str = "Qwen/Qwen2.5-0.5B"): + """Create a vLLM wrapper with AsyncVLLM backend.""" + from torchrl.modules.llm.backends import AsyncVLLM + from torchrl.modules.llm.policies import vLLMWrapper + + async_engine = AsyncVLLM.from_pretrained( + model_name, + num_replicas=2, # Number of engine replicas + ) + wrapper = vLLMWrapper(async_engine, input_mode="history") + return wrapper + + @staticmethod + def _make_worker_transformer(model_name: str = "Qwen/Qwen2.5-0.5B"): + """Create a transformer model for training.""" + from transformers import AutoModelForCausalLM + + transformer = AutoModelForCausalLM.from_pretrained( + model_name, + dtype=torch.float16, + ) + transformer = transformer.cuda() + return transformer + + class WorkerVLLM: + """Ray actor for vLLM inference worker (receiver).""" + + def __init__( + self, + scheme_config: dict, + model_name: str = "Qwen/Qwen2.5-0.5B", + trainer_actor_name: str = "Trainer", + ): + pass + + # Store config for deferred initialization + self.scheme_config = scheme_config + self.model_name = model_name + self.trainer_actor_name = trainer_actor_name + self.wrapper = None + self.engine = None + self.receiver = None + self.scheme = None + self.trainer = None + self.model_metadata = None + + def setup(self): + """Set up vLLM engine (deferred from __init__ to avoid blocking).""" + # Create vLLM wrapper + self.wrapper = TestWeightSyncVLLMNCCL._make_worker_vllm(self.model_name) + self.engine = self.wrapper.model + + # Create scheme from config + from torchrl.weight_update.llm.vllm_nccl import VLLMWeightSyncScheme + + self.scheme = VLLMWeightSyncScheme(**self.scheme_config) + + # Create receiver (engine handles rank assignment automatically) + self.receiver = self.scheme.create_receiver(self.engine) + return "setup_complete" + + def init_metadata(self): + """Initialize the receiver by fetching metadata from trainer.""" + import ray + + if self.receiver is None: + raise RuntimeError("Must call setup() before init()") + + # Get trainer actor by name + logger.info(f"Getting trainer actor by name {self.trainer_actor_name}") + self.trainer = ray.get_actor(self.trainer_actor_name) + + # Fetch model metadata from trainer + logger.info( + "Fetching model metadata from trainer (requires max_concurrency>1)" + ) + self.model_metadata = ray.get(self.trainer.get_model_metadata.remote()) + + def init(self): + if self.model_metadata is None: + raise RuntimeError("Must call init_metadata() before init()") + + # Initialize receiver with metadata + logger.info("Initializing receiver...") + self.receiver.init_all_workers_group(self.model_metadata) + self.initialized = True + logger.info("Receiver initialized") + return "initialized" + + def get_engine(self): + """Get the vLLM engine reference for RPC coordination.""" + if self.engine is None: + raise RuntimeError("Must call setup() first") + return self.engine + + def get_sample_output(self): + """Get a sample output to verify model works.""" + # Simple inference test + return "vllm_ready" + + # @classmethod + # def run_forever( + # cls, scheme_config, parent_pipe, child_pipe, trainer_metadata_pipe, model_name="Qwen/Qwen2.5-0.5B" + # ): + # """A single threaded infinite loop capturing commands via a Pipe. + + # Args: + # scheme_config: Configuration for VLLMWeightSyncScheme + # parent_pipe: Parent end of the pipe (to be closed in child) + # child_pipe: Child end of the pipe for receiving commands + # trainer_metadata_pipe: Pipe to receive metadata from trainer + # model_name: Model name to load + # """ + # import os + + # # Set CUDA_VISIBLE_DEVICES for vLLM workers (GPUs 1,2) + # # vLLM will use these for its 2 replicas + # os.environ["CUDA_VISIBLE_DEVICES"] = "1,2" + + # parent_pipe.close() # Close parent end in child process + + # # Update scheme_config to use device 0 (which is actually GPU 1 due to CUDA_VISIBLE_DEVICES) + # scheme_config = scheme_config.copy() + # scheme_config["device"] = 0 + + # worker = cls(scheme_config, model_name) + # child_pipe.send({"status": "success", "result": "instantiated"}) + + # while True: + # try: + # command = child_pipe.recv() + # if command == "shutdown": + # child_pipe.send({"status": "shutdown"}) + # break + # elif command == "setup": + # result = worker.setup() + # child_pipe.send({"status": "success", "result": result}) + # elif command == "init_metadata": + # # Receive metadata from trainer via separate pipe + # worker.model_metadata = trainer_metadata_pipe.recv() + # child_pipe.send( + # {"status": "success", "result": "metadata_received"} + # ) + # elif command == "init": + # result = worker.init() + # child_pipe.send({"status": "success", "result": result}) + # elif command == "update_weights_receiver": + # worker.receiver.update_weights() + # child_pipe.send({"status": "success", "result": "receiving_started"}) + # elif ( + # isinstance(command, dict) + # and command.get("cmd") == "get_sample_output" + # ): + # result = worker.get_sample_output() + # child_pipe.send({"status": "success", "result": result}) + # else: + # child_pipe.send( + # {"status": "error", "error": f"Unknown command: {command}"} + # ) + # except Exception as e: + # torchrl_logger.error(f"WorkerVLLM error: {e}", exc_info=True) + # child_pipe.send({"status": "error", "error": str(e)}) + # break + + # @classmethod + # def run_forever_http(cls, scheme_config, port, model_name="Qwen/Qwen2.5-0.5B"): + # """Run an HTTP server that accepts commands via REST endpoints. + + # Args: + # scheme_config: Configuration for VLLMWeightSyncScheme + # port: Port to listen on + # model_name: Model name to load + # """ + # import os + + # from flask import Flask, jsonify, request + + # # Set CUDA_VISIBLE_DEVICES for vLLM workers (GPUs 1,2) + # os.environ["CUDA_VISIBLE_DEVICES"] = "1,2" + + # # Update scheme_config to use device 0 (which is actually GPU 1 due to CUDA_VISIBLE_DEVICES) + # scheme_config = scheme_config.copy() + # scheme_config["device"] = 0 + + # # Initialize Ray in this subprocess - required for AsyncVLLM + # import ray + # if not ray.is_initialized(): + # ray.init() + # torchrl_logger.info("Ray initialized in WorkerVLLM subprocess") + + # app = Flask(f"WorkerVLLM_{port}") + + # # Defer worker creation until first request to allow Flask to start quickly + # worker = None + + # def ensure_worker(): + # nonlocal worker + # if worker is None: + # worker = cls(scheme_config, model_name) + # return worker + + # @app.route("/health", methods=["GET"]) + # def health(): + # """Health check endpoint that doesn't require worker initialization.""" + # return jsonify({"status": "ready"}) + + # @app.route("/setup", methods=["POST"]) + # def setup(): + # try: + # w = ensure_worker() + # result = w.setup() + # return jsonify({"status": "success", "result": result}) + # except Exception as e: + # torchrl_logger.error(f"Setup error: {e}", exc_info=True) + # return jsonify({"status": "error", "error": str(e)}), 500 + + # @app.route("/init_metadata", methods=["POST"]) + # def init_metadata(): + # try: + # # Receive metadata in request body + # w = ensure_worker() + # received_data = request.json + # torchrl_logger.info(f"Received metadata with {len(received_data)} parameters") + # torchrl_logger.info(f"First 3 params: {list(received_data.keys())[:3]}") + # torchrl_logger.info(f"Last 3 params: {list(received_data.keys())[-3:]}") + # w.model_metadata = TestWeightSyncVLLMNCCL.deserialize_metadata(received_data) + # torchrl_logger.info(f"Deserialized metadata successfully") + # return jsonify({"status": "success", "result": "metadata_received"}) + # except Exception as e: + # torchrl_logger.error(f"Init metadata error: {e}", exc_info=True) + # return jsonify({"status": "error", "error": str(e)}), 500 + + # @app.route("/init", methods=["POST"]) + # def init(): + # try: + # w = ensure_worker() + # result = w.init() + # return jsonify({"status": "success", "result": result}) + # except Exception as e: + # torchrl_logger.error(f"Init error: {e}", exc_info=True) + # return jsonify({"status": "error", "error": str(e)}), 500 + + # @app.route("/update_weights_receiver", methods=["POST"]) + # def update_weights_receiver(): + # try: + # w = ensure_worker() + # w.receiver.update_weights() + # return jsonify({"status": "success", "result": "receiving_started"}) + # except Exception as e: + # torchrl_logger.error(f"Receiver update error: {e}", exc_info=True) + # return jsonify({"status": "error", "error": str(e)}), 500 + + # @app.route("/get_sample_output", methods=["GET"]) + # def get_sample_output(): + # try: + # w = ensure_worker() + # result = w.get_sample_output() + # return jsonify({"status": "success", "result": result}) + # except Exception as e: + # torchrl_logger.error(f"Get sample output error: {e}", exc_info=True) + # return jsonify({"status": "error", "error": str(e)}), 500 + + # @app.route("/shutdown", methods=["POST"]) + # def shutdown(): + # try: + # # Shutdown Ray before killing the process + # import ray + # if ray.is_initialized(): + # ray.shutdown() + # torchrl_logger.info("Ray shut down in WorkerVLLM subprocess") + + # func = request.environ.get("werkzeug.server.shutdown") + # if func is None: + # # Running under a different WSGI server + # import os + # import signal + + # os.kill(os.getpid(), signal.SIGTERM) + # else: + # func() + # return jsonify({"status": "shutdown"}) + # except Exception as e: + # torchrl_logger.error(f"Shutdown error: {e}", exc_info=True) + # return jsonify({"status": "error", "error": str(e)}), 500 + + # torchrl_logger.info(f"WorkerVLLM HTTP server starting on port {port}") + # app.run(host="0.0.0.0", port=port, threaded=True) + + @classmethod + def as_remote(cls, *args, **kwargs): + import ray + + # No GPUs needed for the actor itself - vLLM workers manage their own placement group (2 GPUs) + # AsyncVLLM service doesn't act as NCCL rank 0 when used with external trainer + return ray.remote(num_cpus=4, num_gpus=0, max_concurrency=4)(cls) + + class WorkerTransformer: + """Ray actor for transformer trainer (sender).""" + + def __init__(self, scheme_config: dict, model_name: str = "Qwen/Qwen2.5-0.5B"): + from torchrl.weight_update.llm.vllm_nccl import ( + get_model_metadata, + VLLMWeightSyncScheme, + ) + + # Create transformer model + self.transformer = TestWeightSyncVLLMNCCL._make_worker_transformer( + model_name + ) + + # Create scheme from config + self.scheme = VLLMWeightSyncScheme(**scheme_config) + + # Create sender + self.sender = self.scheme.create_sender() + self.sender.register_model(self.transformer) + + # Extract and store model metadata + self.model_metadata = get_model_metadata(self.transformer) + + def init(self, vllm_engine=None): + """Initialize sender with optional vLLM engine for RPC coordination. + + Args: + vllm_engine: Optional vLLM engine reference for calling collective_rpc + """ + if self.model_metadata is None: + raise RuntimeError("Must call init_metadata() before init()") + + self.sender.init_all_workers_group( + self.model_metadata, vllm_engine=vllm_engine + ) + self.initialized = True + logger.info("Trainer initialized") + return "initialized" + + def get_model_metadata(self): + """Get model metadata to share with receiver.""" + return self.model_metadata + + def update_weights(self, modify_weights: bool = False): + """Trigger a weight update broadcast. + + Args: + modify_weights: If True, modifies weights before broadcasting + for verification purposes. + + Returns: + str: "updated" status message + """ + + # Optionally modify weights for testing + if modify_weights: + with torch.no_grad(): + first_param = next(self.transformer.parameters()) + first_param.add_(0.01) + + # Broadcast weights to all vLLM workers + self.sender.update_weights() + return "updated" + + def get_first_param_sum(self): + """Get sum of first parameter for verification.""" + return next(self.transformer.parameters()).sum().item() + + # @classmethod + # def run_forever( + # cls, scheme_config, parent_pipe, child_pipe, metadata_send_pipe, model_name="Qwen/Qwen2.5-0.5B" + # ): + # """A single threaded infinite loop capturing commands via a Pipe. + + # Args: + # scheme_config: Configuration for VLLMWeightSyncScheme + # parent_pipe: Parent end of the pipe (to be closed in child) + # child_pipe: Child end of the pipe for receiving commands + # metadata_send_pipe: Pipe to send metadata to receiver + # model_name: Model name to load + # """ + # import os + + # # Set CUDA_VISIBLE_DEVICES for trainer (GPU 0 only) + # os.environ["CUDA_VISIBLE_DEVICES"] = "0" + + # parent_pipe.close() # Close parent end in child process + + # # Update scheme_config to use device 0 (which is actually GPU 0) + # scheme_config = scheme_config.copy() + # scheme_config["device"] = 0 + + # worker = cls(scheme_config, model_name) + # child_pipe.send({"status": "success", "result": "instantiated"}) + + # while True: + # try: + # command = child_pipe.recv() + # if command == "shutdown": + # child_pipe.send({"status": "shutdown"}) + # break + # elif command == "init": + # result = worker.init() + # child_pipe.send({"status": "success", "result": result}) + # elif isinstance(command, dict): + # cmd_name = command.get("cmd") + # if cmd_name == "get_model_metadata": + # # Send metadata to receiver via separate pipe + # metadata_send_pipe.send(worker.model_metadata) + # child_pipe.send( + # {"status": "success", "result": "metadata_sent"} + # ) + # elif cmd_name == "update_weights": + # modify_weights = command.get("modify_weights", False) + # result = worker.update_weights(modify_weights) + # child_pipe.send({"status": "success", "result": result}) + # elif cmd_name == "get_first_param_sum": + # result = worker.get_first_param_sum() + # child_pipe.send({"status": "success", "result": result}) + # else: + # child_pipe.send( + # { + # "status": "error", + # "error": f"Unknown command: {cmd_name}", + # } + # ) + # else: + # child_pipe.send( + # {"status": "error", "error": f"Unknown command: {command}"} + # ) + # except Exception as e: + # torchrl_logger.error(f"WorkerTransformer error: {e}", exc_info=True) + # child_pipe.send({"status": "error", "error": str(e)}) + # break + + # @classmethod + # def run_forever_http(cls, scheme_config, port, model_name="Qwen/Qwen2.5-0.5B"): + # """Run an HTTP server that accepts commands via REST endpoints. + + # Args: + # scheme_config: Configuration for VLLMWeightSyncScheme + # port: Port to listen on + # model_name: Model name to load + # """ + # import os + + # from flask import Flask, jsonify, request + + # # Set CUDA_VISIBLE_DEVICES for trainer (GPU 0 only) + # os.environ["CUDA_VISIBLE_DEVICES"] = "0" + + # # Update scheme_config to use device 0 (which is actually GPU 0) + # scheme_config = scheme_config.copy() + # scheme_config["device"] = 0 + + # app = Flask(f"WorkerTransformer_{port}") + + # # Defer worker creation until first request to allow Flask to start quickly + # worker = None + + # def ensure_worker(): + # nonlocal worker + # if worker is None: + # worker = cls(scheme_config, model_name) + # return worker + + # @app.route("/health", methods=["GET"]) + # def health(): + # """Health check endpoint that doesn't require worker initialization.""" + # return jsonify({"status": "ready"}) + + # @app.route("/init", methods=["POST"]) + # def init(): + # try: + # w = ensure_worker() + # result = w.init() + # return jsonify({"status": "success", "result": result}) + # except Exception as e: + # torchrl_logger.error(f"Init error: {e}", exc_info=True) + # return jsonify({"status": "error", "error": str(e)}), 500 + + # @app.route("/get_model_metadata", methods=["GET"]) + # def get_model_metadata(): + # try: + # # Return metadata as JSON + # w = ensure_worker() + # serialized = TestWeightSyncVLLMNCCL.serialize_metadata(w.model_metadata) + # return jsonify({"status": "success", "result": serialized}) + # except Exception as e: + # torchrl_logger.error(f"Get metadata error: {e}", exc_info=True) + # return jsonify({"status": "error", "error": str(e)}), 500 + + # @app.route("/update_weights", methods=["POST"]) + # def update_weights(): + # try: + # data = request.json or {} + # modify_weights = data.get("modify_weights", False) + # w = ensure_worker() + # result = w.update_weights(modify_weights) + # return jsonify({"status": "success", "result": result}) + # except Exception as e: + # torchrl_logger.error(f"Update weights error: {e}", exc_info=True) + # return jsonify({"status": "error", "error": str(e)}), 500 + + # @app.route("/get_first_param_sum", methods=["GET"]) + # def get_first_param_sum(): + # try: + # w = ensure_worker() + # result = w.get_first_param_sum() + # return jsonify({"status": "success", "result": result}) + # except Exception as e: + # torchrl_logger.error(f"Get param sum error: {e}", exc_info=True) + # return jsonify({"status": "error", "error": str(e)}), 500 + + # @app.route("/shutdown", methods=["POST"]) + # def shutdown(): + # try: + # func = request.environ.get("werkzeug.server.shutdown") + # if func is None: + # # Running under a different WSGI server + # import os + # import signal + + # os.kill(os.getpid(), signal.SIGTERM) + # else: + # func() + # return jsonify({"status": "shutdown"}) + # except Exception as e: + # torchrl_logger.error(f"Shutdown error: {e}", exc_info=True) + # return jsonify({"status": "error", "error": str(e)}), 500 + + # torchrl_logger.info( + # f"WorkerTransformer HTTP server starting on port {port}" + # ) + # app.run(host="0.0.0.0", port=port, threaded=True) + + @classmethod + def as_remote(cls, *args, **kwargs): + import ray + + return ray.remote(num_cpus=4, num_gpus=1, max_concurrency=4)(cls) + + def test_weight_sync_vllm_collective_ray(self, request): + """Test weight sync between transformer trainer and vLLM workers. + + Uses Ray remote calls for RPC coordination. + + This test demonstrates the simplified API using named Ray actors: + 1. Trainer is created as a named actor "Trainer" + 2. vLLM receiver discovers trainer by name to fetch metadata + 3. Both initialize simultaneously for collective handshake + 4. Weight updates can be triggered via RPC to the trainer + """ + import ray + + if not ray.is_initialized(): + ray.init() + + # Determine model based on --runslow flag + if request.config.getoption("--runslow"): + model_name = "Qwen/Qwen2.5-3B" + logger.info("Using large model (3B) for slow test") + else: + model_name = "Qwen/Qwen2.5-0.5B" + logger.info("Using small model (0.5B) for fast test") + + try: + # Create scheme configuration + # Use a unique port for each test run to avoid conflicts + import random + + test_port = random.randint(30000, 40000) + scheme_config = { + "master_address": "localhost", + "master_port": test_port, + "gpus_per_replica": 1, # tp_size × dp_size × pp_size (1×1×1=1) + "num_replicas": 2, # Number of engine replicas + "strategy": "state_dict", + # device defaults to 0 - Ray sets CUDA_VISIBLE_DEVICES per actor + } + logger.info(f"Using NCCL port {test_port}") + # world_size = 1 (trainer) + 2 (replicas) × 1 (gpus_per_replica) = 3 + + logger.info( + "Creating receiver actor first (vLLM workers need 2 GPUs via placement group)..." + ) + # Create receiver actor first - it will find trainer by name + receiver = TestWeightSyncVLLMNCCL.WorkerVLLM.as_remote().remote( + scheme_config, model_name, trainer_actor_name="Trainer" + ) + + # Set up vLLM engine (creates placement group with 2 GPUs for workers) + logger.info("Setting up vLLM engine...") + ray.get(receiver.setup.remote()) + logger.info("vLLM engine setup complete") + + # Now create trainer actor (needs 1 GPU for training and NCCL rank 0) + logger.info("Creating trainer actor (needs 1 GPU)...") + trainer = ( + TestWeightSyncVLLMNCCL.WorkerTransformer.as_remote() + .options(name="Trainer") + .remote(scheme_config, model_name) + ) + logger.info("Trainer actor created") + + # Sequential initialization to avoid deadlock: + # 1. Receiver gets metadata from trainer (RPC) and completes setup + logger.info("Step 1: Receiver fetching metadata from trainer...") + ray.get(receiver.init_metadata.remote()) + + # Get vLLM engine reference from receiver for RPC coordination + logger.info("Getting vLLM engine reference from receiver...") + vllm_engine = ray.get(receiver.get_engine.remote()) + + # 2. Start NCCL init on both sides (parallel dispatch) + logger.info("Step 2: Starting NCCL init on both trainer and workers...") + # Dispatch both futures in parallel + nccl_worker_fut = ( + receiver.init.remote() + ) # Starts vLLM worker background threads + nccl_trainer_fut = trainer.init.remote( + vllm_engine=vllm_engine + ) # Pass engine for RPC + + # Wait for trainer first - it blocks until all ranks (including worker threads) participate + # This ensures the collective completes before we proceed + logger.info( + "Waiting for trainer NCCL init (blocks until all ranks ready)..." + ) + ray.get(nccl_trainer_fut) + + # Receiver future should already be done (it just dispatched threads and waited for RPCs) + logger.info("Waiting for receiver NCCL init...") + ray.get(nccl_worker_fut) + + # 3. NCCL collective completes - all ranks synchronized + logger.info("NCCL collective initialization complete!") + + # Get initial state + initial_sum = ray.get(trainer.get_first_param_sum.remote()) + + # Trigger weight update with modification + # Trainer now handles RPC coordination internally (periodic-mono pattern) + logger.info("=== Starting weight update ===") + t0 = time.time() + ray.get(trainer.update_weights.remote(modify_weights=True)) + t1 = time.time() + update_time = t1 - t0 + logger.info(f"=== NCCL weight update completed in {update_time:.3f}s ===") + + # Verify weights changed + updated_sum = ray.get(trainer.get_first_param_sum.remote()) + assert updated_sum != initial_sum, "Weights should have changed" + + # Verify receiver still functional + assert ray.get(receiver.get_sample_output.remote()) == "vllm_ready" + + finally: + if ray.is_initialized(): + ray.shutdown() + + # def test_weight_sync_vllm_collective_mp(self, request): + # """Test weight sync between transformer trainer and vLLM workers. + + # Uses multiprocessing with pipes for RPC coordination instead of Ray. + + # This test demonstrates the same behavior as test_weight_sync_vllm_collective_ray + # but using Python's multiprocessing: + # 1. Trainer and receiver run in separate processes + # 2. Main process coordinates via pipe commands + # 3. Metadata exchange happens via separate pipes + # 4. Both initialize simultaneously for collective handshake + # 5. Weight updates can be triggered via pipe commands + # """ + # import random + + # # Determine model based on --runslow flag + # if request.config.getoption("--runslow"): + # model_name = "Qwen/Qwen2.5-3B" + # torchrl_logger.info("Using large model (3B) for slow test") + # else: + # model_name = "Qwen/Qwen2.5-0.5B" + # torchrl_logger.info("Using small model (0.5B) for fast test") + + # # Create scheme configuration + # test_port = 10234 + # scheme_config = { + # "master_address": "localhost", + # "master_port": test_port, + # "gpus_per_replica": 1, + # "num_replicas": 2, + # "strategy": "state_dict", + # # device will be set in each worker's run_forever based on CUDA_VISIBLE_DEVICES + # # Trainer: CUDA_VISIBLE_DEVICES="0" -> uses GPU 0 + # # Receiver: CUDA_VISIBLE_DEVICES="1,2" -> vLLM uses GPUs 1,2 for its 2 replicas + # } + # torchrl_logger.info(f"Using NCCL port {test_port}") + + # # Create pipes for communication + # # Pipe for trainer commands + # trainer_parent_pipe, trainer_child_pipe = mp.Pipe() + # # Pipe for receiver commands + # receiver_parent_pipe, receiver_child_pipe = mp.Pipe() + # # Pipe for metadata exchange (trainer -> receiver) + # metadata_send_pipe, metadata_recv_pipe = mp.Pipe() + + # trainer_proc = None + # receiver_proc = None + # try: + # # Start receiver process (needs 2 GPUs for vLLM workers) + # torchrl_logger.info( + # "Starting receiver process (vLLM workers need 2 GPUs)..." + # ) + # receiver_proc = mp.Process( + # target=TestWeightSyncVLLMNCCL.WorkerVLLM.run_forever, + # args=( + # scheme_config, + # receiver_parent_pipe, + # receiver_child_pipe, + # metadata_recv_pipe, + # model_name, + # ), + # ) + # receiver_proc.start() + # receiver_child_pipe.close() # Close child end in parent + + # # Start trainer process (needs 1 GPU) + # torchrl_logger.info("Starting trainer process (needs 1 GPU)...") + # trainer_proc = mp.Process( + # target=TestWeightSyncVLLMNCCL.WorkerTransformer.run_forever, + # args=( + # scheme_config, + # trainer_parent_pipe, + # trainer_child_pipe, + # metadata_send_pipe, + # model_name, + # ), + # ) + # trainer_proc.start() + # trainer_child_pipe.close() # Close child end in parent + + # # Helper to send command and wait for response + # def send_command(pipe, command, timeout=180.0): + # pipe.send(command) + # if pipe.poll(timeout): + # response = pipe.recv() + # if response.get("status") == "error": + # raise RuntimeError(f"Command failed: {response.get('error')}") + # return response + # else: + # raise TimeoutError(f"Command {command} timed out") + + # # Check for successful instantiation + # assert receiver_parent_pipe.recv()["status"] == "success" + # # Check for successful instantiation + # assert trainer_parent_pipe.recv()["status"] == "success" + + # # Step 1: Setup vLLM engine + # torchrl_logger.info("Setting up vLLM engine...") + # send_command(receiver_parent_pipe, "setup") + # torchrl_logger.info("vLLM engine setup complete") + + # # Step 2: Receiver gets metadata from trainer + # torchrl_logger.info("Step 1: Receiver fetching metadata from trainer...") + # # Trainer sends metadata + # send_command(trainer_parent_pipe, {"cmd": "get_model_metadata"}) + # # Receiver receives metadata + # send_command(receiver_parent_pipe, "init_metadata") + # torchrl_logger.info("Metadata exchange complete") + + # # Step 3: Start NCCL init on both sides (parallel) + # torchrl_logger.info( + # "Step 2: Starting NCCL init on both trainer and workers..." + # ) + # # Send init commands to both (non-blocking on main process side) + # trainer_parent_pipe.send("init") + # receiver_parent_pipe.send("init") + + # # Wait for both to complete + # torchrl_logger.info("Waiting for trainer NCCL init...") + # trainer_response = None + # if trainer_parent_pipe.poll(60.0): # Longer timeout for NCCL + # trainer_response = trainer_parent_pipe.recv() + # if trainer_response.get("status") == "error": + # raise RuntimeError( + # f"Trainer init failed: {trainer_response.get('error')}" + # ) + # else: + # raise TimeoutError("Trainer NCCL init timed out") + + # torchrl_logger.info("Waiting for receiver NCCL init...") + # receiver_response = None + # if receiver_parent_pipe.poll(60.0): + # receiver_response = receiver_parent_pipe.recv() + # if receiver_response.get("status") == "error": + # raise RuntimeError( + # f"Receiver init failed: {receiver_response.get('error')}" + # ) + # else: + # raise TimeoutError("Receiver NCCL init timed out") + + # torchrl_logger.info("NCCL collective initialization complete!") + + # # Get initial state + # initial_response = send_command( + # trainer_parent_pipe, {"cmd": "get_first_param_sum"} + # ) + # initial_sum = initial_response.get("result") + # torchrl_logger.info(f"Initial param sum: {initial_sum}") + + # # Trigger weight update with modification using concurrent pattern + # torchrl_logger.info("Triggering concurrent weight update...") + + # # Send both commands without waiting (they'll execute concurrently) + # t0 = time.time() + # receiver_parent_pipe.send("update_weights_receiver") + # trainer_parent_pipe.send({"cmd": "update_weights", "modify_weights": True}) + + # # Wait for both responses + # if receiver_parent_pipe.poll(180.0): + # receiver_response = receiver_parent_pipe.recv() + # if receiver_response.get("status") == "error": + # raise RuntimeError(f"Receiver update failed: {receiver_response.get('error')}") + # else: + # raise TimeoutError("Receiver update timed out") + + # if trainer_parent_pipe.poll(180.0): + # trainer_response = trainer_parent_pipe.recv() + # if trainer_response.get("status") == "error": + # raise RuntimeError(f"Trainer update failed: {trainer_response.get('error')}") + # else: + # raise TimeoutError("Trainer update timed out") + + # t1 = time.time() + # update_time = t1 - t0 + # torchrl_logger.info(f"=== NCCL weight update completed in {update_time:.3f}s ===") + + # # Verify weights changed + # updated_response = send_command( + # trainer_parent_pipe, {"cmd": "get_first_param_sum"} + # ) + # updated_sum = updated_response.get("result") + # torchrl_logger.info(f"Updated param sum: {updated_sum}") + # assert updated_sum != initial_sum, "Weights should have changed" + + # # Verify receiver still functional + # sample_response = send_command( + # receiver_parent_pipe, {"cmd": "get_sample_output"} + # ) + # assert sample_response.get("result") == "vllm_ready" + + # torchrl_logger.info("Test completed successfully!") + + # finally: + # # Shutdown processes + # torchrl_logger.info("Shutting down processes...") + # try: + # trainer_parent_pipe.send("shutdown") + # receiver_parent_pipe.send("shutdown") + # except Exception as e: + # torchrl_logger.warning(f"Error sending shutdown: {e}") + + # # Wait for processes to exit + # if trainer_proc is not None and trainer_proc.is_alive(): + # trainer_proc.join(timeout=5.0) + # if trainer_proc.is_alive(): + # torchrl_logger.warning( + # "Trainer process did not exit, terminating..." + # ) + # trainer_proc.terminate() + # trainer_proc.join(timeout=2.0) + + # if receiver_proc is not None and receiver_proc.is_alive(): + # receiver_proc.join(timeout=5.0) + # if receiver_proc.is_alive(): + # torchrl_logger.warning( + # "Receiver process did not exit, terminating..." + # ) + # receiver_proc.terminate() + # receiver_proc.join(timeout=2.0) + + # # Close pipes + # trainer_parent_pipe.close() + # receiver_parent_pipe.close() + # metadata_send_pipe.close() + # metadata_recv_pipe.close() + + # def test_weight_sync_vllm_collective_http(self, request): + # """Test weight sync between transformer trainer and vLLM workers. + + # Uses HTTP/REST for RPC coordination instead of Ray or pipes. + + # This test demonstrates the same behavior using HTTP: + # 1. Trainer and receiver run Flask servers in separate processes + # 2. Main process coordinates via HTTP POST/GET requests + # 3. Metadata exchange happens via REST endpoints + # 4. Both initialize simultaneously for collective handshake + # 5. Weight updates can be triggered via HTTP requests + + # Benefits: + # - Easy to debug with curl/browser + # - Works across any network + # - Language-agnostic (workers could be in different languages) + # - No special dependencies beyond Flask + # """ + # import random + + # import requests + + # # Determine model based on --runslow flag + # if request.config.getoption("--runslow"): + # model_name = "Qwen/Qwen2.5-3B" + # torchrl_logger.info("Using large model (3B) for slow test") + # else: + # model_name = "Qwen/Qwen2.5-0.5B" + # torchrl_logger.info("Using small model (0.5B) for fast test") + + # # Create scheme configuration + # test_port = 10235 + # scheme_config = { + # "master_address": "localhost", + # "master_port": test_port, + # "gpus_per_replica": 1, + # "num_replicas": 2, + # "strategy": "state_dict", + # } + # torchrl_logger.info(f"Using NCCL port {test_port}") + + # # Choose random ports for HTTP servers + # receiver_http_port = random.randint(5000, 5100) + # trainer_http_port = random.randint(5100, 5200) + + # try: + # # Start receiver HTTP server (needs 2 GPUs for vLLM workers) + # torchrl_logger.info( + # f"Starting receiver HTTP server on port {receiver_http_port}..." + # ) + # receiver_proc = mp.Process( + # target=TestWeightSyncVLLMNCCL.WorkerVLLM.run_forever_http, + # args=(scheme_config, receiver_http_port, model_name), + # ) + # receiver_proc.start() + + # # Start trainer HTTP server (needs 1 GPU) + # torchrl_logger.info( + # f"Starting trainer HTTP server on port {trainer_http_port}..." + # ) + # trainer_proc = mp.Process( + # target=TestWeightSyncVLLMNCCL.WorkerTransformer.run_forever_http, + # args=(scheme_config, trainer_http_port, model_name), + # ) + # trainer_proc.start() + + # # Wait for servers to be ready by polling health endpoints + # torchrl_logger.info("Waiting for HTTP servers to start...") + # receiver_url = f"http://localhost:{receiver_http_port}" + # trainer_url = f"http://localhost:{trainer_http_port}" + + # # Poll health endpoints with timeout + # start_time = time.time() + # timeout = 180.0 + # receiver_ready = False + # trainer_ready = False + + # while time.time() - start_time < timeout: + # if not receiver_ready: + # try: + # resp = requests.get(f"{receiver_url}/health", timeout=1.0) + # if resp.status_code == 200: + # receiver_ready = True + # torchrl_logger.info("Receiver HTTP server is ready") + # except Exception: + # pass # Server not ready yet + + # if not trainer_ready: + # try: + # resp = requests.get(f"{trainer_url}/health", timeout=1.0) + # if resp.status_code == 200: + # trainer_ready = True + # torchrl_logger.info("Trainer HTTP server is ready") + # except Exception: + # pass # Server not ready yet + + # if receiver_ready and trainer_ready: + # break + + # time.sleep(0.5) + + # if not (receiver_ready and trainer_ready): + # raise TimeoutError( + # f"Servers did not start within {timeout}s. " + # f"Receiver: {receiver_ready}, Trainer: {trainer_ready}" + # ) + + # # Helper to make HTTP requests + # def http_request( + # base_url, endpoint, method="POST", data=None, timeout=180.0 + # ): + # url = f"{base_url}{endpoint}" + # try: + # if method == "GET": + # response = requests.get(url, timeout=timeout) + # else: + # response = requests.post(url, json=data, timeout=timeout) + + # response.raise_for_status() + # result = response.json() + + # if result.get("status") == "error": + # raise RuntimeError(f"Request failed: {result.get('error')}") + # return result + # except requests.exceptions.RequestException as e: + # raise RuntimeError(f"HTTP request to {url} failed: {e}") + + # # Step 1: Setup vLLM engine + # torchrl_logger.info("Setting up vLLM engine via HTTP...") + # http_request(receiver_url, "/setup") + # torchrl_logger.info("vLLM engine setup complete") + + # # Step 2: Receiver gets metadata from trainer + # torchrl_logger.info("Step 1: Fetching metadata from trainer via HTTP...") + # # Get metadata from trainer + # metadata_response = http_request( + # trainer_url, "/get_model_metadata", method="GET" + # ) + # metadata = metadata_response.get("result") + # torchrl_logger.info(f"Fetched metadata with {len(metadata)} parameters") + # torchrl_logger.info(f"First 3 params: {list(metadata.keys())[:3]}") + # torchrl_logger.info(f"Last 3 params: {list(metadata.keys())[-3:]}") + + # # Send metadata to receiver + # http_request(receiver_url, "/init_metadata", data=metadata) + # torchrl_logger.info("Metadata exchange complete") + + # # Step 3: Start NCCL init on both sides + # # Note: HTTP is synchronous, so we need to do this carefully + # # We'll use threading to make parallel requests + # torchrl_logger.info( + # "Step 2: Starting NCCL init on both trainer and workers..." + # ) + + # import queue + # from threading import Thread + + # # Queues to collect results + # trainer_queue = queue.Queue() + # receiver_queue = queue.Queue() + + # def init_trainer(): + # try: + # result = http_request(trainer_url, "/init", timeout=180.0) + # trainer_queue.put(("success", result)) + # except Exception as e: + # trainer_queue.put(("error", str(e))) + + # def init_receiver(): + # try: + # result = http_request(receiver_url, "/init", timeout=180.0) + # receiver_queue.put(("success", result)) + # except Exception as e: + # receiver_queue.put(("error", str(e))) + + # # Start both initializations in parallel + # trainer_thread = Thread(target=init_trainer) + # receiver_thread = Thread(target=init_receiver) + + # trainer_thread.start() + # receiver_thread.start() + + # # Wait for both to complete + # torchrl_logger.info("Waiting for trainer NCCL init...") + # trainer_thread.join(timeout=180.0) + # status, result = trainer_queue.get() + # if status == "error": + # raise RuntimeError(f"Trainer init failed: {result}") + + # torchrl_logger.info("Waiting for receiver NCCL init...") + # receiver_thread.join(timeout=180.0) + # status, result = receiver_queue.get() + # if status == "error": + # raise RuntimeError(f"Receiver init failed: {result}") + + # torchrl_logger.info("NCCL collective initialization complete!") + + # # Get initial state + # initial_response = http_request( + # trainer_url, "/get_first_param_sum", method="GET" + # ) + # initial_sum = initial_response.get("result") + # torchrl_logger.info(f"Initial param sum: {initial_sum}") + + # # Trigger weight update with modification using concurrent pattern + # torchrl_logger.info("Triggering concurrent weight update...") + + # # Use threading to call both endpoints concurrently + # receiver_queue = queue.Queue() + # trainer_queue = queue.Queue() + + # def update_receiver(): + # try: + # result = http_request(receiver_url, "/update_weights_receiver", timeout=180.0) + # receiver_queue.put(("success", result)) + # except Exception as e: + # receiver_queue.put(("error", str(e))) + + # def update_trainer(): + # try: + # result = http_request(trainer_url, "/update_weights", data={"modify_weights": True}, timeout=180.0) + # trainer_queue.put(("success", result)) + # except Exception as e: + # trainer_queue.put(("error", str(e))) + + # # Start both updates in parallel + # t0 = time.time() + # receiver_update_thread = Thread(target=update_receiver) + # trainer_update_thread = Thread(target=update_trainer) + + # receiver_update_thread.start() + # trainer_update_thread.start() + + # # Wait for both to complete + # receiver_update_thread.join(timeout=180.0) + # trainer_update_thread.join(timeout=180.0) + + # # Check results + # status, result = receiver_queue.get() + # if status == "error": + # raise RuntimeError(f"Receiver update failed: {result}") + + # status, result = trainer_queue.get() + # if status == "error": + # raise RuntimeError(f"Trainer update failed: {result}") + + # t1 = time.time() + # update_time = t1 - t0 + # torchrl_logger.info(f"=== NCCL weight update completed in {update_time:.3f}s ===") + + # # Verify weights changed + # updated_response = http_request( + # trainer_url, "/get_first_param_sum", method="GET" + # ) + # updated_sum = updated_response.get("result") + # torchrl_logger.info(f"Updated param sum: {updated_sum}") + # assert updated_sum != initial_sum, "Weights should have changed" + + # # Verify receiver still functional + # sample_response = http_request( + # receiver_url, "/get_sample_output", method="GET" + # ) + # assert sample_response.get("result") == "vllm_ready" + + # torchrl_logger.info("Test completed successfully!") + # torchrl_logger.info( + # f"You can debug with: curl http://localhost:{trainer_http_port}/get_first_param_sum" + # ) + + # finally: + # # Shutdown processes + # torchrl_logger.info("Shutting down HTTP servers...") + # try: + # requests.post( + # f"http://localhost:{trainer_http_port}/shutdown", timeout=2.0 + # ) + # except Exception as e: + # torchrl_logger.warning(f"Error shutting down trainer: {e}") + + # try: + # requests.post( + # f"http://localhost:{receiver_http_port}/shutdown", timeout=2.0 + # ) + # except Exception as e: + # torchrl_logger.warning(f"Error shutting down receiver: {e}") + + # # Wait for processes to exit + # if trainer_proc.is_alive(): + # trainer_proc.join(timeout=5.0) + # if trainer_proc.is_alive(): + # torchrl_logger.warning( + # "Trainer process did not exit, terminating..." + # ) + # trainer_proc.terminate() + # trainer_proc.join(timeout=2.0) + + # if receiver_proc.is_alive(): + # receiver_proc.join(timeout=5.0) + # if receiver_proc.is_alive(): + # torchrl_logger.warning( + # "Receiver process did not exit, terminating..." + # ) + # receiver_proc.terminate() + # receiver_proc.join(timeout=2.0) + + +@pytest.mark.skipif(not _has_ray, reason="missing ray dependencies") +@pytest.mark.skipif(not _has_vllm, reason="missing vllm dependencies") +@pytest.mark.skipif(not _has_transformers, reason="missing transformers dependencies") +@pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="CUDA not available or not enough GPUs (need 2: 1 for vLLM, 1 for trainer)", +) +class TestWeightSyncVLLMDoubleBuffer: + """Test vLLM weight synchronization using double-buffered shared storage. + + This test suite verifies weight synchronization between a transformer trainer + and vLLM inference workers using memory-mapped TensorDict storage. + """ + + @staticmethod + def _make_worker_vllm(model_name: str = "Qwen/Qwen2.5-0.5B"): + """Create a vLLM wrapper with AsyncVLLM backend.""" + from torchrl.modules.llm.backends import AsyncVLLM + from torchrl.modules.llm.policies import vLLMWrapper + + async_engine = AsyncVLLM.from_pretrained( + model_name, + num_replicas=1, # Single replica for simplicity + ) + wrapper = vLLMWrapper(async_engine, input_mode="history") + return wrapper + + @staticmethod + def _make_worker_transformer(model_name: str = "Qwen/Qwen2.5-0.5B"): + """Create a transformer model for training.""" + from transformers import AutoModelForCausalLM + + transformer = AutoModelForCausalLM.from_pretrained( + model_name, + dtype=torch.float16, + ) + transformer = transformer.cuda() + return transformer + + class WorkerVLLM: + """Ray actor for vLLM inference worker (receiver).""" + + def __init__(self, scheme_config: dict, model_name: str = "Qwen/Qwen2.5-0.5B"): + # Store config for deferred initialization + self.scheme_config = scheme_config + self.model_name = model_name + self.wrapper = None + self.engine = None + self.receiver = None + self.scheme = None + + def setup(self): + """Set up vLLM engine and receiver.""" + # Create vLLM wrapper + self.wrapper = TestWeightSyncVLLMDoubleBuffer._make_worker_vllm( + self.model_name + ) + self.engine = self.wrapper.model + + # Create scheme from config + from torchrl.weight_update.llm.vllm_double_buffer import ( + VLLMDoubleBufferSyncScheme, + ) + + self.scheme = VLLMDoubleBufferSyncScheme(**self.scheme_config) + + # Create receiver + self.receiver = self.scheme.create_receiver(self.engine) + logger.info("Receiver setup complete") + return "setup_complete" + + def poll_and_apply_weights(self): + """Poll for new weights and apply them to the engine.""" + if self.receiver is None: + raise RuntimeError("Must call setup() first") + + success = self.receiver.poll_and_apply() + return success + + def get_sample_output(self): + """Get a sample output to verify model works.""" + return "vllm_ready" + + @classmethod + def as_remote(cls, *args, **kwargs): + import ray + + # vLLM worker needs 1 GPU + return ray.remote(num_cpus=2, num_gpus=1, max_concurrency=4)(cls) + + class WorkerTransformer: + """Ray actor for transformer trainer (sender).""" + + def __init__(self, scheme_config: dict, model_name: str = "Qwen/Qwen2.5-0.5B"): + from torchrl.weight_update.llm.vllm_double_buffer import ( + VLLMDoubleBufferSyncScheme, + ) + + # Create transformer model + self.transformer = TestWeightSyncVLLMDoubleBuffer._make_worker_transformer( + model_name + ) + + # Create scheme from config + self.scheme = VLLMDoubleBufferSyncScheme(**scheme_config) + + # Create sender + self.sender = self.scheme.create_sender() + self.sender.register_model(self.transformer) + logger.info("Trainer setup complete") + + def update_weights(self, modify_weights: bool = False): + """Trigger a weight update by writing to shared storage. + + Args: + modify_weights: If True, modifies weights before writing + for verification purposes. + + Returns: + str: "updated" status message + """ + # Optionally modify weights for testing + if modify_weights: + with torch.no_grad(): + first_param = next(self.transformer.parameters()) + first_param.add_(0.01) + + # Write weights to shared storage + self.sender.update_weights() + return "updated" + + def get_first_param_sum(self): + """Get sum of first parameter for verification.""" + return next(self.transformer.parameters()).sum().item() + + @classmethod + def as_remote(cls, *args, **kwargs): + import ray + + return ray.remote(num_cpus=2, num_gpus=1, max_concurrency=4)(cls) + + def test_weight_sync_vllm_double_buffer_ray(self, tmpdir, request): + """Test weight sync using double-buffered storage with Ray. + + This test demonstrates the simplified double-buffer API: + 1. Trainer writes weights to shared directory + 2. vLLM receiver polls and reads from shared directory + 3. No coordination needed - simple push/pull model + """ + import ray + + if not ray.is_initialized(): + ray.init() + + # Determine model based on --runslow flag + if request.config.getoption("--runslow"): + model_name = "Qwen/Qwen2.5-3B" + logger.info("Using large model (3B) for slow test") + else: + model_name = "Qwen/Qwen2.5-0.5B" + logger.info("Using small model (0.5B) for fast test") + + try: + # Create temporary directory for weight storage + logger.info(f"Using temporary directory for weights: {tmpdir}") + + # Create scheme configuration + scheme_config = { + "remote_addr": str(tmpdir), + "num_threads": 128, + "strategy": "state_dict", + } + + # Create trainer actor + logger.info("Creating trainer actor...") + trainer = ( + TestWeightSyncVLLMDoubleBuffer.WorkerTransformer.as_remote().remote( + scheme_config, model_name + ) + ) + logger.info("Trainer actor created") + + # Create receiver actor + logger.info("Creating receiver actor...") + receiver = TestWeightSyncVLLMDoubleBuffer.WorkerVLLM.as_remote().remote( + scheme_config, model_name + ) + + # Set up vLLM engine + logger.info("Setting up vLLM engine...") + ray.get(receiver.setup.remote()) + logger.info("vLLM engine setup complete") + + # Get initial state + initial_sum = ray.get(trainer.get_first_param_sum.remote()) + logger.info(f"Initial param sum: {initial_sum}") + + # Trigger weight update with modification and measure send timing + logger.info("=== Starting weight update timing measurement ===") + t0 = time.time() + ray.get(trainer.update_weights.remote(modify_weights=True)) + t1 = time.time() + send_time = t1 - t0 + logger.info(f"=== Weights written to storage in {send_time:.3f}s ===") + + # Verify weights changed on trainer side + updated_sum = ray.get(trainer.get_first_param_sum.remote()) + assert updated_sum != initial_sum, "Weights should have changed" + logger.info(f"Updated param sum: {updated_sum}") + + # Receiver polls and applies weights - measure receive timing + logger.info("Receiver polling for weights...") + t2 = time.time() + success = ray.get(receiver.poll_and_apply_weights.remote()) + t3 = time.time() + receive_time = t3 - t2 + total_time = t3 - t0 + assert success, "Weight application should succeed" + logger.info(f"=== Weights received and applied in {receive_time:.3f}s ===") + logger.info( + f"=== Total double-buffer update time: {total_time:.3f}s (send: {send_time:.3f}s, receive: {receive_time:.3f}s) ===" + ) + + # Verify receiver is still functional + assert ray.get(receiver.get_sample_output.remote()) == "vllm_ready" + logger.info("Test completed successfully!") + + finally: + if ray.is_initialized(): + ray.shutdown() if __name__ == "__main__": - # Simple smoke test - pytest.main([__file__, "-v", "-s"]) + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/modules/llm/backends/vllm/vllm_async.py b/torchrl/modules/llm/backends/vllm/vllm_async.py index 5cf2deb3c7a..73f440af533 100644 --- a/torchrl/modules/llm/backends/vllm/vllm_async.py +++ b/torchrl/modules/llm/backends/vllm/vllm_async.py @@ -17,42 +17,23 @@ import random import uuid from collections.abc import Iterator, Sequence +from concurrent.futures import ThreadPoolExecutor, wait from typing import Any, Literal, TYPE_CHECKING +import ray + import torch + +from ray.util.placement_group import placement_group, remove_placement_group +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from torchrl._utils import logger as torchrl_logger # Import RLvLLMEngine and shared utilities from .base import RLvLLMEngine from .vllm_utils import stateless_init_process_group -try: - import ray - from ray.util.placement_group import placement_group, remove_placement_group - from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy -except ImportError: - ray = None - - def placement_group(*args, **kwargs): - """Placement group is not available when ray is not installed.""" - raise ImportError( - "ray is not installed. Please install it with `pip install ray`." - ) - - def remove_placement_group(*args, **kwargs): - """Remove placement group is not available when ray is not installed.""" - raise ImportError( - "ray is not installed. Please install it with `pip install ray`." - ) - - class PlacementGroupSchedulingStrategy: - """Placement group scheduling strategy is not available when ray is not installed.""" - - def __init__(self, *args, **kwargs): - raise ImportError( - "ray is not installed. Please install it with `pip install ray`." - ) +_has_vllm = True if TYPE_CHECKING: from vllm.engine.async_llm_engine import AsyncEngineArgs @@ -61,30 +42,8 @@ def __init__(self, *args, **kwargs): TIMEOUT_SECONDS = os.getenv("TORCHRL_VLLM_TIMEOUT_SECONDS", 300) -try: - import vllm - - _has_vllm = True -except ImportError: - vllm = None - _has_vllm = False - - -if not _has_vllm: - class Worker: - """Placeholder for Worker class when vLLM is not installed.""" - - def __init__(self, *args, **kwargs): - raise ImportError( - "vllm is not installed. Please install it with `pip install vllm`." - ) - -else: - from vllm.worker.worker import Worker - - -class _AsyncvLLMWorker(Worker): +class _AsyncvLLMWorker: """Async vLLM worker for Ray with weight update capabilities. This worker extends the base vLLM Worker to support async operations @@ -95,8 +54,8 @@ def __init__(self, *args, **kwargs): torchrl_logger.info(f"=> in {type(self).__name__}.__init__") torchrl_logger.info(f"visible devices {os.getenv('CUDA_VISIBLE_DEVICES')}") torchrl_logger.info(f"device count {torch.cuda.device_count()}") - super().__init__(*args, **kwargs) self.model_update_group = None + super().__init__(*args, **kwargs) def init_weight_update_group( self, @@ -105,7 +64,10 @@ def init_weight_update_group( rank_offset: int, world_size: int, ): - """Initialize weight update group for this worker. + """Initialize weight update group for this worker (non-blocking). + + This method starts NCCL initialization in a background thread and returns immediately, + allowing the RPC to complete. The NCCL collective will complete when the trainer joins. Args: master_address (str): The master address for distributed training. @@ -113,10 +75,12 @@ def init_weight_update_group( rank_offset (int): Rank offset for this worker in the global weight update group. world_size (int): Total number of processes in the weight update group. """ + import threading + from vllm.distributed.parallel_state import get_world_group torchrl_logger.info(f"=> in {type(self).__name__}.init_weight_update_group") - if self.model_update_group is not None: + if getattr(self, "model_update_group", None) is not None: torchrl_logger.info("Model update group already initialized") return @@ -128,18 +92,35 @@ def init_weight_update_group( # Calculate the global rank for weight update group rank = local_rank + rank_offset torchrl_logger.info( - f"Initializing {type(self).__name__} weight update group with " + f"Starting {type(self).__name__} weight update group init (non-blocking) with " f"{master_address=}, {master_port=}, {rank=}, {world_size=}, device={self.device}" ) - # Import synchronous version for workers too - from .vllm_utils import stateless_init_process_group + # Start NCCL init in a background thread so this RPC can return immediately + def _init_nccl_background(): + try: + from .vllm_utils import stateless_init_process_group - self.model_update_group = stateless_init_process_group( - master_address, master_port, rank, world_size, self.device - ) + torchrl_logger.info( + f"Worker rank {rank}: Starting NCCL init (will block until collective completes)..." + ) + self.model_update_group = stateless_init_process_group( + master_address, master_port, rank, world_size, self.device + ) + torchrl_logger.info(f"Worker rank {rank}: NCCL init complete!") + except Exception as e: + torchrl_logger.error(f"Worker rank {rank}: NCCL init failed: {e}") + raise + + thread = threading.Thread(target=_init_nccl_background, daemon=False) + thread.start() + + # Store thread reference for potential cleanup + self._nccl_init_thread = thread - torchrl_logger.info(f"{type(self).__name__}.init_weight_update_group success") + torchrl_logger.info( + f"{type(self).__name__}.init_weight_update_group dispatched (non-blocking)" + ) def update_weight(self, name: str, dtype_name: str, shape: tuple[int, ...]): """Update weight via broadcast from master (rank 0) - periodic-mono pattern. @@ -169,6 +150,41 @@ def check_nccl_group_ready(self): torchrl_logger.info(f"Worker NCCL group ready: {ready}") return ready + def load_weights_from_storage(self, storage_path: str, num_threads: int = 1): + """Load weights from shared storage (double-buffer approach). + + This method reads weights from a memory-mapped TensorDict directory + and loads them into the model. Used for file-based weight synchronization + as an alternative to NCCL collectives. + + Args: + storage_path: Path to the directory containing memory-mapped weights + num_threads: Number of threads for reading (default: 1) + """ + from tensordict import TensorDict + + torchrl_logger.info(f"Worker loading weights from {storage_path}") + + # Read weights from shared storage + weights = TensorDict.load_memmap(storage_path) + weights = weights.flatten_keys(".") + + # Convert to list of (name, tensor) tuples + weights_list = list(weights.items()) + + torchrl_logger.info(f"Worker loading {len(weights_list)} weights into model") + + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [ + executor.submit(self.model_runner.model.load_weights, weights) + for weights in weights_list + ] + wait(futures) + + torchrl_logger.info( + f"Worker successfully loaded {len(weights_list)} weights from storage" + ) + class _AsyncLLMEngine: """Extended AsyncLLMEngine with TorchRL-specific features. @@ -203,17 +219,9 @@ def __init__( from vllm import AsyncLLMEngine - worker_cls = "torchrl.modules.llm.backends.vllm.vllm_async._AsyncvLLMWorker" - if engine_args.worker_cls != "auto": - old_worker_cls = engine_args.worker_cls - torchrl_logger.warning( - f"Overriding worker_cls from {old_worker_cls} to {worker_cls}" - ) - if bundle_indices is not None: os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(map(str, bundle_indices)) - engine_args.worker_cls = worker_cls engine_args.enable_prefix_caching = enable_prefix_caching # Create the engine directly - this is the source of the blocking ray.get issue @@ -466,36 +474,6 @@ def _gpus_per_replica(engine_args: AsyncEngineArgs) -> int: ) -def _get_bundle_indices(placement_group, index: int, length: int) -> list[int]: - """Get bundle indices for a placement group. - - Address https://github.com/ray-project/ray/issues/51117 - This function is used to get the bundle indices of a placement group - and ensure that the bundles placed on the same node are grouped together. - - Args: - placement_group: Ray placement group. - index (int): Index of the current replica. - length (int): Number of bundles per replica. - - Returns: - list[int]: Bundle indices for this replica. - """ - if ray is None: - raise ImportError( - "ray is not installed. Please install it with `pip install ray`." - ) - - pg_infos = ray.util.placement_group_table(placement_group) - - node_id_to_bundles = {} - for bundle, node_id in pg_infos["bundles_to_node_id"].items(): - node_id_to_bundles.setdefault(node_id, []).append(bundle) - - sorted_bundle_indices = sum(node_id_to_bundles.values(), []) - return sorted_bundle_indices[index * length : (index + 1) * length] - - # Create Ray remote versions if ray is not None and _has_vllm: _AsyncLLMEngineActor = ray.remote(num_cpus=0, num_gpus=0)(_AsyncLLMEngine) @@ -630,13 +608,6 @@ def _launch(self): torchrl_logger.warning("AsyncVLLMEngineService already launched") return - # Check if CUDA is available since vLLM requires GPU - if not torch.cuda.is_available(): - raise RuntimeError( - "AsyncVLLM requires CUDA but no GPU devices are available. " - "Please run on a machine with GPU support." - ) - torchrl_logger.info( f"Launching {self.num_replicas} async vLLM engine actors..." ) @@ -651,10 +622,8 @@ def _launch(self): ) # Create individual placement group for this replica - bundles = [ - {"GPU": 1.0, "CPU": 1.0} - for _ in range(self.engine_args.tensor_parallel_size) - ] + num_gpus = _gpus_per_replica(self.engine_args) + bundles = [{"GPU": 1.0, "CPU": 1.0} for _ in range(num_gpus)] torchrl_logger.info( f"Creating placement group for replica {i + 1} with {len(bundles)} bundles" ) @@ -670,8 +639,8 @@ def _launch(self): # Calculate bundle indices for tensor parallelism bundle_indices = None - if self.engine_args.tensor_parallel_size > 1: - bundle_indices = list(range(self.engine_args.tensor_parallel_size)) + if num_gpus > 1: + bundle_indices = list(range(num_gpus)) bundle_index = 0 # Always use first bundle since each replica has its own placement group scheduling_strategy = PlacementGroupSchedulingStrategy( @@ -691,7 +660,6 @@ def _launch(self): bundle_indices=bundle_indices, enable_prefix_caching=self.engine_args.enable_prefix_caching, ) - self.actors.append(actor) torchrl_logger.info("Waiting for actors to be ready") @@ -1163,28 +1131,65 @@ def get_master_port(self) -> int: self._cached_master_port = 29500 # Default port return self._cached_master_port - def init_weight_update_group(self) -> None: - """Initialize the weight update communication group (RLvLLMEngine interface).""" + def init_weight_update_group( + self, + master_address: str, + master_port: int | str, + ) -> list[Any]: + """Forward the request to init NCCL weight update group to all actors. + + This method initializes the weight update group for all vLLM workers. + The external trainer should be rank 0, and vLLM workers will be ranks 1+. + + Args: + master_address: Master address for NCCL communication. + master_port: Master port for NCCL communication. + + Returns: + List of Ray futures for the initialization calls. + + Note: + The caller must wait on the returned futures (ray.get(refs)) to ensure + all workers have completed initialization before sending weights. + """ if not self._launched: raise RuntimeError( "AsyncVLLM service must be launched before initializing weight update group" ) - master_address = self.get_master_address() - master_port = self.get_master_port() + gpus_per_replica = _gpus_per_replica(self.engine_args) + weight_sync_world_size = self.num_replicas * gpus_per_replica + 1 - # Call the internal method with the auto-detected parameters (like V1) - refs = self._init_weight_update_group_internal(master_address, master_port) + torchrl_logger.info( + f"Initializing weight update group for {self.num_replicas} replicas " + f"with {gpus_per_replica} GPUs each (world_size={weight_sync_world_size})" + ) - # CRITICAL: Initialize master NCCL group immediately (like V1) - don't wait for workers - torchrl_logger.info("Setting up master NCCL group (rank 0)...") - self._setup_nccl_master_group() + from vllm import envs - # Now wait for workers to complete (like V1 does) - if ray is not None: - ray.get(refs) + refs = [] + for i, actor in enumerate(self.actors): + rank_offset = 1 + i * gpus_per_replica + if envs and envs.VLLM_USE_V1: + actor_collective_rpc = actor.collective_rpc_v1 + else: + actor_collective_rpc = actor.collective_rpc_v0 + refs.append( + actor_collective_rpc.remote( + "init_weight_update_group", + args=( + master_address, + str(master_port), + rank_offset, + weight_sync_world_size, + ), + ) + ) + torchrl_logger.info( + f"Requested init for actor {i} with rank_offset {rank_offset}" + ) - torchrl_logger.info("AsyncVLLM weight update group initialized") + return refs def update_weights(self, weights: Iterator[tuple[str, torch.Tensor]]) -> None: """Update model weights across all replicas using NCCL broadcast. @@ -1909,16 +1914,22 @@ def make_async_vllm_engine( num_replicas: int = 1, verbose: bool = True, compile: bool = True, + tensor_parallel_size: int | None = None, + data_parallel_size: int | None = None, + pipeline_parallel_size: int | None = None, **kwargs, ) -> AsyncVLLM: """Create an async vLLM engine service. - Args: + Keyword Args: model_name (str): The model name to pass to vLLM. num_devices (int, optional): Number of devices to use, per replica. num_replicas (int): Number of engine replicas to create. verbose (bool, optional): Whether to enable verbose logging with throughput statistics. Defaults to True. compile (bool, optional): Whether to enable model compilation for better performance. Defaults to True. + tensor_parallel_size (int, optional): Number of devices to use, per replica. Defaults to None. + data_parallel_size (int, optional): Number of data parallel groups to use. Defaults to None. + pipeline_parallel_size (int, optional): Number of pipeline parallel groups to use. Defaults to None. **kwargs: Additional arguments passed to AsyncEngineArgs. Returns: @@ -1944,17 +1955,6 @@ def make_async_vllm_engine( from vllm import AsyncEngineArgs - # Check if CUDA is available since vLLM requires GPU - if not torch.cuda.is_available(): - raise RuntimeError( - "AsyncVLLM requires CUDA but no GPU devices are available. " - "Please run on a machine with GPU support." - ) - - # Handle device specification - if num_devices is None: - num_devices = 1 - # Configure verbose logging if requested if verbose: import logging @@ -1969,6 +1969,21 @@ def make_async_vllm_engine( "Enabled verbose vLLM logging - throughput statistics will be displayed" ) + # Set tensor_parallel_size to num_devices if not set + if tensor_parallel_size is None: + if num_devices is None: + tensor_parallel_size = 1 + else: + tensor_parallel_size = num_devices + elif num_devices is not None and tensor_parallel_size != num_devices: + raise ValueError(f"tensor_parallel_size must be set to {num_devices}") + + if data_parallel_size is None: + data_parallel_size = 1 + + if pipeline_parallel_size is None: + pipeline_parallel_size = 1 + # Create engine args kwargs.setdefault("distributed_executor_backend", "ray") # Don't explicitly set enable_prefix_caching to avoid conflicts @@ -1984,8 +1999,10 @@ def make_async_vllm_engine( engine_args = AsyncEngineArgs( model=model_name, - tensor_parallel_size=num_devices, - worker_cls="torchrl.modules.llm.backends.vllm.vllm_async._AsyncvLLMWorker", + tensor_parallel_size=tensor_parallel_size, + data_parallel_size=data_parallel_size, + pipeline_parallel_size=pipeline_parallel_size, + worker_extension_cls="torchrl.modules.llm.backends.vllm.vllm_async._AsyncvLLMWorker", **kwargs, ) diff --git a/torchrl/modules/llm/backends/vllm/vllm_utils.py b/torchrl/modules/llm/backends/vllm/vllm_utils.py index efbfb29ef88..533bc274449 100644 --- a/torchrl/modules/llm/backends/vllm/vllm_utils.py +++ b/torchrl/modules/llm/backends/vllm/vllm_utils.py @@ -7,6 +7,9 @@ from __future__ import annotations +import torch + +from torchrl._utils import logger as torchrl_logger try: from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator @@ -22,7 +25,7 @@ def stateless_init_process_group( - master_address: str | None, master_port: str | None, rank, world_size, device + master_address: str | None, master_port: str | None, rank, world_size, device=None ): """Initializes a stateless process group for distributed communication. @@ -36,7 +39,7 @@ def stateless_init_process_group( master_port (str | None): The port used by the master node. Automatically assigns an open port if not specified. rank (int): The rank of the current process. world_size (int): The total number of processes in the distributed group. - device: The device to use for communication. + device: The device to use for communication. Defaults to None. Returns: PyNcclCommunicator: A PyNcclCommunicator instance initialized with the created StatelessProcessGroup. @@ -56,9 +59,14 @@ def stateless_init_process_group( if master_port is None: master_port = get_open_port() if callable(get_open_port) else 29500 + torchrl_logger.info( + f"Initializing stateless process group: rank={rank}, world_size={world_size}, master_address={master_address}, master_port={master_port}" + ) pg = StatelessProcessGroup.create( host=master_address, port=int(master_port), rank=rank, world_size=world_size ) + if device is None: + device = torch.device("cuda:0") pynccl = PyNcclCommunicator(pg, device=device) return pynccl diff --git a/torchrl/weight_update/llm/__init__.py b/torchrl/weight_update/llm/__init__.py new file mode 100644 index 00000000000..a74ad9c70b0 --- /dev/null +++ b/torchrl/weight_update/llm/__init__.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .vllm_double_buffer import ( + VLLMDoubleBufferSyncScheme, + VLLMDoubleBufferTransport, + VLLMDoubleBufferWeightReceiver, + VLLMDoubleBufferWeightSender, +) +from .vllm_nccl import ( + get_model_metadata, + VLLMCollectiveTransport, + VLLMWeightReceiver, + VLLMWeightSender, + VLLMWeightSyncScheme, +) + +__all__ = [ + # vLLM NCCL-based weight sync + "VLLMWeightSyncScheme", + "VLLMWeightSender", + "VLLMWeightReceiver", + "VLLMCollectiveTransport", + "get_model_metadata", + # vLLM double-buffer weight sync + "VLLMDoubleBufferSyncScheme", + "VLLMDoubleBufferWeightSender", + "VLLMDoubleBufferWeightReceiver", + "VLLMDoubleBufferTransport", +] diff --git a/torchrl/weight_update/llm/vllm_double_buffer.py b/torchrl/weight_update/llm/vllm_double_buffer.py new file mode 100644 index 00000000000..2482f250d0e --- /dev/null +++ b/torchrl/weight_update/llm/vllm_double_buffer.py @@ -0,0 +1,362 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""vLLM weight synchronization using double-buffered shared memory. + +This module provides weight synchronization for vLLM engines using a double-buffer +approach with memory-mapped TensorDict storage. + +**Architecture Overview** + +The double-buffer synchronization uses a simpler architecture compared to NCCL: + +1. **Sender (Trainer)** + - Extracts weights from the training model + - Writes weights to shared directory using TensorDict.memmap + - No coordination needed - receiver pulls when ready + +2. **Receiver (vLLM Worker)** + - Uses RPC to tell all vLLM workers to load from shared directory + - Each worker reads weights and calls model.load_weights() + - Can trigger at any time (pull-based) + +**Key Differences from NCCL** + +- **Async vs Sync**: Double-buffer is asynchronous (no coordination required) +- **Push vs Pull**: Sender writes, receiver pulls when ready via RPC +- **Simplicity**: No NCCL collectives, uses file I/O +- **Storage**: Uses shared filesystem instead of GPU-GPU transfer + +**RPC Pattern** + +Like the NCCL implementation, this uses RPC to coordinate workers: +- RPC tells workers: "load weights from this directory" +- Workers read from shared storage independently +- Each worker calls `model_runner.model.load_weights()` + +**Usage Example** + +.. code-block:: python + + # Create scheme with shared directory + scheme = VLLMDoubleBufferSyncScheme( + remote_addr="/shared/weights", + num_threads=4 + ) + + # Sender side (trainer) + sender = scheme.create_sender() + sender.register_model(policy_model) + sender.update_weights() # Writes to /shared/weights + + # Receiver side (vLLM worker - AsyncVLLM) + receiver = scheme.create_receiver(vllm_engine) + receiver.poll_and_apply() # RPC to workers -> load from /shared/weights + +**Node-to-Node Transfer** + +For distributed setups, you can use different addresses: +- Sender writes to local path +- Use NFS, rsync, or other file sync mechanisms +- Receiver reads from its local mount point +""" + +from __future__ import annotations + +from typing import Any, Literal + +from tensordict import TensorDict, TensorDictBase +from torchrl._utils import logger +from torchrl.weight_update.weight_sync_schemes import ( + WeightReceiver, + WeightSender, + WeightStrategy, + WeightSyncScheme, +) + + +class VLLMDoubleBufferTransport: + """Transport for vLLM using double-buffered memory-mapped storage. + + This transport writes weights to a shared directory and reads them back + using TensorDict's memory-mapping capabilities. + + Args: + remote_addr: Directory path where sender writes weights. + local_addr: Directory path where receiver reads weights. + If None, uses same path as remote_addr (for local testing). + num_threads: Number of threads for memmap operations. + """ + + def __init__( + self, remote_addr: str, local_addr: str | None = None, num_threads: int = 1 + ): + if local_addr is None: + local_addr = remote_addr + self.remote_addr = remote_addr + self.local_addr = local_addr + self.num_threads = num_threads + + def send_weights(self, model_id: str, weights: Any) -> None: + """Writes the weights to a shared directory. + + Args: + model_id: Identifier for the model (used for logging). + weights: TensorDict or dict of weights to write. + """ + if isinstance(weights, dict): + weights = TensorDict(weights, batch_size=[]) + elif isinstance(weights, TensorDictBase): + # Ensure it has a batch_size + if weights.batch_size == (): + weights = weights.clone() + + logger.info(f"Writing weights for model '{model_id}' to {self.remote_addr}") + weights.memmap(self.remote_addr, num_threads=self.num_threads) + logger.info(f"Weights written successfully to {self.remote_addr}") + + def receive_weights(self, timeout: float = 1.0) -> TensorDict: + """Reads the weights from the shared directory. + + Args: + timeout: Not used for file-based transport (kept for API compatibility). + + Returns: + TensorDict with flattened keys containing the weights. + """ + logger.info(f"Reading weights from {self.local_addr}") + weights = TensorDict.load_memmap(self.local_addr) + weights = weights.flatten_keys(".") + logger.info(f"Weights read successfully from {self.local_addr}") + return weights + + def check_connection(self) -> bool: + """Check if the transport is ready. + + For file-based transport, always returns True. + """ + return True + + +class VLLMDoubleBufferSyncScheme(WeightSyncScheme): + """Weight synchronization scheme for vLLM using double-buffered storage. + + This scheme uses memory-mapped TensorDict storage to transfer weights from + a trainer to vLLM inference workers. It's simpler than NCCL-based approaches + and doesn't require process group coordination. + + Args: + remote_addr: Directory path where sender writes weights. + local_addr: Directory path where receiver reads weights. + If None, uses same path as remote_addr (for local testing). + num_threads: Number of threads for memmap operations. Defaults to 1. + strategy: Weight extraction strategy ("tensordict" or "state_dict"). + + Example: + >>> # Local testing (same machine) + >>> scheme = VLLMDoubleBufferSyncScheme( + ... remote_addr="/tmp/weights", + ... strategy="tensordict" + ... ) + >>> + >>> # Distributed setup (different machines) + >>> # On trainer node: + >>> scheme = VLLMDoubleBufferSyncScheme( + ... remote_addr="/mnt/shared/weights", # NFS mount + ... num_threads=4 + ... ) + >>> + >>> # On vLLM worker node: + >>> scheme = VLLMDoubleBufferSyncScheme( + ... remote_addr="/mnt/shared/weights", # Same NFS mount + ... num_threads=4 + ... ) + """ + + def __init__( + self, + remote_addr: str, + local_addr: str | None = None, + num_threads: int = 1, + strategy: Literal["tensordict", "state_dict"] = "tensordict", + ): + self.remote_addr = remote_addr + self.local_addr = local_addr if local_addr is not None else remote_addr + self.num_threads = num_threads + self.strategy_name = strategy + + def create_transport( + self, pipe_or_context: Any = None + ) -> VLLMDoubleBufferTransport: + """Create transport for double-buffered storage. + + Args: + pipe_or_context: Not used for file-based transport (kept for API compatibility). + + Returns: + A VLLMDoubleBufferTransport instance. + """ + return VLLMDoubleBufferTransport( + remote_addr=self.remote_addr, + local_addr=self.local_addr, + num_threads=self.num_threads, + ) + + def create_sender(self) -> VLLMDoubleBufferWeightSender: + """Create a weight sender for the trainer process.""" + return VLLMDoubleBufferWeightSender(self) + + def create_receiver(self, vllm_engine) -> VLLMDoubleBufferWeightReceiver: + """Create a weight receiver for a vLLM worker process. + + Args: + vllm_engine: The vLLM engine instance (must have .llm_engine.model_executor attribute). + """ + return VLLMDoubleBufferWeightReceiver(self, vllm_engine) + + +class VLLMDoubleBufferWeightSender(WeightSender): + """Sends weights to vLLM workers using double-buffered storage. + + This sender extracts weights from a training model and writes them to + a shared directory using TensorDict.memmap. + + Example: + >>> sender = scheme.create_sender() + >>> sender.register_model(policy_model) + >>> + >>> # During training loop + >>> sender.update_weights() # Writes current weights to shared storage + """ + + def __init__(self, scheme: VLLMDoubleBufferSyncScheme): + self._scheme = scheme + self._strategy = WeightStrategy(extract_as=scheme.strategy_name) + self._model_ref = None + self._transport = None + + def register_model(self, model: Any) -> None: + """Register the model to extract weights from. + + Args: + model: The model to extract weights from (e.g., TransformersWrapper). + """ + import weakref + + self._model_ref = weakref.ref(model) + + # Create transport on registration + self._transport = self._scheme.create_transport() + logger.info( + f"Registered model for double-buffer weight sync to {self._scheme.remote_addr}" + ) + + def update_weights(self, weights: Any | None = None) -> None: + """Extract and write weights to shared storage. + + Args: + weights: Optional weights to send. If None, extracts from registered model. + """ + if self._transport is None: + raise RuntimeError("Transport not initialized. Call register_model first.") + + # Extract weights if not provided + if weights is None: + model = self._model_ref() + if model is None: + raise RuntimeError("Model reference is dead") + weights = self._strategy.extract_weights(model) + else: + # Ensure weights are in the right format + if hasattr(weights, "state_dict"): + # It's a module, extract + weights = self._strategy.extract_weights(weights) + + # Send via transport + self._transport.send_weights("vllm_model", weights) + + +class VLLMDoubleBufferWeightReceiver(WeightReceiver): + """Receives weights in a vLLM worker using double-buffered storage. + + This receiver reads weights from a shared directory and loads them into + the vLLM engine using the engine's load_weights interface. + + Example: + >>> receiver = scheme.create_receiver(vllm_engine) + >>> + >>> # Poll for new weights + >>> if receiver.poll_and_apply(): + ... print("Weights updated!") + """ + + def __init__(self, scheme: VLLMDoubleBufferSyncScheme, vllm_engine): + self._scheme = scheme + self._strategy = WeightStrategy(extract_as=scheme.strategy_name) + self._vllm_engine = vllm_engine + self._transport = scheme.create_transport() + logger.info( + f"Initialized double-buffer receiver reading from {self._scheme.local_addr}" + ) + + def apply_weights(self, weights: TensorDict) -> None: + """Apply weights to vLLM engine using RPC. + + This method uses RPC to tell all vLLM workers to load weights from + the shared storage directory. Similar to how AsyncVLLM._update_weights_with_nccl_broadcast_simple + uses collective_rpc to coordinate workers. + + Args: + weights: TensorDict with flattened keys containing weights. + """ + logger.info("Applying weights to vLLM engine via RPC") + + # Convert TensorDict to list of (name, tensor) tuples + weights_list = list(weights.items()) + + # Check if this is an AsyncVLLM instance (uses RPC to coordinate workers) + if hasattr(self._vllm_engine, "collective_rpc"): + # AsyncVLLM path: use RPC to tell all workers to load weights + logger.info( + f"Using RPC to load {len(weights_list)} weights across all replicas" + ) + + # Call collective_rpc to tell workers to load from shared storage + # The method 'load_weights_from_storage' will be called on each worker + futures = self._vllm_engine.collective_rpc( + method="load_weights_from_storage", + args=(str(self._scheme.local_addr), self._transport.num_threads), + ) + + # Wait for all workers to complete + import ray + + ray.get(futures) + logger.info("Weights loaded successfully via RPC") + else: + # Direct path for local LLM (non-AsyncVLLM) + logger.info("Using direct load for local LLM") + engine = ( + self._vllm_engine.llm_engine + if hasattr(self._vllm_engine, "llm_engine") + else self._vllm_engine + ) + worker = engine.model_executor.driver_worker + model = worker.model_runner.model + model.load_weights(weights_list) + logger.info("Weights loaded successfully") + + def poll_and_apply(self, timeout: float = 180.0) -> bool: + """Poll for and apply weights from shared storage. + + Args: + timeout: Not used for file-based transport (kept for API compatibility). + + Returns: + True if weights were successfully read and applied, False otherwise. + """ + weights = self._transport.receive_weights(timeout=timeout) + self.apply_weights(weights) + return True diff --git a/torchrl/weight_update/llm/vllm_nccl.py b/torchrl/weight_update/llm/vllm_nccl.py new file mode 100644 index 00000000000..840a9883d14 --- /dev/null +++ b/torchrl/weight_update/llm/vllm_nccl.py @@ -0,0 +1,699 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""vLLM weight synchronization for the v2 API. + +This module provides weight synchronization for vLLM engines using a two-layer +architecture: + +**Architecture Overview** + +The weight synchronization uses two separate layers: + +1. **RPC Layer** (Coordination) + - Signals workers when a collective operation will begin + - Can be implemented with different backends (Ray, torch.distributed.rpc, etc.) + - Tells vLLM workers: "prepare to receive weights via collective" + - Currently supports Ray as the RPC backend + +2. **Collective Layer** (Data Transfer) + - Performs the actual weight broadcast using NCCL + - High-bandwidth GPU-to-GPU communication + - All ranks participate simultaneously in the collective + +**Why Two Layers?** + +Separating RPC and collectives provides: +- **Flexibility**: Swap RPC backends (Ray, RPC, gRPC) without changing collectives +- **Clarity**: Coordination logic separate from data transfer +- **Performance**: Use optimal transport for each (RPC for signals, NCCL for data) + +**Flow Example (Ray Backend)** + +.. code-block:: + + Trainer (rank 0) vLLM Workers (ranks 1+) + ================ ======================= + + # 1. RPC: Signal collective start + trainer.update_weights() ---------> [Ray RPC] --------> receiver.init_all_workers_group() + "I'm ready for collective" + + # 2. Collective: Broadcast weights + NCCL broadcast -------------------- [GPU-GPU] ---------> NCCL receive + (high bandwidth) (parallel) + + # 3. RPC: Confirmation (optional) + "broadcast done" <----------------- [Ray RPC] --------- "weights applied" + +**Extending to Other Backends** + +To add a new RPC backend (e.g., torch.distributed.rpc): + +1. Implement an RPC coordinator in the sender/receiver +2. Replace Ray remote calls with your RPC mechanism +3. Keep the collective layer unchanged (it's backend-agnostic) + +.. rubric:: Example + +.. code-block:: python + + class TorchRPCVLLMReceiver(VLLMWeightReceiver): + def init_all_workers_group(self, metadata): + # Use torch.distributed.rpc instead of Ray + torch.distributed.rpc.rpc_sync( + "trainer", + lambda: "ready", + ) + super().init_all_workers_group(metadata) # Collective init + +**Current Implementation (Ray Backend)** + +The test suite in ``test_weightsync.py`` demonstrates the Ray-based RPC: + +.. code-block:: python + + # Trainer actor (provides RPC endpoint) + trainer = RayWorkerTransformer.as_remote().options( + name="Trainer" # Named for discovery + ).remote(scheme_config) + + # Receiver actor (uses RPC to coordinate) + receiver = RayWorkerVLLM.as_remote().remote( + scheme_config, trainer_actor_name="Trainer" + ) + + # RPC Layer: Both actors call init() via Ray remote calls + # This coordinates the collective handshake + ray.get([trainer.init.remote(), receiver.init.remote()]) + + # RPC Layer: Trigger update via Ray remote call + # Collective Layer: NCCL broadcast happens automatically + ray.get(trainer.update_weights.remote(modify_weights=True)) + +In this setup: +- **Ray provides RPC**: Named actors, ``remote()`` calls, ``ray.get()`` +- **NCCL provides collectives**: GPU-GPU weight broadcast +- **Loose coupling**: Can replace Ray with any RPC mechanism +""" + +from __future__ import annotations + +from typing import Any, Literal + +import torch +import torch.distributed +from tensordict import TensorDictBase + +from torchrl._utils import logger as torchrl_logger +from torchrl.modules.llm.backends import stateless_init_process_group +from torchrl.weight_update.weight_sync_schemes import ( + WeightReceiver, + WeightSender, + WeightStrategy, + WeightSyncScheme, +) + +# ============================================================================ +# vLLM Transport using Collective Communication +# ============================================================================ + + +class VLLMCollectiveTransport: + """Transport for vLLM using collective communication (NCCL). + + **COLLECTIVE LAYER ONLY** - This class handles the data transfer layer. + RPC coordination is handled separately by the caller (sender/receiver). + + This transport uses PyTorch distributed collectives to broadcast weights + from a trainer (rank 0) to vLLM workers (ranks 1+). + + **Separation of Concerns:** + - This class: NCCL collective operations (GPU-GPU data transfer) + - Caller (sender/receiver): RPC coordination (when to start collective) + + Args: + master_address: Address of the master node for distributed init. + master_port: Port of the master node for distributed init. + rank: Rank of this process (0 for trainer, 1+ for vLLM workers). + world_size: Total number of processes (1 + num_replicas * gpus_per_replica). + device: Device to use for communication (typically cuda:0). + vllm_engine: Optional vLLM engine reference (for receiver side). + + Note: + The RPC layer (e.g., Ray remote calls) must ensure all ranks call + init_all_workers_group() simultaneously before any collective operations. + """ + + def __init__( + self, + master_address: str, + master_port: int, + rank: int | None, + world_size: int, + device: torch.device | str | int | None = None, + vllm_engine: Any | None = None, + ): + self.master_address = master_address + self.master_port = master_port + self.rank = rank + self.world_size = world_size + self.vllm_engine = vllm_engine + self._comm_group = None + self._model_metadata = None + + # Ray sets CUDA_VISIBLE_DEVICES, so each actor sees only device 0 + # PyNcclCommunicator expects an integer device index + if device is None: + self.device = 0 # Default to device 0 (Ray convention) + elif isinstance(device, str): + # Extract device index from "cuda:X" + self.device = int(device.split(":")[-1]) if ":" in device else 0 + elif isinstance(device, torch.device): + # Extract index from torch.device + self.device = device.index if device.index is not None else 0 + else: + self.device = device + + def init_all_workers_group( + self, model_metadata: dict[str, tuple[torch.dtype, torch.Size]] + ): + """Initialize the collective communication group. + + Args: + model_metadata: Dict mapping param names to (dtype, shape) tuples. + """ + self._model_metadata = model_metadata + + if self.rank == 0: + # Trainer side - initialize process group + torchrl_logger.info( + f"Initializing trainer collective group: rank={self.rank}, world_size={self.world_size}, device={self.device}" + ) + # Ray sets CUDA_VISIBLE_DEVICES, so we always use device 0 + # Set CUDA device before initializing NCCL to avoid segfaults + torch.cuda.set_device(self.device) + torchrl_logger.info(f"Set CUDA device to {self.device}") + + self._comm_group = stateless_init_process_group( + self.master_address, + self.master_port, + self.rank, + self.world_size, + device=self.device, + ) + torchrl_logger.info("Trainer collective group initialized successfully") + else: + # vLLM worker side - initialize through engine + if self.vllm_engine is None: + raise ValueError("vllm_engine must be provided for worker ranks") + + torchrl_logger.info( + "Initializing vLLM worker collective group through engine" + ) + # Call vLLM engine's init method - it returns futures for all workers + # Workers will start NCCL init in background threads and return immediately + refs = self.vllm_engine.init_weight_update_group( + master_address=self.master_address, + master_port=self.master_port, + ) + + # Wait for RPCs to complete - ensures workers have dispatched their NCCL init threads + import ray + + ray.get(refs) + torchrl_logger.info( + f"All {len(refs)} vLLM workers have dispatched NCCL init RPCs" + ) + + # Small delay to ensure worker background threads have entered the NCCL collective + # This prevents a race where the trainer starts NCCL before workers are ready + import time + + time.sleep(0.2) + + self._comm_group = True # Mark as initialized + torchrl_logger.info( + "vLLM workers should now be blocked in NCCL collective, ready for trainer" + ) + + def send_weights(self, model_id: str, weights: Any) -> None: + """Broadcast weights to all workers using NCCL. + + This method follows AsyncVLLM's periodic-mono pattern: + For each weight: RPC → NCCL broadcast → Wait for RPC completion + + This should only be called from rank 0 (trainer). + + Args: + model_id: ID of the model (used for logging). + weights: TensorDict or dict of weights to broadcast. + """ + # This code is a duplicate from AsyncVLLM + # We are waiting for vLLM server to accept tokens endpoints, at which point we will be + # able to remove all dependencies on Ray for vllm distributed features. + # This will allow a more natural integration with the sender/receiver API. + + import ray + + if self.rank != 0: + raise RuntimeError("send_weights should only be called from rank 0") + + if self._comm_group is None: + raise RuntimeError( + "Communication group not initialized. Call init_all_workers_group first." + ) + + if self._model_metadata is None: + raise RuntimeError("Model metadata not set") + + if self.vllm_engine is None: + raise RuntimeError( + "vllm_engine must be provided to sender for RPC coordination" + ) + + # Set CUDA device for this operation + torch.cuda.set_device(self.device) + + # Convert to dict if needed + if isinstance(weights, TensorDictBase): + weights_dict = weights.to_dict() + else: + weights_dict = weights + + torchrl_logger.info( + f"Broadcasting {len(weights_dict)} weights for model '{model_id}'" + ) + + # Broadcast each weight using periodic-mono pattern (like AsyncVLLM) + for name, (dtype, shape) in self._model_metadata.items(): + if name not in weights_dict: + raise ValueError( + f"Weight '{name}' not found in weights. Weights keys: {list(weights_dict.keys())[:10]}..." + ) + + tensor = weights_dict[name].to(f"cuda:{self.device}") + dtype_name = str(dtype).split(".")[-1] # "torch.float16" -> "float16" + + # Step 1: Send RPC to workers for this weight + futures = self.vllm_engine.collective_rpc( + "update_weight", args=(name, dtype_name, tuple(shape)) + ) + + # Step 2: Immediately broadcast this weight + self._comm_group.broadcast( + tensor, + src=0, + stream=torch.cuda.current_stream(), + ) + + # Step 3: Wait for workers to complete this weight + ray.get(futures) + del tensor + + torch.cuda.synchronize() + torchrl_logger.info(f"Broadcast complete for model '{model_id}'") + + def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: + """Receive weights from broadcaster. + + This should only be called from worker ranks (rank > 0). + This method is called by vLLM engine internally through collective operations. + + Returns: + None - vLLM handles weight application internally via collectives. + """ + # vLLM handles this through its own collective operations + # The weights are received and applied by the engine during broadcast + return None + + def check_connection(self) -> bool: + """Check if the communication group is initialized.""" + return self._comm_group is not None + + +# ============================================================================ +# vLLM Weight Synchronization Components +# ============================================================================ + + +class VLLMWeightSyncScheme(WeightSyncScheme): + """Weight synchronization scheme for vLLM engines. + + This scheme uses collective communication (NCCL) to broadcast weights from + a trainer to vLLM inference workers with parallelism support. + + Args: + master_address: Address of the master node. Defaults to "localhost". + master_port: Port of the master node. If None, will auto-assign. + gpus_per_replica: Number of GPUs per replica (tp_size × dp_size × pp_size). + num_replicas: Number of vLLM engine replicas. Defaults to 1. + strategy: Weight extraction strategy ("tensordict" or "state_dict"). + device: Device index to use for communication. Defaults to 0. + Note: When using Ray, each actor sees only its assigned GPU as device 0 + due to CUDA_VISIBLE_DEVICES isolation. You should typically use 0. + + .. warning:: + Collective communication requires ALL ranks to participate simultaneously. + Both the sender (trainer, rank 0) and all receivers (vLLM workers, ranks 1+) + must call ``init_all_workers_group()`` at approximately the same time for the collective + handshake to succeed. Do NOT wait for one init to complete before starting + the other - start both and wait for both together. + + Note: + The world_size for NCCL will be: 1 (trainer) + num_replicas × gpus_per_replica (vLLM workers) + + Example: + >>> # Single replica with 2 GPUs (e.g., tp_size=2) + >>> scheme = VLLMWeightSyncScheme( + ... master_port=12345, + ... gpus_per_replica=2, + ... num_replicas=1, + ... strategy="tensordict" + ... ) # world_size = 1 + 1*2 = 3 + >>> + >>> # Multiple replicas with 1 GPU each + >>> scheme = VLLMWeightSyncScheme( + ... master_port=12345, + ... gpus_per_replica=1, + ... num_replicas=2, + ... strategy="tensordict" + ... ) # world_size = 1 + 2*1 = 3 + >>> + >>> # Multiple replicas with tp_size=2, dp_size=1, pp_size=1 + >>> scheme = VLLMWeightSyncScheme( + ... master_port=12345, + ... gpus_per_replica=2, # 2*1*1 + ... num_replicas=3, + ... strategy="tensordict" + ... ) # world_size = 1 + 3*2 = 7 + >>> + >>> # In trainer process (rank 0) + >>> sender = VLLMWeightSender(scheme) + >>> sender.register_model(policy) + >>> + >>> # In vLLM worker process (rank 1+) + >>> receiver = VLLMWeightReceiver(scheme, vllm_engine) + >>> + >>> # IMPORTANT: Both must init simultaneously for collective handshake + >>> # With Ray: + >>> init_sender = sender_actor.init_all_workers_group.remote(metadata) + >>> init_receiver = receiver_actor.init_all_workers_group.remote(metadata) + >>> ray.get([init_sender, init_receiver]) # Wait for both together + >>> + >>> # After init, updates work normally + >>> sender.update_weights() + >>> # Weights are received automatically via collectives + """ + + def __init__( + self, + master_address: str | None = None, + master_port: int | None = None, + gpus_per_replica: int = 1, + num_replicas: int = 1, + strategy: Literal["tensordict", "state_dict"] = "tensordict", + device: torch.device | str | int = 0, + ): + self.master_address = ( + master_address if master_address is not None else "localhost" + ) + self.master_port = master_port + self.gpus_per_replica = gpus_per_replica + self.num_replicas = num_replicas + self.strategy_name = strategy + # Ray sets CUDA_VISIBLE_DEVICES for each actor, so device 0 is typical + self.device = device + + # Auto-assign port if not provided + if self.master_port is None: + try: + from vllm.utils import get_open_port + + self.master_port = get_open_port() + except ImportError: + # Fallback if vLLM not available + import socket + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + self.master_port = s.getsockname()[1] + + def create_transport(self, pipe_or_context: Any) -> VLLMCollectiveTransport: + """Create transport for collective communication. + + For vLLM, this creates a transport but requires additional setup via init_all_workers_group(). + This method is required by the base class but transport creation for vLLM + is more complex and typically handled by sender/receiver initialization. + + Args: + pipe_or_context: Not used for vLLM (kept for API compatibility). + + Returns: + A VLLMCollectiveTransport instance (needs init_all_workers_group() to be called). + """ + # Return a transport with default rank 0 (trainer) + # Actual initialization happens in sender/receiver + world_size = 1 + self.num_replicas * self.gpus_per_replica + return VLLMCollectiveTransport( + master_address=self.master_address, + master_port=self.master_port, + rank=0, + world_size=world_size, + device=self.device, + ) + + def create_sender(self) -> VLLMWeightSender: + """Create a weight sender for the trainer process.""" + return VLLMWeightSender(self) + + def create_receiver(self, vllm_engine) -> VLLMWeightReceiver: + """Create a weight receiver for a vLLM worker process. + + Args: + vllm_engine: The vLLM engine instance (must implement RLvLLMEngine interface). + """ + return VLLMWeightReceiver(self, vllm_engine) + + +class VLLMWeightSender(WeightSender): + """Sends weights to vLLM workers using collective communication. + + **RPC + Collective Implementation** + + This class implements both layers: + + 1. **RPC Layer**: Currently uses Ray remote calls (implicit in test setup) + - Can be extended to other RPC backends (torch.distributed.rpc, gRPC) + - In the test, Ray actors provide the RPC mechanism + + 2. **Collective Layer**: Uses VLLMCollectiveTransport for NCCL broadcast + - Broadcasts weights from trainer (rank 0) to workers (ranks 1+) + - High-bandwidth GPU-to-GPU transfer + + **Extending RPC Backends** + + To use a different RPC backend, subclass and override coordination: + + .. code-block:: python + + class TorchRPCVLLMSender(VLLMWeightSender): + def update_weights(self, weights=None): + # Custom RPC: Signal workers to prepare + for worker in self.workers: + torch.distributed.rpc.rpc_async(worker, "prepare_receive") + + # Then do collective (unchanged) + super().update_weights(weights) + """ + + def __init__(self, scheme: VLLMWeightSyncScheme): + self._scheme = scheme + self._strategy = WeightStrategy(extract_as=scheme.strategy_name) + self._model_ref = None + self._transport = None + self._model_metadata = None + + def register_model(self, model: Any) -> None: + """Register the model to extract weights from.""" + import weakref + + self._model_ref = weakref.ref(model) + + def init_all_workers_group( + self, + model_metadata: dict[str, tuple[torch.dtype, torch.Size]], + vllm_engine: Any | None = None, + ): + """Initialize the collective communication group. + + Args: + model_metadata: Dict mapping param names to (dtype, shape) tuples. + vllm_engine: Optional vLLM engine for RPC coordination. Required for NCCL broadcasts. + """ + self._model_metadata = model_metadata + self._vllm_engine = vllm_engine + + # Create transport for trainer (rank 0) + world_size = 1 + self._scheme.num_replicas * self._scheme.gpus_per_replica + self._transport = VLLMCollectiveTransport( + master_address=self._scheme.master_address, + master_port=self._scheme.master_port, + rank=0, # Trainer is always rank 0 + world_size=world_size, + device=self._scheme.device, + vllm_engine=vllm_engine, + ) + torchrl_logger.info( + f"Initializing transport from sender with world_size={world_size}" + ) + self._transport.init_all_workers_group(model_metadata) + + def update_weights(self, weights: Any | None = None) -> None: + """Extract and broadcast weights to vLLM workers. + + Args: + weights: Optional weights to send. If None, extracts from registered model. + """ + if self._transport is None: + raise RuntimeError( + "Transport not initialized. Call init_all_workers_group first." + ) + + # Extract weights if not provided + if weights is None: + model = self._model_ref() + if model is None: + raise RuntimeError("Model reference is dead") + weights = self._strategy.extract_weights(model) + else: + # Ensure weights are in the right format + if hasattr(weights, "state_dict"): + # It's a module, extract + weights = self._strategy.extract_weights(weights) + + # Send via transport + self._transport.send_weights("vllm_model", weights) + + +class VLLMWeightReceiver(WeightReceiver): + """Receives weights in a vLLM worker using collective communication. + + **RPC + Collective Implementation** + + This class implements both layers: + + 1. **RPC Layer**: Currently uses Ray for coordination + - `init()` in test uses Ray `ray.get_actor()` to find trainer + - Fetches metadata via Ray remote call + - Signals readiness to participate in collective + + 2. **Collective Layer**: Participates in NCCL broadcast + - Receives weights via collective operations + - vLLM engine applies weights internally during broadcast + + **Extending RPC Backends** + + To use a different RPC backend: + + .. code-block:: python + + class TorchRPCVLLMReceiver(VLLMWeightReceiver): + def init(self): + # Custom RPC: Get metadata from trainer + metadata = torch.distributed.rpc.rpc_sync( + "trainer", + lambda: get_metadata() + ) + + # Then init collective (unchanged) + self.receiver.init_all_workers_group(metadata) + + Note: + The RPC and collective layers are loosely coupled. The RPC layer + ensures all ranks are ready before the collective starts, but the + actual data transfer is independent of the RPC mechanism. + """ + + def __init__(self, scheme: VLLMWeightSyncScheme, vllm_engine): + self._scheme = scheme + self._strategy = WeightStrategy(extract_as=scheme.strategy_name) + self._vllm_engine = vllm_engine + self._transport = None + + def init_all_workers_group( + self, model_metadata: dict[str, tuple[torch.dtype, torch.Size]] + ): + """Initialize the collective communication group. + + Args: + model_metadata: Dict mapping param names to (dtype, shape) tuples. + """ + # For vLLM receiver, we use rank=1 as a placeholder + # The engine handles actual rank assignment internally for all workers + world_size = 1 + self._scheme.num_replicas * self._scheme.gpus_per_replica + self._transport = VLLMCollectiveTransport( + master_address=self._scheme.master_address, + master_port=self._scheme.master_port, + rank=None, # Placeholder - engine assigns actual ranks + world_size=world_size, + device=self._scheme.device, + vllm_engine=self._vllm_engine, + ) + torchrl_logger.info( + f"Initializing transport from receiver with world_size={world_size}." + ) + self._transport.init_all_workers_group(model_metadata) + + def apply_weights(self, weights: Any) -> None: + """Apply weights to vLLM engine. + + Note: For vLLM, weights are applied automatically during the collective + broadcast operation. This method is a no-op but kept for API consistency. + """ + # vLLM handles weight application through its collective operations + # The weights are already applied by the time broadcast completes + + def poll_and_apply(self, timeout: float = 0.1) -> bool: + """Poll for and apply weights. + + Returns: + False - vLLM uses push-based updates via collectives, not polling. + """ + # vLLM uses collective broadcasts (push), not polling + # This is handled by the engine's collective operations + return False + + +# ============================================================================ +# Helper Functions +# ============================================================================ + + +def get_model_metadata(model) -> dict[str, tuple[torch.dtype, torch.Size]]: + """Extract model metadata from a model. + + Args: + model: A model with state_dict() or a model wrapper. + + Returns: + Dict mapping parameter names to (dtype, shape) tuples. + + Note: + This function must extract keys in the same format as WeightStrategy.extract_weights() + to ensure consistency between metadata and actual weight keys during broadcasting. + """ + # Extract state_dict directly from the model + # This ensures keys match what extract_weights() will produce + if hasattr(model, "state_dict"): + if hasattr(model, "merge_and_unload"): + # LoRA model + sd = model.merge_and_unload().state_dict() + else: + sd = model.state_dict() + else: + raise TypeError(f"Cannot extract state_dict from {type(model)}") + + return {k: (v.dtype, v.shape) for k, v in sd.items()}