From 8b25d909043c60c1d43b31ad3049b6d72b0f7286 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 19 Oct 2025 18:07:27 -0700 Subject: [PATCH 1/7] Update [ghstack-poisoned] --- .../grpo/requirements_gsm8k.txt | 24 +- .../grpo/requirements_ifeval.txt | 32 +- test/llm/ray_helpers.py | 293 ++++++++++++++++++ test/llm/test_updaters.py | 284 +---------------- 4 files changed, 336 insertions(+), 297 deletions(-) create mode 100644 test/llm/ray_helpers.py diff --git a/sota-implementations/grpo/requirements_gsm8k.txt b/sota-implementations/grpo/requirements_gsm8k.txt index 4a6182cdf52..2b81e1a2df7 100644 --- a/sota-implementations/grpo/requirements_gsm8k.txt +++ b/sota-implementations/grpo/requirements_gsm8k.txt @@ -1,13 +1,11 @@ -torch==2.7.0 -transformers==4.52.4 -peft==0.15.2 -bitsandbytes==0.46.0 -datasets==3.6.0 -wandb==0.19.11 -hydra-core==1.3.2 -ray==2.46.0 -tqdm==4.67.1 -tensordict==0.9.0 -vllm==0.9.0.1 -accelerate==1.7.0 -xformers==0.0.30 +vllm==0.11.0 +peft +bitsandbytes +datasets +wandb +hydra-core +ray +tqdm +tensordict +accelerate +xformers diff --git a/sota-implementations/grpo/requirements_ifeval.txt b/sota-implementations/grpo/requirements_ifeval.txt index dbd2735d979..c3d889ddee1 100644 --- a/sota-implementations/grpo/requirements_ifeval.txt +++ b/sota-implementations/grpo/requirements_ifeval.txt @@ -1,16 +1,16 @@ -torch==2.7.0 -transformers==4.52.4 -peft==0.15.2 -bitsandbytes==0.46.0 -datasets==3.6.0 -wandb==0.19.11 -hydra-core==1.3.2 -ray==2.46.0 -tqdm==4.67.1 -tensordict==0.9.0 -vllm==0.9.0.1 -accelerate==1.7.0 -xformers==0.0.30 -nltk==3.9.1 -langdetect==1.0.9 -immutabledict==4.2.1 +vllm==0.11.0 +torch +transformers +peft +bitsandbytes +datasets +wandb +hydra-core +ray +tqdm +tensordict +accelerate +xformers +nltk +langdetect +immutabledict diff --git a/test/llm/ray_helpers.py b/test/llm/ray_helpers.py new file mode 100644 index 00000000000..fc4c448606d --- /dev/null +++ b/test/llm/ray_helpers.py @@ -0,0 +1,293 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Helper classes for Ray-based weight synchronization tests. + +This module contains Ray actor classes that need to be importable by Ray workers. +These classes are used in test_updaters.py but must be defined at module level +so Ray can serialize and import them on remote workers. +""" + +import torch +from torchrl._utils import logger + + +class WorkerVLLMNCCL: + """Ray actor for vLLM inference worker (receiver) using NCCL collective communication.""" + + def __init__( + self, + scheme_config: dict, + model_name: str = "Qwen/Qwen2.5-0.5B", + trainer_actor_name: str = "Trainer", + ): + pass + + # Store config for deferred initialization + self.scheme_config = scheme_config + self.model_name = model_name + self.trainer_actor_name = trainer_actor_name + self.wrapper = None + self.engine = None + self.receiver = None + self.scheme = None + self.trainer = None + self.model_metadata = None + + def setup(self): + """Set up vLLM engine (deferred from __init__ to avoid blocking).""" + from torchrl.modules.llm.backends import AsyncVLLM + from torchrl.modules.llm.policies import vLLMWrapper + + # Create vLLM wrapper + async_engine = AsyncVLLM.from_pretrained( + self.model_name, + num_replicas=2, # Number of engine replicas + ) + self.wrapper = vLLMWrapper(async_engine, input_mode="history") + self.engine = self.wrapper.model + + # Create scheme from config + from torchrl.weight_update.llm.vllm_nccl import VLLMWeightSyncScheme + + self.scheme = VLLMWeightSyncScheme(**self.scheme_config) + + # Create receiver (engine handles rank assignment automatically) + self.receiver = self.scheme.create_receiver(self.engine) + return "setup_complete" + + def init_metadata(self): + """Initialize the receiver by fetching metadata from trainer.""" + import ray + + if self.receiver is None: + raise RuntimeError("Must call setup() before init()") + + # Get trainer actor by name + logger.info(f"Getting trainer actor by name {self.trainer_actor_name}") + self.trainer = ray.get_actor(self.trainer_actor_name) + + # Fetch model metadata from trainer + logger.info("Fetching model metadata from trainer (requires max_concurrency>1)") + self.model_metadata = ray.get(self.trainer.get_model_metadata.remote()) + + def init(self): + if self.model_metadata is None: + raise RuntimeError("Must call init_metadata() before init()") + + # Initialize receiver with metadata + logger.info("Initializing receiver...") + self.receiver.init_all_workers_group(self.model_metadata) + self.initialized = True + logger.info("Receiver initialized") + return "initialized" + + def get_engine(self): + """Get the vLLM engine reference for RPC coordination.""" + if self.engine is None: + raise RuntimeError("Must call setup() first") + return self.engine + + def get_sample_output(self): + """Get a sample output to verify model works.""" + # Simple inference test + return "vllm_ready" + + @classmethod + def as_remote(cls, *args, **kwargs): + import ray + + # No GPUs needed for the actor itself - vLLM workers manage their own placement group (2 GPUs) + # AsyncVLLM service doesn't act as NCCL rank 0 when used with external trainer + return ray.remote(num_cpus=4, num_gpus=0, max_concurrency=4)(cls) + + +class WorkerTransformerNCCL: + """Ray actor for transformer trainer (sender) using NCCL collective communication.""" + + def __init__(self, scheme_config: dict, model_name: str = "Qwen/Qwen2.5-0.5B"): + from torchrl.weight_update.llm.vllm_nccl import ( + get_model_metadata, + VLLMWeightSyncScheme, + ) + from transformers import AutoModelForCausalLM + + # Create transformer model + transformer = AutoModelForCausalLM.from_pretrained( + model_name, + dtype=torch.float16, + ) + self.transformer = transformer.cuda() + + # Create scheme from config + self.scheme = VLLMWeightSyncScheme(**scheme_config) + + # Create sender + self.sender = self.scheme.create_sender() + self.sender.register_model(self.transformer) + + # Extract and store model metadata + self.model_metadata = get_model_metadata(self.transformer) + + def init(self, vllm_engine=None): + """Initialize sender with optional vLLM engine for RPC coordination. + + Args: + vllm_engine: Optional vLLM engine reference for calling collective_rpc + """ + if self.model_metadata is None: + raise RuntimeError("Must call init_metadata() before init()") + + self.sender.init_all_workers_group(self.model_metadata, vllm_engine=vllm_engine) + self.initialized = True + logger.info("Trainer initialized") + return "initialized" + + def get_model_metadata(self): + """Get model metadata to share with receiver.""" + return self.model_metadata + + def update_weights(self, modify_weights: bool = False): + """Trigger a weight update broadcast. + + Args: + modify_weights: If True, modifies weights before broadcasting + for verification purposes. + + Returns: + str: "updated" status message + """ + + # Optionally modify weights for testing + if modify_weights: + with torch.no_grad(): + first_param = next(self.transformer.parameters()) + first_param.add_(0.01) + + # Broadcast weights to all vLLM workers + self.sender.update_weights() + return "updated" + + def get_first_param_sum(self): + """Get sum of first parameter for verification.""" + return next(self.transformer.parameters()).sum().item() + + @classmethod + def as_remote(cls, *args, **kwargs): + import ray + + return ray.remote(num_cpus=4, num_gpus=1, max_concurrency=4)(cls) + + +class WorkerVLLMDoubleBuffer: + """Ray actor for vLLM inference worker (receiver) using double-buffered storage.""" + + def __init__(self, scheme_config: dict, model_name: str = "Qwen/Qwen2.5-0.5B"): + # Store config for deferred initialization + self.scheme_config = scheme_config + self.model_name = model_name + self.wrapper = None + self.engine = None + self.receiver = None + self.scheme = None + + def setup(self): + """Set up vLLM engine and receiver.""" + from torchrl.modules.llm.backends import AsyncVLLM + from torchrl.modules.llm.policies import vLLMWrapper + + # Create vLLM wrapper + async_engine = AsyncVLLM.from_pretrained( + self.model_name, + num_replicas=1, # Single replica for simplicity + ) + self.wrapper = vLLMWrapper(async_engine, input_mode="history") + self.engine = self.wrapper.model + + # Create scheme from config + from torchrl.weight_update.llm.vllm_double_buffer import ( + VLLMDoubleBufferSyncScheme, + ) + + self.scheme = VLLMDoubleBufferSyncScheme(**self.scheme_config) + + # Create receiver + self.receiver = self.scheme.create_receiver(self.engine) + logger.info("Receiver setup complete") + return "setup_complete" + + def poll_and_apply_weights(self): + """Poll for new weights and apply them to the engine.""" + if self.receiver is None: + raise RuntimeError("Must call setup() first") + + success = self.receiver.poll_and_apply() + return success + + def get_sample_output(self): + """Get a sample output to verify model works.""" + return "vllm_ready" + + @classmethod + def as_remote(cls, *args, **kwargs): + import ray + + # vLLM worker needs 1 GPU + return ray.remote(num_cpus=2, num_gpus=1, max_concurrency=4)(cls) + + +class WorkerTransformerDoubleBuffer: + """Ray actor for transformer trainer (sender) using double-buffered storage.""" + + def __init__(self, scheme_config: dict, model_name: str = "Qwen/Qwen2.5-0.5B"): + from torchrl.weight_update.llm.vllm_double_buffer import ( + VLLMDoubleBufferSyncScheme, + ) + from transformers import AutoModelForCausalLM + + # Create transformer model + transformer = AutoModelForCausalLM.from_pretrained( + model_name, + dtype=torch.float16, + ) + self.transformer = transformer.cuda() + + # Create scheme from config + self.scheme = VLLMDoubleBufferSyncScheme(**scheme_config) + + # Create sender + self.sender = self.scheme.create_sender() + self.sender.register_model(self.transformer) + logger.info("Trainer setup complete") + + def update_weights(self, modify_weights: bool = False): + """Trigger a weight update by writing to shared storage. + + Args: + modify_weights: If True, modifies weights before writing + for verification purposes. + + Returns: + str: "updated" status message + """ + # Optionally modify weights for testing + if modify_weights: + with torch.no_grad(): + first_param = next(self.transformer.parameters()) + first_param.add_(0.01) + + # Write weights to shared storage + self.sender.update_weights() + return "updated" + + def get_first_param_sum(self): + """Get sum of first parameter for verification.""" + return next(self.transformer.parameters()).sum().item() + + @classmethod + def as_remote(cls, *args, **kwargs): + import ray + + return ray.remote(num_cpus=2, num_gpus=1, max_concurrency=4)(cls) diff --git a/test/llm/test_updaters.py b/test/llm/test_updaters.py index 2ad078535cb..bb6b694cf3e 100644 --- a/test/llm/test_updaters.py +++ b/test/llm/test_updaters.py @@ -40,6 +40,13 @@ def get_open_port(): if _has_ray: import ray + + from .ray_helpers import ( + WorkerTransformerDoubleBuffer, + WorkerTransformerNCCL, + WorkerVLLMDoubleBuffer, + WorkerVLLMNCCL, + ) else: ray = None @@ -289,6 +296,8 @@ def setup_ray(self): @pytest.fixture(scope="class") def target_vllm_engine(self, model_name): """Create Ray worker with low memory settings.""" + if not _has_vllm: + pytest.skip("vllm not installed") # Create Ray worker with minimal memory usage worker = make_vllm_worker( model_name=model_name, @@ -339,6 +348,8 @@ class TestVLLMUpdaterV2WithLocalLLM(BaseVLLMUpdaterTest): @pytest.fixture(scope="class") def target_vllm_engine(self, model_name): """Create local LLM with low memory settings.""" + if not _has_vllm: + pytest.skip("vllm not installed") # Create local LLM with minimal memory usage llm = make_vllm_worker( model_name=model_name, @@ -457,165 +468,6 @@ def _make_worker_transformer(model_name: str = "Qwen/Qwen2.5-0.5B"): transformer = transformer.cuda() return transformer - class WorkerVLLM: - """Ray actor for vLLM inference worker (receiver).""" - - def __init__( - self, - scheme_config: dict, - model_name: str = "Qwen/Qwen2.5-0.5B", - trainer_actor_name: str = "Trainer", - ): - pass - - # Store config for deferred initialization - self.scheme_config = scheme_config - self.model_name = model_name - self.trainer_actor_name = trainer_actor_name - self.wrapper = None - self.engine = None - self.receiver = None - self.scheme = None - self.trainer = None - self.model_metadata = None - - def setup(self): - """Set up vLLM engine (deferred from __init__ to avoid blocking).""" - # Create vLLM wrapper - self.wrapper = TestWeightSyncVLLMNCCL._make_worker_vllm(self.model_name) - self.engine = self.wrapper.model - - # Create scheme from config - from torchrl.weight_update.llm.vllm_nccl import VLLMWeightSyncScheme - - self.scheme = VLLMWeightSyncScheme(**self.scheme_config) - - # Create receiver (engine handles rank assignment automatically) - self.receiver = self.scheme.create_receiver(self.engine) - return "setup_complete" - - def init_metadata(self): - """Initialize the receiver by fetching metadata from trainer.""" - import ray - - if self.receiver is None: - raise RuntimeError("Must call setup() before init()") - - # Get trainer actor by name - logger.info(f"Getting trainer actor by name {self.trainer_actor_name}") - self.trainer = ray.get_actor(self.trainer_actor_name) - - # Fetch model metadata from trainer - logger.info( - "Fetching model metadata from trainer (requires max_concurrency>1)" - ) - self.model_metadata = ray.get(self.trainer.get_model_metadata.remote()) - - def init(self): - if self.model_metadata is None: - raise RuntimeError("Must call init_metadata() before init()") - - # Initialize receiver with metadata - logger.info("Initializing receiver...") - self.receiver.init_all_workers_group(self.model_metadata) - self.initialized = True - logger.info("Receiver initialized") - return "initialized" - - def get_engine(self): - """Get the vLLM engine reference for RPC coordination.""" - if self.engine is None: - raise RuntimeError("Must call setup() first") - return self.engine - - def get_sample_output(self): - """Get a sample output to verify model works.""" - # Simple inference test - return "vllm_ready" - - @classmethod - def as_remote(cls, *args, **kwargs): - import ray - - # No GPUs needed for the actor itself - vLLM workers manage their own placement group (2 GPUs) - # AsyncVLLM service doesn't act as NCCL rank 0 when used with external trainer - return ray.remote(num_cpus=4, num_gpus=0, max_concurrency=4)(cls) - - class WorkerTransformer: - """Ray actor for transformer trainer (sender).""" - - def __init__(self, scheme_config: dict, model_name: str = "Qwen/Qwen2.5-0.5B"): - from torchrl.weight_update.llm.vllm_nccl import ( - get_model_metadata, - VLLMWeightSyncScheme, - ) - - # Create transformer model - self.transformer = TestWeightSyncVLLMNCCL._make_worker_transformer( - model_name - ) - - # Create scheme from config - self.scheme = VLLMWeightSyncScheme(**scheme_config) - - # Create sender - self.sender = self.scheme.create_sender() - self.sender.register_model(self.transformer) - - # Extract and store model metadata - self.model_metadata = get_model_metadata(self.transformer) - - def init(self, vllm_engine=None): - """Initialize sender with optional vLLM engine for RPC coordination. - - Args: - vllm_engine: Optional vLLM engine reference for calling collective_rpc - """ - if self.model_metadata is None: - raise RuntimeError("Must call init_metadata() before init()") - - self.sender.init_all_workers_group( - self.model_metadata, vllm_engine=vllm_engine - ) - self.initialized = True - logger.info("Trainer initialized") - return "initialized" - - def get_model_metadata(self): - """Get model metadata to share with receiver.""" - return self.model_metadata - - def update_weights(self, modify_weights: bool = False): - """Trigger a weight update broadcast. - - Args: - modify_weights: If True, modifies weights before broadcasting - for verification purposes. - - Returns: - str: "updated" status message - """ - - # Optionally modify weights for testing - if modify_weights: - with torch.no_grad(): - first_param = next(self.transformer.parameters()) - first_param.add_(0.01) - - # Broadcast weights to all vLLM workers - self.sender.update_weights() - return "updated" - - def get_first_param_sum(self): - """Get sum of first parameter for verification.""" - return next(self.transformer.parameters()).sum().item() - - @classmethod - def as_remote(cls, *args, **kwargs): - import ray - - return ray.remote(num_cpus=4, num_gpus=1, max_concurrency=4)(cls) - def test_weight_sync_vllm_collective_ray(self, request): """Test weight sync between transformer trainer and vLLM workers. @@ -661,7 +513,7 @@ def test_weight_sync_vllm_collective_ray(self, request): "Creating receiver actor first (vLLM workers need 2 GPUs via placement group)..." ) # Create receiver actor first - it will find trainer by name - receiver = TestWeightSyncVLLMNCCL.WorkerVLLM.as_remote().remote( + receiver = WorkerVLLMNCCL.as_remote().remote( scheme_config, model_name, trainer_actor_name="Trainer" ) @@ -673,7 +525,7 @@ def test_weight_sync_vllm_collective_ray(self, request): # Now create trainer actor (needs 1 GPU for training and NCCL rank 0) logger.info("Creating trainer actor (needs 1 GPU)...") trainer = ( - TestWeightSyncVLLMNCCL.WorkerTransformer.as_remote() + WorkerTransformerNCCL.as_remote() .options(name="Trainer") .remote(scheme_config, model_name) ) @@ -775,108 +627,6 @@ def _make_worker_transformer(model_name: str = "Qwen/Qwen2.5-0.5B"): transformer = transformer.cuda() return transformer - class WorkerVLLM: - """Ray actor for vLLM inference worker (receiver).""" - - def __init__(self, scheme_config: dict, model_name: str = "Qwen/Qwen2.5-0.5B"): - # Store config for deferred initialization - self.scheme_config = scheme_config - self.model_name = model_name - self.wrapper = None - self.engine = None - self.receiver = None - self.scheme = None - - def setup(self): - """Set up vLLM engine and receiver.""" - # Create vLLM wrapper - self.wrapper = TestWeightSyncVLLMDoubleBuffer._make_worker_vllm( - self.model_name - ) - self.engine = self.wrapper.model - - # Create scheme from config - from torchrl.weight_update.llm.vllm_double_buffer import ( - VLLMDoubleBufferSyncScheme, - ) - - self.scheme = VLLMDoubleBufferSyncScheme(**self.scheme_config) - - # Create receiver - self.receiver = self.scheme.create_receiver(self.engine) - logger.info("Receiver setup complete") - return "setup_complete" - - def poll_and_apply_weights(self): - """Poll for new weights and apply them to the engine.""" - if self.receiver is None: - raise RuntimeError("Must call setup() first") - - success = self.receiver.poll_and_apply() - return success - - def get_sample_output(self): - """Get a sample output to verify model works.""" - return "vllm_ready" - - @classmethod - def as_remote(cls, *args, **kwargs): - import ray - - # vLLM worker needs 1 GPU - return ray.remote(num_cpus=2, num_gpus=1, max_concurrency=4)(cls) - - class WorkerTransformer: - """Ray actor for transformer trainer (sender).""" - - def __init__(self, scheme_config: dict, model_name: str = "Qwen/Qwen2.5-0.5B"): - from torchrl.weight_update.llm.vllm_double_buffer import ( - VLLMDoubleBufferSyncScheme, - ) - - # Create transformer model - self.transformer = TestWeightSyncVLLMDoubleBuffer._make_worker_transformer( - model_name - ) - - # Create scheme from config - self.scheme = VLLMDoubleBufferSyncScheme(**scheme_config) - - # Create sender - self.sender = self.scheme.create_sender() - self.sender.register_model(self.transformer) - logger.info("Trainer setup complete") - - def update_weights(self, modify_weights: bool = False): - """Trigger a weight update by writing to shared storage. - - Args: - modify_weights: If True, modifies weights before writing - for verification purposes. - - Returns: - str: "updated" status message - """ - # Optionally modify weights for testing - if modify_weights: - with torch.no_grad(): - first_param = next(self.transformer.parameters()) - first_param.add_(0.01) - - # Write weights to shared storage - self.sender.update_weights() - return "updated" - - def get_first_param_sum(self): - """Get sum of first parameter for verification.""" - return next(self.transformer.parameters()).sum().item() - - @classmethod - def as_remote(cls, *args, **kwargs): - import ray - - return ray.remote(num_cpus=2, num_gpus=1, max_concurrency=4)(cls) - def test_weight_sync_vllm_double_buffer_ray(self, tmpdir, request): """Test weight sync using double-buffered storage with Ray. @@ -911,16 +661,14 @@ def test_weight_sync_vllm_double_buffer_ray(self, tmpdir, request): # Create trainer actor logger.info("Creating trainer actor...") - trainer = ( - TestWeightSyncVLLMDoubleBuffer.WorkerTransformer.as_remote().remote( - scheme_config, model_name - ) + trainer = WorkerTransformerDoubleBuffer.as_remote().remote( + scheme_config, model_name ) logger.info("Trainer actor created") # Create receiver actor logger.info("Creating receiver actor...") - receiver = TestWeightSyncVLLMDoubleBuffer.WorkerVLLM.as_remote().remote( + receiver = WorkerVLLMDoubleBuffer.as_remote().remote( scheme_config, model_name ) From 9604f48b8de3bc814b8801ace778b29fa97f7485 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 19 Oct 2025 18:25:09 -0700 Subject: [PATCH 2/7] Update [ghstack-poisoned] --- .gitignore | 8 ++ .../expert-iteration-async.py | 16 +-- .../expert-iteration/expert-iteration-sync.py | 16 +-- sota-implementations/grpo/grpo-async.py | 16 +-- sota-implementations/grpo/grpo-sync.py | 16 +-- test/llm/test_updaters.py | 2 +- torchrl/__init__.py | 4 + torchrl/_utils.py | 100 ++++++++++++++++++ 8 files changed, 125 insertions(+), 53 deletions(-) diff --git a/.gitignore b/.gitignore index 3821a3a77da..52927eff312 100644 --- a/.gitignore +++ b/.gitignore @@ -189,3 +189,11 @@ log Roms scratch/* + +# Large directories from git history that should not be committed +dev/ +main/ +*.html + +# Additional cache directories +.ruff_cache/ diff --git a/sota-implementations/expert-iteration/expert-iteration-async.py b/sota-implementations/expert-iteration/expert-iteration-async.py index 75f8d39462d..7b736fa24e4 100644 --- a/sota-implementations/expert-iteration/expert-iteration-async.py +++ b/sota-implementations/expert-iteration/expert-iteration-async.py @@ -12,7 +12,7 @@ import hydra -from torchrl import torchrl_logger +from torchrl import merge_ray_runtime_env, torchrl_logger from torchrl.data.llm.history import History from torchrl.record.loggers.wandb import WandbLogger from torchrl.weight_update.llm import get_model_metadata @@ -397,19 +397,9 @@ def main(cfg): if not k.startswith("_") } - # Add computed GPU configuration + # Add computed GPU configuration and merge with default runtime_env ray_init_config["num_gpus"] = device_config["ray_num_gpus"] - # Ensure runtime_env and env_vars exist - if "runtime_env" not in ray_init_config: - ray_init_config["runtime_env"] = {} - if not isinstance(ray_init_config["runtime_env"], dict): - ray_init_config["runtime_env"] = dict(ray_init_config["runtime_env"]) - if "env_vars" not in ray_init_config["runtime_env"]: - ray_init_config["runtime_env"]["env_vars"] = {} - if not isinstance(ray_init_config["runtime_env"]["env_vars"], dict): - ray_init_config["runtime_env"]["env_vars"] = dict( - ray_init_config["runtime_env"]["env_vars"] - ) + ray_init_config = merge_ray_runtime_env(ray_init_config) torchrl_logger.info(f"Ray init config: {ray_init_config=}") ray.init(**ray_init_config) diff --git a/sota-implementations/expert-iteration/expert-iteration-sync.py b/sota-implementations/expert-iteration/expert-iteration-sync.py index 126c188b6e9..f5af9d245d9 100644 --- a/sota-implementations/expert-iteration/expert-iteration-sync.py +++ b/sota-implementations/expert-iteration/expert-iteration-sync.py @@ -12,7 +12,7 @@ import hydra -from torchrl import torchrl_logger +from torchrl import merge_ray_runtime_env, torchrl_logger from torchrl.data.llm.history import History from torchrl.record.loggers.wandb import WandbLogger from torchrl.weight_update.llm import get_model_metadata @@ -398,19 +398,9 @@ def main(cfg): if not k.startswith("_") } - # Add computed GPU configuration + # Add computed GPU configuration and merge with default runtime_env ray_init_config["num_gpus"] = device_config["ray_num_gpus"] - # Ensure runtime_env and env_vars exist - if "runtime_env" not in ray_init_config: - ray_init_config["runtime_env"] = {} - if not isinstance(ray_init_config["runtime_env"], dict): - ray_init_config["runtime_env"] = dict(ray_init_config["runtime_env"]) - if "env_vars" not in ray_init_config["runtime_env"]: - ray_init_config["runtime_env"]["env_vars"] = {} - if not isinstance(ray_init_config["runtime_env"]["env_vars"], dict): - ray_init_config["runtime_env"]["env_vars"] = dict( - ray_init_config["runtime_env"]["env_vars"] - ) + ray_init_config = merge_ray_runtime_env(ray_init_config) torchrl_logger.info(f"Ray init config: {ray_init_config=}") ray.init(**ray_init_config) diff --git a/sota-implementations/grpo/grpo-async.py b/sota-implementations/grpo/grpo-async.py index e94d25c56fc..933b8832b18 100644 --- a/sota-implementations/grpo/grpo-async.py +++ b/sota-implementations/grpo/grpo-async.py @@ -13,7 +13,7 @@ import hydra -from torchrl import torchrl_logger +from torchrl import merge_ray_runtime_env, torchrl_logger from torchrl.data.llm.history import History from torchrl.record.loggers.wandb import WandbLogger from torchrl.weight_update.llm import get_model_metadata @@ -319,19 +319,9 @@ def main(cfg): if not k.startswith("_") } - # Add computed GPU configuration + # Add computed GPU configuration and merge with default runtime_env ray_init_config["num_gpus"] = device_config["ray_num_gpus"] - # Ensure runtime_env and env_vars exist - if "runtime_env" not in ray_init_config: - ray_init_config["runtime_env"] = {} - if not isinstance(ray_init_config["runtime_env"], dict): - ray_init_config["runtime_env"] = dict(ray_init_config["runtime_env"]) - if "env_vars" not in ray_init_config["runtime_env"]: - ray_init_config["runtime_env"]["env_vars"] = {} - if not isinstance(ray_init_config["runtime_env"]["env_vars"], dict): - ray_init_config["runtime_env"]["env_vars"] = dict( - ray_init_config["runtime_env"]["env_vars"] - ) + ray_init_config = merge_ray_runtime_env(ray_init_config) torchrl_logger.info(f"Ray init config: {ray_init_config=}") ray_managed_externally = os.environ.get("RAY_CLUSTER_MANAGED_EXTERNALLY") if ray_managed_externally: diff --git a/sota-implementations/grpo/grpo-sync.py b/sota-implementations/grpo/grpo-sync.py index 309581d6c75..48538f5e5df 100644 --- a/sota-implementations/grpo/grpo-sync.py +++ b/sota-implementations/grpo/grpo-sync.py @@ -12,7 +12,7 @@ import hydra -from torchrl import torchrl_logger +from torchrl import merge_ray_runtime_env, torchrl_logger from torchrl.data.llm.history import History from torchrl.record.loggers.wandb import WandbLogger from torchrl.weight_update.llm import get_model_metadata @@ -319,19 +319,9 @@ def main(cfg): if not k.startswith("_") } - # Add computed GPU configuration + # Add computed GPU configuration and merge with default runtime_env ray_init_config["num_gpus"] = device_config["ray_num_gpus"] - # Ensure runtime_env and env_vars exist - if "runtime_env" not in ray_init_config: - ray_init_config["runtime_env"] = {} - if not isinstance(ray_init_config["runtime_env"], dict): - ray_init_config["runtime_env"] = dict(ray_init_config["runtime_env"]) - if "env_vars" not in ray_init_config["runtime_env"]: - ray_init_config["runtime_env"]["env_vars"] = {} - if not isinstance(ray_init_config["runtime_env"]["env_vars"], dict): - ray_init_config["runtime_env"]["env_vars"] = dict( - ray_init_config["runtime_env"]["env_vars"] - ) + ray_init_config = merge_ray_runtime_env(ray_init_config) torchrl_logger.info(f"Ray init config: {ray_init_config=}") ray_managed_externally = os.environ.get("RAY_CLUSTER_MANAGED_EXTERNALLY") if ray_managed_externally: diff --git a/test/llm/test_updaters.py b/test/llm/test_updaters.py index bb6b694cf3e..2ed746aca8e 100644 --- a/test/llm/test_updaters.py +++ b/test/llm/test_updaters.py @@ -41,7 +41,7 @@ def get_open_port(): if _has_ray: import ray - from .ray_helpers import ( + from ray_helpers import ( WorkerTransformerDoubleBuffer, WorkerTransformerNCCL, WorkerVLLMDoubleBuffer, diff --git a/torchrl/__init__.py b/torchrl/__init__.py index 1594bc1e3e4..5ea95ae26a8 100644 --- a/torchrl/__init__.py +++ b/torchrl/__init__.py @@ -54,8 +54,10 @@ from torchrl._utils import ( auto_unwrap_transformed_env, compile_with_warmup, + get_ray_default_runtime_env, implement_for, logger, + merge_ray_runtime_env, set_auto_unwrap_transformed_env, timeit, ) @@ -113,7 +115,9 @@ def _inv(self): __all__ = [ "auto_unwrap_transformed_env", "compile_with_warmup", + "get_ray_default_runtime_env", "implement_for", + "merge_ray_runtime_env", "set_auto_unwrap_transformed_env", "timeit", "logger", diff --git a/torchrl/_utils.py b/torchrl/_utils.py index d8a06e37641..50cfa8af7d3 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -962,3 +962,103 @@ def as_remote(cls, remote_config: dict[str, Any] | None = None): remote_collector = ray.remote(**remote_config)(cls) remote_collector.is_remote = True return remote_collector + + +def get_ray_default_runtime_env() -> dict[str, Any]: + """Get the default Ray runtime environment configuration for TorchRL. + + This function returns a runtime environment configuration that excludes + large directories and files that should not be uploaded to Ray workers. + This helps prevent issues with Ray's working_dir size limits (512MB default). + + Returns: + dict: A dictionary containing the default runtime_env configuration with + excludes for common large directories. + + Examples: + >>> import ray + >>> from torchrl._utils import get_ray_default_runtime_env + >>> ray_init_config = {"num_cpus": 4} + >>> ray_init_config["runtime_env"] = get_ray_default_runtime_env() + >>> ray.init(**ray_init_config) + + Note: + The excludes list includes: + - Virtual environments (.venv/, venv/, etc.) + - Test files and caches + - Documentation builds + - Benchmarks + - Examples and tutorials + - CI/CD configurations + - IDE configurations + + """ + return { + "excludes": [ + ".venv/", + "venv/", + "env/", + "ENV/", + "env.bak/", + "venv.bak/", + "test/", + "tests/", + "docs/", + "benchmarks/", + "tutorials/", + "examples/", + ".github/", + ".pytest_cache/", + ".mypy_cache/", + ".ruff_cache/", + "__pycache__/", + "*.pyc", + "*.pyo", + "*.egg-info/", + ".idea/", + ".vscode/", + "dev/", + "main/", + "*.html", + ] + } + + +def merge_ray_runtime_env(ray_init_config: dict[str, Any]) -> dict[str, Any]: + """Merge user-provided ray_init_config with default runtime_env excludes. + + This function ensures that the default TorchRL runtime_env excludes are applied + to prevent large directories from being uploaded to Ray workers, while preserving + any user-provided configuration. + + Args: + ray_init_config (dict): The ray init configuration dictionary to merge. + + Returns: + dict: The merged configuration with default runtime_env excludes applied. + + Examples: + >>> from torchrl._utils import merge_ray_runtime_env + >>> ray_init_config = {"num_cpus": 4} + >>> ray_init_config = merge_ray_runtime_env(ray_init_config) + >>> ray.init(**ray_init_config) + + """ + default_runtime_env = get_ray_default_runtime_env() + runtime_env = ray_init_config.setdefault("runtime_env", {}) + + if not isinstance(runtime_env, dict): + runtime_env = dict(runtime_env) + ray_init_config["runtime_env"] = runtime_env + + # Merge excludes lists + excludes = runtime_env.get("excludes", []) + runtime_env["excludes"] = list(set(default_runtime_env["excludes"] + excludes)) + + # Ensure env_vars exists + if "env_vars" not in runtime_env: + runtime_env["env_vars"] = {} + elif not isinstance(runtime_env["env_vars"], dict): + runtime_env["env_vars"] = dict(runtime_env["env_vars"]) + + return ray_init_config From 5d400b46dfa7c69f73f5b906bac552d807c45750 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 19 Oct 2025 20:56:28 -0700 Subject: [PATCH 3/7] Update [ghstack-poisoned] --- test/llm/test_updaters.py | 2 +- test/test_rb.py | 19 ++++++++++++++++ torchrl/data/replay_buffers/ray_buffer.py | 1 + torchrl/testing/__init__.py | 24 ++++++++++++++++++++ {test/llm => torchrl/testing}/ray_helpers.py | 4 ++-- 5 files changed, 47 insertions(+), 3 deletions(-) create mode 100644 torchrl/testing/__init__.py rename {test/llm => torchrl/testing}/ray_helpers.py (98%) diff --git a/test/llm/test_updaters.py b/test/llm/test_updaters.py index 2ed746aca8e..02e2efed163 100644 --- a/test/llm/test_updaters.py +++ b/test/llm/test_updaters.py @@ -41,7 +41,7 @@ def get_open_port(): if _has_ray: import ray - from ray_helpers import ( + from torchrl.testing import ( WorkerTransformerDoubleBuffer, WorkerTransformerNCCL, WorkerVLLMDoubleBuffer, diff --git a/test/test_rb.py b/test/test_rb.py index f3abf85f9a0..bc6acaeb3be 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -4188,6 +4188,25 @@ def test_ray_rb_iter(self): finally: rb.close() + def test_ray_rb_serialization(self): + import ray + + class Worker: + def __init__(self, rb): + self.rb = rb + + def run(self): + self.rb.extend(TensorDict({"x": torch.ones(100)}, batch_size=100)) + + rb = RayReplayBuffer( + storage=partial(LazyTensorStorage, 100), ray_init_config={"num_cpus": 1} + ) + try: + remote_worker = ray.remote(Worker).remote(rb) + ray.get(remote_worker.run.remote()) + finally: + rb.close() + class TestSharedStorageInit: def worker(self, rb, worker_id, queue): diff --git a/torchrl/data/replay_buffers/ray_buffer.py b/torchrl/data/replay_buffers/ray_buffer.py index e18b8650f8d..d5cf0474b8b 100644 --- a/torchrl/data/replay_buffers/ray_buffer.py +++ b/torchrl/data/replay_buffers/ray_buffer.py @@ -147,6 +147,7 @@ def __init__( else: self.has_gpu = False self._rb = remote_cls(*args, **kwargs) + self._delayed_init = False def close(self): """Terminates the Ray actor associated with this replay buffer.""" diff --git a/torchrl/testing/__init__.py b/torchrl/testing/__init__.py new file mode 100644 index 00000000000..0f942e4c85f --- /dev/null +++ b/torchrl/testing/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Testing utilities for TorchRL. + +This module provides helper classes and utilities for testing TorchRL functionality, +particularly for distributed and Ray-based tests that require importable classes. +""" + +from torchrl.testing.ray_helpers import ( + WorkerTransformerDoubleBuffer, + WorkerTransformerNCCL, + WorkerVLLMDoubleBuffer, + WorkerVLLMNCCL, +) + +__all__ = [ + "WorkerVLLMNCCL", + "WorkerTransformerNCCL", + "WorkerVLLMDoubleBuffer", + "WorkerTransformerDoubleBuffer", +] diff --git a/test/llm/ray_helpers.py b/torchrl/testing/ray_helpers.py similarity index 98% rename from test/llm/ray_helpers.py rename to torchrl/testing/ray_helpers.py index fc4c448606d..b9d6f84a293 100644 --- a/test/llm/ray_helpers.py +++ b/torchrl/testing/ray_helpers.py @@ -6,8 +6,8 @@ """Helper classes for Ray-based weight synchronization tests. This module contains Ray actor classes that need to be importable by Ray workers. -These classes are used in test_updaters.py but must be defined at module level -so Ray can serialize and import them on remote workers. +These classes are used in tests but must be defined at module level in a proper +Python package (not in test files) so Ray can serialize and import them on remote workers. """ import torch From a2c297367190dc3837d2cebbc58721f96f4844da Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 19 Oct 2025 21:16:04 -0700 Subject: [PATCH 4/7] Update [ghstack-poisoned] --- test/llm/test_envs.py | 364 +-------------------------------- torchrl/testing/ray_helpers.py | 1 - 2 files changed, 2 insertions(+), 363 deletions(-) diff --git a/test/llm/test_envs.py b/test/llm/test_envs.py index 91876b44c11..831b1d79d5e 100644 --- a/test/llm/test_envs.py +++ b/test/llm/test_envs.py @@ -5,7 +5,6 @@ from __future__ import annotations import argparse -import contextlib import importlib.util import random import re @@ -14,27 +13,15 @@ import pytest import torch -from mocking_classes_llm import DummyStrDataLoader, DummyTensorDataLoader - -from tensordict import ( - lazy_stack, - NonTensorData, - NonTensorStack, - set_capture_non_tensor_stack, - set_list_to_stack, - TensorDict, -) + +from tensordict import lazy_stack, set_list_to_stack, TensorDict from torchrl._utils import logger as torchrl_logger from torchrl.data.llm.history import History -from torchrl.envs import StepCounter from torchrl.envs.llm import ( - as_padded_tensor, ChatEnv, - DataLoadingPrimer, GSM8KEnv, KLRewardTransform, - LLMEnv, make_gsm8k_env, RetrieveKL, ) @@ -82,353 +69,6 @@ def set_list_to_stack_for_test(): return -class TestLLMEnv: - @pytest.fixture(scope="class", autouse=True) - def set_capture(self): - with set_capture_non_tensor_stack(False): - yield None - return - - @pytest.mark.skipif(not _has_transformers, reason="test requires transformers") - @pytest.mark.parametrize( - "from_text,stack_method", - [ - [True, None], - [False, "as_padded_tensor"], - # TODO: a bit experimental, fails with check_env_specs - # [False, "as_nested_tensor"], - [False, None], - ], - ) - @pytest.mark.parametrize("dl_batch_size", [1, 4]) - @pytest.mark.parametrize("env_batch_size", [None, 0, (), 4]) - @pytest.mark.parametrize("device", [None, "cpu"]) - def test_llm_env( - self, from_text, stack_method, device, dl_batch_size, env_batch_size - ): - if from_text: - primer = DataLoadingPrimer( - dataloader=DummyStrDataLoader(batch_size=dl_batch_size), - batch_size=env_batch_size, - ) - else: - if stack_method is None: - stack_method = as_padded_tensor - primer = DataLoadingPrimer( - dataloader=DummyTensorDataLoader( - batch_size=dl_batch_size, padding=True - ), - stack_method=stack_method, - batch_size=env_batch_size, - ) - with pytest.warns(UserWarning, match="eos_token_id"): - env = LLMEnv( - from_text=from_text, - device=device, - batch_size=primer.batch_size, - ) - env = env.append_transform(primer) - if env_batch_size is None: - assert env.batch_size == torch.Size((dl_batch_size,)) - else: - if not isinstance(env_batch_size, tuple): - env_batch_size = ( - torch.Size(()) - if env_batch_size == 0 - else torch.Size((env_batch_size,)) - ) - assert env.batch_size == env_batch_size - - env.check_env_specs(break_when_any_done="both") - - @pytest.mark.skipif(not _has_transformers, reason="test requires transformers") - @pytest.mark.parametrize("tokenizer", [True, False]) - @pytest.mark.parametrize( - "from_text,stack_method", - [ - [True, None], - [False, "as_padded_tensor"], - [False, None], - ], - ) - @pytest.mark.parametrize("device", [None, "cpu"]) - @pytest.mark.parametrize("dl_batch_size", [1, 4]) - @pytest.mark.parametrize("env_batch_size", [None, 0, (), 4]) - def test_llm_from_dataloader( - self, - from_text, - stack_method, - device, - dl_batch_size, - env_batch_size, - tokenizer, - ): - from transformers import AutoTokenizer - - if tokenizer and from_text: - tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") - else: - tokenizer = None - if from_text: - kwargs = { - "dataloader": DummyStrDataLoader(batch_size=dl_batch_size), - } - else: - if stack_method is None: - stack_method = as_padded_tensor - kwargs = { - "dataloader": DummyTensorDataLoader( - padding=True, batch_size=dl_batch_size - ), - "stack_method": stack_method, - } - kwargs.update( - { - "batch_size": env_batch_size, - "from_text": from_text, - "device": device, - "has_attention": False, - "tokenizer": tokenizer, - } - ) - with pytest.warns(UserWarning, match="eos_token_id"): - env = LLMEnv.from_dataloader(**kwargs) - if env_batch_size is None: - assert env.batch_size == torch.Size((dl_batch_size,)) - else: - if not isinstance(env_batch_size, tuple): - env_batch_size = ( - torch.Size(()) - if env_batch_size == 0 - else torch.Size((env_batch_size,)) - ) - assert env.batch_size == env_batch_size - env.check_env_specs(break_when_any_done="both") - - def policy(td): - if from_text and tokenizer is None: - if not td.shape: - td[LLMEnv._DEFAULT_ACTION_STR_KEY] = NonTensorData( - "", device=device - ) - else: - td[LLMEnv._DEFAULT_ACTION_STR_KEY] = NonTensorStack( - *[ - NonTensorData("", device=device) - for _ in range(td.shape[0]) - ] - ) - else: - td[LLMEnv._DEFAULT_ACTION_TOKENS_KEY] = torch.ones( - td.shape + (1,), dtype=torch.int64 - ) - return td - - r = env.rollout(10, policy) - if env.batch_size == (): - assert r.ndim == 1 - r = r.unsqueeze(0) - else: - assert r.ndim == 2 - if from_text and tokenizer is None: - assert isinstance(r[0, 0][LLMEnv._DEFAULT_STR_KEY], str) - assert isinstance(r[0, 1][LLMEnv._DEFAULT_STR_KEY], str) - assert ( - r[0, 0][LLMEnv._DEFAULT_STR_KEY] - == r[0, 1][LLMEnv._DEFAULT_STR_KEY][ - : -len(r[0, 0][LLMEnv._DEFAULT_ACTION_STR_KEY]) - ] - ), ( - r[0, 0][LLMEnv._DEFAULT_STR_KEY], - r[0, 0][LLMEnv._DEFAULT_ACTION_STR_KEY], - r[0, 0]["next", LLMEnv._DEFAULT_STR_KEY], - r[0, 1][LLMEnv._DEFAULT_STR_KEY], - ) - assert ( - r[0, 1][LLMEnv._DEFAULT_STR_KEY] - == r[0, 2][LLMEnv._DEFAULT_STR_KEY][ - : -len(r[0, 1][LLMEnv._DEFAULT_ACTION_STR_KEY]) - ] - ) - assert ( - r[-1, 0][LLMEnv._DEFAULT_STR_KEY] - == r[-1, 1][LLMEnv._DEFAULT_STR_KEY][ - : -len(r[-1, 0][LLMEnv._DEFAULT_ACTION_STR_KEY]) - ] - ) - assert ( - r[-1, 1][LLMEnv._DEFAULT_STR_KEY] - == r[-1, 2][LLMEnv._DEFAULT_STR_KEY][ - : -len(r[-1, 1][LLMEnv._DEFAULT_ACTION_STR_KEY]) - ] - ) - elif tokenizer is None: - assert ( - r[0, 0][LLMEnv._DEFAULT_TOKEN_KEY] - == r[0, 1][LLMEnv._DEFAULT_TOKEN_KEY][:-1] - ).all() - assert ( - r[0, 1][LLMEnv._DEFAULT_TOKEN_KEY] - == r[0, 2][LLMEnv._DEFAULT_TOKEN_KEY][:-1] - ).all() - assert ( - r[-1, 0][LLMEnv._DEFAULT_TOKEN_KEY] - == r[-1, 1][LLMEnv._DEFAULT_TOKEN_KEY][:-1] - ).all() - assert ( - r[-1, 1][LLMEnv._DEFAULT_TOKEN_KEY] - == r[-1, 2][LLMEnv._DEFAULT_TOKEN_KEY][:-1] - ).all() - - @pytest.mark.parametrize( - "from_text,stack_method", - [ - [True, None], - [False, "as_padded_tensor"], - # TODO: a bit experimental, fails with check_env_specs - # [False, "as_nested_tensor"], - [False, None], - ], - ) - @pytest.mark.parametrize("device", [None, "cpu"]) - @pytest.mark.parametrize("dl_batch_size", [1, 4]) - @pytest.mark.parametrize("env_batch_size", [None, 0, (), 4]) - @pytest.mark.parametrize("repeats", [3]) - def test_llm_from_dataloader_repeats( - self, from_text, stack_method, device, env_batch_size, dl_batch_size, repeats - ): - if from_text: - kwargs = { - "dataloader": DummyStrDataLoader(batch_size=dl_batch_size), - "repeats": repeats, - } - else: - if stack_method is None: - stack_method = as_padded_tensor - kwargs = { - "dataloader": DummyTensorDataLoader( - padding=True, batch_size=dl_batch_size - ), - "stack_method": stack_method, - "repeats": repeats, - } - kwargs.update( - { - "batch_size": env_batch_size, - "from_text": from_text, - "device": device, - "has_attention": False, - } - ) - with pytest.warns(UserWarning, match="eos_token_id"): - env = LLMEnv.from_dataloader(**kwargs) - assert env.transform.repeats == repeats - - max_steps = 3 - env.append_transform(StepCounter(max_steps=max_steps)) - - def policy(td): - if from_text: - if not td.shape: - td[LLMEnv._DEFAULT_ACTION_STR_KEY] = "" - else: - td[LLMEnv._DEFAULT_ACTION_STR_KEY] = NonTensorStack( - *["" for _ in range(td.shape[0])] - ) - else: - td[LLMEnv._DEFAULT_ACTION_TOKENS_KEY] = torch.ones( - td.shape + (1,), dtype=torch.int64 - ) - return td - - r = env.rollout(100, policy, break_when_any_done=False) - # check that r at reset is always the same - r_reset = r[..., ::max_steps] - if from_text: - all_strings = r_reset.view(-1)[LLMEnv._DEFAULT_STR_KEY] - assert sum(s == all_strings[0] for s in all_strings) == repeats - assert sum(s == all_strings[repeats] for s in all_strings) == repeats - assert sum(s == all_strings[repeats * 2] for s in all_strings) == repeats - else: - all_tokens = r_reset.view(-1)[LLMEnv._DEFAULT_TOKEN_KEY] - assert sum((s == all_tokens[0]).all() for s in all_tokens) == repeats - assert sum((s == all_tokens[repeats]).all() for s in all_tokens) == repeats - assert ( - sum((s == all_tokens[repeats * 2]).all() for s in all_tokens) == repeats - ) - - @pytest.mark.parametrize( - "from_text,stack_method", - [ - [True, None], - [False, "as_padded_tensor"], - ], - ) - @pytest.mark.parametrize("device", [None]) - @pytest.mark.parametrize("dl_batch_size", [1, 4]) - @pytest.mark.parametrize("env_batch_size", [None, 0, (), 4]) - @pytest.mark.parametrize("repeats", [3]) - @pytest.mark.parametrize( - "assign_reward,assign_done", [[True, False], [True, True], [False, True]] - ) - def test_done_and_reward( - self, - from_text, - stack_method, - device, - env_batch_size, - dl_batch_size, - repeats, - assign_reward, - assign_done, - ): - with pytest.raises( - ValueError, match="from_text" - ) if from_text else contextlib.nullcontext(): - if from_text: - kwargs = { - "dataloader": DummyStrDataLoader(batch_size=dl_batch_size), - "repeats": repeats, - "assign_reward": assign_reward, - "assign_done": assign_done, - } - else: - if stack_method is None: - stack_method = as_padded_tensor - kwargs = { - "dataloader": DummyTensorDataLoader( - padding=True, batch_size=dl_batch_size - ), - "stack_method": stack_method, - "repeats": repeats, - "assign_reward": assign_reward, - "assign_done": assign_done, - } - kwargs.update( - { - "batch_size": env_batch_size, - "from_text": from_text, - "device": device, - "has_attention": False, - } - ) - with pytest.warns(UserWarning, match="eos_token_id"): - env = LLMEnv.from_dataloader(**kwargs) - # We want to make sure that transforms that rely on the done state work appropriately - env.append_transform(StepCounter(max_steps=10)) - - def policy(td): - td[LLMEnv._DEFAULT_ACTION_TOKENS_KEY] = torch.ones( - td.shape + (torch.randint(10, (1,)).item(),), dtype=torch.int64 - ) - return td - - r = env.rollout(100, policy, break_when_any_done=False) - if assign_done: - assert "terminated" in r - assert "done" in r - - class TestChatEnv: @pytest.fixture def tokenizer(self): diff --git a/torchrl/testing/ray_helpers.py b/torchrl/testing/ray_helpers.py index b9d6f84a293..758f5d858a6 100644 --- a/torchrl/testing/ray_helpers.py +++ b/torchrl/testing/ray_helpers.py @@ -159,7 +159,6 @@ def update_weights(self, modify_weights: bool = False): Returns: str: "updated" status message """ - # Optionally modify weights for testing if modify_weights: with torch.no_grad(): From ba0f22184e84debd0a3856ea8143647df0c8d99c Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 20 Oct 2025 21:18:08 -0700 Subject: [PATCH 5/7] Update [ghstack-poisoned] --- torchrl/envs/transforms/ray_service.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchrl/envs/transforms/ray_service.py b/torchrl/envs/transforms/ray_service.py index 5eca9d1e66d..8a40fe20fa7 100644 --- a/torchrl/envs/transforms/ray_service.py +++ b/torchrl/envs/transforms/ray_service.py @@ -169,6 +169,10 @@ def _ray(self): ) return ray + @_ray.setter + def _ray(self, value): + self._ray = value + def __init__( self, *, From 675f0d96a6ec0693bee3aa2eeb080b267eb749f1 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 20 Oct 2025 21:21:09 -0700 Subject: [PATCH 6/7] Update [ghstack-poisoned] --- torchrl/envs/transforms/ray_service.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/torchrl/envs/transforms/ray_service.py b/torchrl/envs/transforms/ray_service.py index 8a40fe20fa7..0da863863fa 100644 --- a/torchrl/envs/transforms/ray_service.py +++ b/torchrl/envs/transforms/ray_service.py @@ -160,6 +160,9 @@ def _create_actor(self, **kwargs): @property def _ray(self): + ray = self.__dict__.get("_ray_val", None) + if ray is not None: + return ray # Import ray here to avoid requiring it as a dependency try: import ray @@ -167,11 +170,17 @@ def _ray(self): raise ImportError( "Ray is required for RayTransform. Install with: pip install ray" ) + self.__dict__["_ray_val"] = ray return ray @_ray.setter def _ray(self, value): - self._ray = value + self.__dict__["_ray_val"] = value + + def __getstate__(self): + state = super().__getstate__() + state.pop("_ray_val", None) + return state def __init__( self, From 58f68072e253ede8429280b2ff364c0f26013984 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 21 Oct 2025 20:48:33 -0700 Subject: [PATCH 7/7] Update [ghstack-poisoned] --- test/llm/test_data.py | 2 +- test/llm/test_envs.py | 8 ++++---- torchrl/envs/llm/transforms/tools.py | 5 ++++- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/test/llm/test_data.py b/test/llm/test_data.py index caf654ec91d..53b0438a844 100644 --- a/test/llm/test_data.py +++ b/test/llm/test_data.py @@ -13,7 +13,7 @@ import torch from tensordict import lazy_stack, set_list_to_stack, TensorDict -from torchrl import torchrl_logger +from torchrl import logger as torchrl_logger from torchrl.data import ( History, diff --git a/test/llm/test_envs.py b/test/llm/test_envs.py index 831b1d79d5e..cc06aa33d56 100644 --- a/test/llm/test_envs.py +++ b/test/llm/test_envs.py @@ -375,7 +375,7 @@ def test_python_interpreter_single_batch(self): "```python\n" "print(1 + 1)\n" "```<|im_end|>\n" - " <|im_start|>user\n" + " <|im_start|>tool\n" "\n" "Code block 1 executed successfully:\n" "2\n" @@ -395,7 +395,7 @@ def test_python_interpreter_single_batch(self): content="Here is a python code to execute:\n```python\nprint(1 + 1)\n```", ), History( - role="user", + role="tool", content="\nCode block 1 executed successfully:\n2\n\n", tool_responses=["Code block 1 executed successfully:\n2\n"], ), @@ -478,7 +478,7 @@ def test_python_interpreter_persistent(self): "```python\n" "a=1\n" "```<|im_end|>\n" - " <|im_start|>user\n" + " <|im_start|>tool\n" "\n" "Code block 1 executed successfully:\n" "\n" @@ -489,7 +489,7 @@ def test_python_interpreter_persistent(self): "a+=1\n" "assert a == 2\n" "```<|im_end|>\n" - " <|im_start|>user\n" + " <|im_start|>tool\n" "\n" "Code block 1 executed successfully:\n" "\n" diff --git a/torchrl/envs/llm/transforms/tools.py b/torchrl/envs/llm/transforms/tools.py index 22fccb27cba..f0940b1b1aa 100644 --- a/torchrl/envs/llm/transforms/tools.py +++ b/torchrl/envs/llm/transforms/tools.py @@ -530,7 +530,10 @@ def _step( procs = [] # Iterate over env batch-size - for i, t in enumerate(local_history.content): + content = local_history.content + if isinstance(content, str): + content = [content] + for i, t in enumerate(content): results = self._process_llm_response(t, i) if len(results) == 0: procs.append(None)