Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,27 @@ def maybe_fork_ParallelEnv(request):
):
return functools.partial(ParallelEnv, mp_start_method="fork")
return ParallelEnv


# LLM testing fixtures
@pytest.fixture
def mock_transformer_model():
"""Fixture that provides a mock transformer model factory."""
from torchrl.testing import MockTransformerModel

def _make_model(
vocab_size: int = 1024, device: torch.device | str | int = "cpu"
) -> MockTransformerModel:
"""Make a mock transformer model."""
device = torch.device(device)
return MockTransformerModel(vocab_size, device)

return _make_model


@pytest.fixture
def mock_tokenizer():
"""Fixture that provides a mock tokenizer."""
from transformers import AutoTokenizer

return AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
2 changes: 2 additions & 0 deletions test/llm/libs/test_mlgym.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from torchrl.envs.llm import make_mlgym
from torchrl.modules.llm import TransformersWrapper

pytest.importorskip("mlgym")


class TestMLGYM:
def test_mlgym_specs(self):
Expand Down
270 changes: 270 additions & 0 deletions test/services/test_python_executor_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
"""Tests for PythonExecutorService with Ray service registry."""
from __future__ import annotations

import pytest

# Skip all tests if Ray is not available
pytest.importorskip("ray")

import ray
from torchrl.envs.llm.transforms import PythonExecutorService, PythonInterpreter
from torchrl.services import get_services


@pytest.fixture
def ray_init():
"""Initialize Ray for tests."""
if not ray.is_initialized():
ray.init()
yield
if ray.is_initialized():
ray.shutdown()


class TestPythonExecutorService:
"""Test suite for PythonExecutorService."""

def test_service_initialization(self, ray_init):
"""Test that the service can be created and registered."""
namespace = "test_executor_init"
services = get_services(backend="ray", namespace=namespace)

try:
services.register(
"python_executor",
PythonExecutorService,
pool_size=2,
timeout=5.0,
num_cpus=2,
max_concurrency=2,
)

# Verify it was registered
assert "python_executor" in services

# Get the service
executor = services["python_executor"]
assert executor is not None

finally:
services.reset()

def test_service_execution(self, ray_init):
"""Test that the service can execute Python code."""
namespace = "test_executor_exec"
services = get_services(backend="ray", namespace=namespace)

try:
services.register(
"python_executor",
PythonExecutorService,
pool_size=2,
timeout=5.0,
num_cpus=2,
max_concurrency=2,
)

executor = services["python_executor"]

# Execute simple code
code = """
x = 10
y = 20
result = x + y
print(f"Result: {result}")
"""
result = ray.get(executor.execute.remote(code), timeout=2)

assert result["success"] is True
assert "Result: 30" in result["stdout"]
assert result["returncode"] == 0

finally:
services.reset()

def test_service_execution_error(self, ray_init):
"""Test that the service handles execution errors."""
namespace = "test_executor_error"
services = get_services(backend="ray", namespace=namespace)

try:
services.register(
"python_executor",
PythonExecutorService,
pool_size=2,
timeout=5.0,
num_cpus=2,
max_concurrency=2,
)

executor = services["python_executor"]

# Execute code with an error
code = "raise ValueError('Test error')"
result = ray.get(executor.execute.remote(code), timeout=2)

assert result["success"] is False
assert "ValueError: Test error" in result["stderr"]

finally:
services.reset()

def test_multiple_executions(self, ray_init):
"""Test multiple concurrent executions."""
namespace = "test_executor_multi"
services = get_services(backend="ray", namespace=namespace)

try:
services.register(
"python_executor",
PythonExecutorService,
pool_size=4,
timeout=5.0,
num_cpus=4,
max_concurrency=4,
)

executor = services["python_executor"]

# Submit multiple executions
futures = []
for i in range(8):
code = f"print('Execution {i}')"
futures.append(executor.execute.remote(code))

# Wait for all to complete
results = ray.get(futures, timeout=5)

# All should succeed
assert len(results) == 8
for i, result in enumerate(results):
assert result["success"] is True
assert f"Execution {i}" in result["stdout"]

finally:
services.reset()


class TestPythonInterpreterWithService:
"""Test suite for PythonInterpreter using the service."""

def test_interpreter_with_service(self, ray_init):
"""Test that PythonInterpreter can use the service."""
namespace = "test_interp_service"
services = get_services(backend="ray", namespace=namespace)

try:
# Register service
services.register(
"python_executor",
PythonExecutorService,
pool_size=2,
timeout=5.0,
num_cpus=2,
max_concurrency=2,
)

# Create interpreter with service
interpreter = PythonInterpreter(
services="ray",
service_name="python_executor",
namespace=namespace,
)

# Verify it's using the service
assert interpreter.python_service is not None
assert interpreter.processes is None
assert interpreter.services == "ray"

finally:
services.reset()

def test_interpreter_without_service(self):
"""Test that PythonInterpreter works without service."""
# Create interpreter without service
interpreter = PythonInterpreter(
services=None,
persistent=True,
)

# Verify it's using local processes
assert interpreter.python_service is None
assert interpreter.processes is not None
assert interpreter.services is None

def test_interpreter_execution_with_service(self, ray_init):
"""Test code execution through interpreter with service."""
namespace = "test_interp_exec"
services = get_services(backend="ray", namespace=namespace)

try:
# Register service
services.register(
"python_executor",
PythonExecutorService,
pool_size=2,
timeout=5.0,
num_cpus=2,
max_concurrency=2,
)

# Create interpreter with service
interpreter = PythonInterpreter(services="ray", namespace=namespace)

# Execute code
code = "print('Hello from service')"
result = interpreter._execute_python_code(code, 0)

assert result["success"] is True
assert "Hello from service" in result["stdout"]

finally:
services.reset()

def test_interpreter_clone_preserves_service(self, ray_init):
"""Test that cloning an interpreter preserves service settings."""
namespace = "test_interp_clone"
services = get_services(backend="ray", namespace=namespace)

try:
# Register service
services.register(
"python_executor",
PythonExecutorService,
pool_size=2,
timeout=5.0,
num_cpus=2,
max_concurrency=2,
)

# Create interpreter with service
interpreter1 = PythonInterpreter(
services="ray",
service_name="python_executor",
namespace=namespace,
)

# Clone it
interpreter2 = interpreter1.clone()

# Verify clone has same settings
assert interpreter2.services == "ray"
assert interpreter2.service_name == "python_executor"
assert interpreter2.python_service is not None

finally:
services.reset()

def test_interpreter_invalid_service_backend(self):
"""Test that invalid service backend raises error."""
with pytest.raises(ValueError, match="Invalid services backend"):
PythonInterpreter(services="invalid")

def test_interpreter_missing_service(self, ray_init):
"""Test that missing service raises error."""
with pytest.raises(RuntimeError, match="Failed to get Ray service"):
PythonInterpreter(services="ray", service_name="nonexistent_service")


if __name__ == "__main__":
pytest.main([__file__, "-v"])
2 changes: 2 additions & 0 deletions test/test_services.py → test/services/test_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import pytest

pytest.importorskip("ray")
Expand Down
48 changes: 48 additions & 0 deletions test/services/test_services_fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Test fixtures for service tests that need to be importable by Ray workers."""

from __future__ import annotations

from typing import Any


class SimpleService:
"""A simple service for testing."""

def __init__(self, value: int = 0):
self.value = value

def get_value(self):
return self.value

def set_value(self, value: int):
self.value = value

def getattr(self, val: str, **kwargs) -> Any:
if "default" in kwargs:
default = kwargs["default"]
return getattr(self, val, default)
return getattr(self, val)


class TokenizerService:
"""Mock tokenizer service."""

def __init__(self, vocab_size: int = 1000):
self.vocab_size = vocab_size

def encode(self, text: str):
return [hash(c) % self.vocab_size for c in text]

def decode(self, tokens: list):
return "".join([str(t) for t in tokens])

def getattr(self, val: str, **kwargs) -> Any:
if "default" in kwargs:
default = kwargs["default"]
return getattr(self, val, default)
return getattr(self, val)
1 change: 1 addition & 0 deletions torchrl/services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
>>> tokenizer = services["tokenizer"]
>>> result = tokenizer.encode.remote(text)
"""
from __future__ import annotations

from torchrl.services.base import ServiceBase
from torchrl.services.ray_service import RayService
Expand Down
1 change: 1 addition & 0 deletions torchrl/services/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any
Expand Down
1 change: 1 addition & 0 deletions torchrl/services/ray_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

from typing import Any

Expand Down
Loading
Loading