From 14eeb4a67642669f8645ef6bd6cbca4841c141db Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 6 Nov 2025 16:58:35 +0000 Subject: [PATCH 1/3] Update [ghstack-poisoned] --- test/conftest.py | 24 ++ test/llm/libs/test_mlgym.py | 2 + test/services/test_python_executor_service.py | 269 ++++++++++++++++++ test/{ => services}/test_services.py | 0 test/services/test_services_fixtures.py | 46 +++ torchrl/testing/__init__.py | 8 + .../testing/llm_mocks.py | 23 -- 7 files changed, 349 insertions(+), 23 deletions(-) create mode 100644 test/services/test_python_executor_service.py rename test/{ => services}/test_services.py (100%) create mode 100644 test/services/test_services_fixtures.py rename test/llm/conftest.py => torchrl/testing/llm_mocks.py (76%) diff --git a/test/conftest.py b/test/conftest.py index b2adfc3d984..62f8010e74d 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -155,3 +155,27 @@ def maybe_fork_ParallelEnv(request): ): return functools.partial(ParallelEnv, mp_start_method="fork") return ParallelEnv + + +# LLM testing fixtures +@pytest.fixture +def mock_transformer_model(): + """Fixture that provides a mock transformer model factory.""" + from torchrl.testing import MockTransformerModel + + def _make_model( + vocab_size: int = 1024, device: torch.device | str | int = "cpu" + ) -> MockTransformerModel: + """Make a mock transformer model.""" + device = torch.device(device) + return MockTransformerModel(vocab_size, device) + + return _make_model + + +@pytest.fixture +def mock_tokenizer(): + """Fixture that provides a mock tokenizer.""" + from transformers import AutoTokenizer + + return AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B") diff --git a/test/llm/libs/test_mlgym.py b/test/llm/libs/test_mlgym.py index 2632c439fdd..a86732e622f 100644 --- a/test/llm/libs/test_mlgym.py +++ b/test/llm/libs/test_mlgym.py @@ -16,6 +16,8 @@ from torchrl.envs.llm import make_mlgym from torchrl.modules.llm import TransformersWrapper +pytest.importorskip("mlgym") + class TestMLGYM: def test_mlgym_specs(self): diff --git a/test/services/test_python_executor_service.py b/test/services/test_python_executor_service.py new file mode 100644 index 00000000000..43e85d3f0aa --- /dev/null +++ b/test/services/test_python_executor_service.py @@ -0,0 +1,269 @@ +"""Tests for PythonExecutorService with Ray service registry.""" + +import pytest + +# Skip all tests if Ray is not available +pytest.importorskip("ray") + +import ray +from torchrl.envs.llm.transforms import PythonExecutorService, PythonInterpreter +from torchrl.services import get_services + + +@pytest.fixture +def ray_init(): + """Initialize Ray for tests.""" + if not ray.is_initialized(): + ray.init() + yield + if ray.is_initialized(): + ray.shutdown() + + +class TestPythonExecutorService: + """Test suite for PythonExecutorService.""" + + def test_service_initialization(self, ray_init): + """Test that the service can be created and registered.""" + namespace = "test_executor_init" + services = get_services(backend="ray", namespace=namespace) + + try: + services.register( + "python_executor", + PythonExecutorService, + pool_size=2, + timeout=5.0, + num_cpus=2, + max_concurrency=2, + ) + + # Verify it was registered + assert "python_executor" in services + + # Get the service + executor = services["python_executor"] + assert executor is not None + + finally: + services.reset() + + def test_service_execution(self, ray_init): + """Test that the service can execute Python code.""" + namespace = "test_executor_exec" + services = get_services(backend="ray", namespace=namespace) + + try: + services.register( + "python_executor", + PythonExecutorService, + pool_size=2, + timeout=5.0, + num_cpus=2, + max_concurrency=2, + ) + + executor = services["python_executor"] + + # Execute simple code + code = """ +x = 10 +y = 20 +result = x + y +print(f"Result: {result}") +""" + result = ray.get(executor.execute.remote(code), timeout=2) + + assert result["success"] is True + assert "Result: 30" in result["stdout"] + assert result["returncode"] == 0 + + finally: + services.reset() + + def test_service_execution_error(self, ray_init): + """Test that the service handles execution errors.""" + namespace = "test_executor_error" + services = get_services(backend="ray", namespace=namespace) + + try: + services.register( + "python_executor", + PythonExecutorService, + pool_size=2, + timeout=5.0, + num_cpus=2, + max_concurrency=2, + ) + + executor = services["python_executor"] + + # Execute code with an error + code = "raise ValueError('Test error')" + result = ray.get(executor.execute.remote(code), timeout=2) + + assert result["success"] is False + assert "ValueError: Test error" in result["stderr"] + + finally: + services.reset() + + def test_multiple_executions(self, ray_init): + """Test multiple concurrent executions.""" + namespace = "test_executor_multi" + services = get_services(backend="ray", namespace=namespace) + + try: + services.register( + "python_executor", + PythonExecutorService, + pool_size=4, + timeout=5.0, + num_cpus=4, + max_concurrency=4, + ) + + executor = services["python_executor"] + + # Submit multiple executions + futures = [] + for i in range(8): + code = f"print('Execution {i}')" + futures.append(executor.execute.remote(code)) + + # Wait for all to complete + results = ray.get(futures, timeout=5) + + # All should succeed + assert len(results) == 8 + for i, result in enumerate(results): + assert result["success"] is True + assert f"Execution {i}" in result["stdout"] + + finally: + services.reset() + + +class TestPythonInterpreterWithService: + """Test suite for PythonInterpreter using the service.""" + + def test_interpreter_with_service(self, ray_init): + """Test that PythonInterpreter can use the service.""" + namespace = "test_interp_service" + services = get_services(backend="ray", namespace=namespace) + + try: + # Register service + services.register( + "python_executor", + PythonExecutorService, + pool_size=2, + timeout=5.0, + num_cpus=2, + max_concurrency=2, + ) + + # Create interpreter with service + interpreter = PythonInterpreter( + services="ray", + service_name="python_executor", + namespace=namespace, + ) + + # Verify it's using the service + assert interpreter.python_service is not None + assert interpreter.processes is None + assert interpreter.services == "ray" + + finally: + services.reset() + + def test_interpreter_without_service(self): + """Test that PythonInterpreter works without service.""" + # Create interpreter without service + interpreter = PythonInterpreter( + services=None, + persistent=True, + ) + + # Verify it's using local processes + assert interpreter.python_service is None + assert interpreter.processes is not None + assert interpreter.services is None + + def test_interpreter_execution_with_service(self, ray_init): + """Test code execution through interpreter with service.""" + namespace = "test_interp_exec" + services = get_services(backend="ray", namespace=namespace) + + try: + # Register service + services.register( + "python_executor", + PythonExecutorService, + pool_size=2, + timeout=5.0, + num_cpus=2, + max_concurrency=2, + ) + + # Create interpreter with service + interpreter = PythonInterpreter(services="ray", namespace=namespace) + + # Execute code + code = "print('Hello from service')" + result = interpreter._execute_python_code(code, 0) + + assert result["success"] is True + assert "Hello from service" in result["stdout"] + + finally: + services.reset() + + def test_interpreter_clone_preserves_service(self, ray_init): + """Test that cloning an interpreter preserves service settings.""" + namespace = "test_interp_clone" + services = get_services(backend="ray", namespace=namespace) + + try: + # Register service + services.register( + "python_executor", + PythonExecutorService, + pool_size=2, + timeout=5.0, + num_cpus=2, + max_concurrency=2, + ) + + # Create interpreter with service + interpreter1 = PythonInterpreter( + services="ray", + service_name="python_executor", + namespace=namespace, + ) + + # Clone it + interpreter2 = interpreter1.clone() + + # Verify clone has same settings + assert interpreter2.services == "ray" + assert interpreter2.service_name == "python_executor" + assert interpreter2.python_service is not None + + finally: + services.reset() + + def test_interpreter_invalid_service_backend(self): + """Test that invalid service backend raises error.""" + with pytest.raises(ValueError, match="Invalid services backend"): + PythonInterpreter(services="invalid") + + def test_interpreter_missing_service(self, ray_init): + """Test that missing service raises error.""" + with pytest.raises(RuntimeError, match="Failed to get Ray service"): + PythonInterpreter(services="ray", service_name="nonexistent_service") + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/test_services.py b/test/services/test_services.py similarity index 100% rename from test/test_services.py rename to test/services/test_services.py diff --git a/test/services/test_services_fixtures.py b/test/services/test_services_fixtures.py new file mode 100644 index 00000000000..e703ff1715f --- /dev/null +++ b/test/services/test_services_fixtures.py @@ -0,0 +1,46 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Test fixtures for service tests that need to be importable by Ray workers.""" + +from typing import Any + + +class SimpleService: + """A simple service for testing.""" + + def __init__(self, value: int = 0): + self.value = value + + def get_value(self): + return self.value + + def set_value(self, value: int): + self.value = value + + def getattr(self, val: str, **kwargs) -> Any: + if "default" in kwargs: + default = kwargs["default"] + return getattr(self, val, default) + return getattr(self, val) + + +class TokenizerService: + """Mock tokenizer service.""" + + def __init__(self, vocab_size: int = 1000): + self.vocab_size = vocab_size + + def encode(self, text: str): + return [hash(c) % self.vocab_size for c in text] + + def decode(self, tokens: list): + return "".join([str(t) for t in tokens]) + + def getattr(self, val: str, **kwargs) -> Any: + if "default" in kwargs: + default = kwargs["default"] + return getattr(self, val, default) + return getattr(self, val) diff --git a/torchrl/testing/__init__.py b/torchrl/testing/__init__.py index 0f942e4c85f..6a7be0f118e 100644 --- a/torchrl/testing/__init__.py +++ b/torchrl/testing/__init__.py @@ -9,6 +9,11 @@ particularly for distributed and Ray-based tests that require importable classes. """ +from torchrl.testing.llm_mocks import ( + MockTransformerConfig, + MockTransformerModel, + MockTransformerOutput, +) from torchrl.testing.ray_helpers import ( WorkerTransformerDoubleBuffer, WorkerTransformerNCCL, @@ -21,4 +26,7 @@ "WorkerTransformerNCCL", "WorkerVLLMDoubleBuffer", "WorkerTransformerDoubleBuffer", + "MockTransformerConfig", + "MockTransformerModel", + "MockTransformerOutput", ] diff --git a/test/llm/conftest.py b/torchrl/testing/llm_mocks.py similarity index 76% rename from test/llm/conftest.py rename to torchrl/testing/llm_mocks.py index a69bdce6700..96889384a3a 100644 --- a/test/llm/conftest.py +++ b/torchrl/testing/llm_mocks.py @@ -5,7 +5,6 @@ """Shared test fixtures and mock infrastructure for LLM tests.""" from __future__ import annotations -import pytest import torch @@ -53,25 +52,3 @@ def get_tokenizer(self): from transformers import AutoTokenizer return AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B") - - -@pytest.fixture -def mock_transformer_model(): - """Fixture that provides a mock transformer model factory.""" - - def _make_model( - vocab_size: int = 1024, device: torch.device | str | int = "cpu" - ) -> MockTransformerModel: - """Make a mock transformer model.""" - device = torch.device(device) - return MockTransformerModel(vocab_size, device) - - return _make_model - - -@pytest.fixture -def mock_tokenizer(): - """Fixture that provides a mock tokenizer.""" - from transformers import AutoTokenizer - - return AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B") From 7fd8d0efbb8ba28af9a051103ddd07f2ebf9fa6c Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 6 Nov 2025 17:29:00 +0000 Subject: [PATCH 2/3] Update [ghstack-poisoned] --- test/services/test_python_executor_service.py | 1 + test/services/test_services.py | 2 ++ test/services/test_services_fixtures.py | 2 ++ 3 files changed, 5 insertions(+) diff --git a/test/services/test_python_executor_service.py b/test/services/test_python_executor_service.py index 43e85d3f0aa..cb55c0a6a10 100644 --- a/test/services/test_python_executor_service.py +++ b/test/services/test_python_executor_service.py @@ -1,4 +1,5 @@ """Tests for PythonExecutorService with Ray service registry.""" +from __future__ import annotations import pytest diff --git a/test/services/test_services.py b/test/services/test_services.py index a62ce9a9052..60e25547c62 100644 --- a/test/services/test_services.py +++ b/test/services/test_services.py @@ -3,6 +3,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import pytest pytest.importorskip("ray") diff --git a/test/services/test_services_fixtures.py b/test/services/test_services_fixtures.py index e703ff1715f..39420412716 100644 --- a/test/services/test_services_fixtures.py +++ b/test/services/test_services_fixtures.py @@ -5,6 +5,8 @@ """Test fixtures for service tests that need to be importable by Ray workers.""" +from __future__ import annotations + from typing import Any From 7522579ff3384d2928b6930ef30805e05fa4bd35 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 6 Nov 2025 18:05:25 +0000 Subject: [PATCH 3/3] Update [ghstack-poisoned] --- torchrl/services/__init__.py | 1 + torchrl/services/base.py | 1 + torchrl/services/ray_service.py | 1 + 3 files changed, 3 insertions(+) diff --git a/torchrl/services/__init__.py b/torchrl/services/__init__.py index 0a78077e06e..2a93d20b557 100644 --- a/torchrl/services/__init__.py +++ b/torchrl/services/__init__.py @@ -20,6 +20,7 @@ >>> tokenizer = services["tokenizer"] >>> result = tokenizer.encode.remote(text) """ +from __future__ import annotations from torchrl.services.base import ServiceBase from torchrl.services.ray_service import RayService diff --git a/torchrl/services/base.py b/torchrl/services/base.py index 0c7ceefb198..40d2f72ad68 100644 --- a/torchrl/services/base.py +++ b/torchrl/services/base.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations from abc import ABC, abstractmethod from typing import Any diff --git a/torchrl/services/ray_service.py b/torchrl/services/ray_service.py index 3fdd0f2b445..89b8bbd750f 100644 --- a/torchrl/services/ray_service.py +++ b/torchrl/services/ray_service.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations from typing import Any