From 7f168a74b7a9df141f35f0ac07287af5ee4cb8c2 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 23 Oct 2025 10:46:57 -0700 Subject: [PATCH 1/3] Update [ghstack-poisoned] --- test/llm/test_transforms.py | 441 ++++++++++++++++++++ torchrl/envs/llm/transforms/__init__.py | 27 +- torchrl/envs/llm/transforms/tools.py | 509 +++++++++++++++++++++++- 3 files changed, 970 insertions(+), 7 deletions(-) create mode 100644 test/llm/test_transforms.py diff --git a/test/llm/test_transforms.py b/test/llm/test_transforms.py new file mode 100644 index 00000000000..ab33315f59f --- /dev/null +++ b/test/llm/test_transforms.py @@ -0,0 +1,441 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import json + +import pytest +from tensordict import set_list_to_stack, TensorDict + +from torchrl.data.llm import History +from torchrl.envs.llm import ChatEnv +from torchrl.envs.llm.transforms import ( + ExecuteToolsInOrder, + JSONCallParser, + ToolCall, + ToolRegistry, + XMLBlockParser, +) +from torchrl.envs.transforms import TransformedEnv + + +@pytest.fixture(scope="module", autouse=True) +def list_to_stack_fixture(): + with set_list_to_stack(True): + yield + + +class TestToolRegistry: + """Test the ToolRegistry class.""" + + def test_register_and_get(self): + """Test basic registration and retrieval.""" + + class DummyService: + name = "dummy" + schema_in = {"x": int} + schema_out = {"y": int} + + def __call__(self, x, **kwargs): + return {"y": x * 2} + + service = DummyService() + registry = ToolRegistry([service]) + + retrieved = registry.get("dummy") + assert retrieved.name == "dummy" + result = retrieved(x=5) + assert result["y"] == 10 + + def test_register_after_init(self): + """Test registering services after initialization.""" + + class ServiceA: + name = "a" + schema_in = {} + schema_out = {} + + def __call__(self, **kwargs): + return {"result": "a"} + + class ServiceB: + name = "b" + schema_in = {} + schema_out = {} + + def __call__(self, **kwargs): + return {"result": "b"} + + registry = ToolRegistry([ServiceA()]) + assert "a" in registry + assert "b" not in registry + + registry.register(ServiceB()) + assert "b" in registry + + def test_unknown_tool_raises_error(self): + """Test that requesting an unknown tool raises KeyError.""" + registry = ToolRegistry() + with pytest.raises(KeyError, match="Unknown tool: nonexistent"): + registry.get("nonexistent") + + +class TestXMLBlockParser: + """Test the XMLBlockParser.""" + + def test_parse_single_tool(self): + """Test parsing a single tool call.""" + parser = XMLBlockParser() + response = '{"query": "test"}' + result = parser(response) + + assert result["text"] == "" + assert len(result["calls"]) == 1 + assert result["calls"][0].tool == "search" + assert result["calls"][0].args == {"query": "test"} + assert result["calls"][0].tag is None + + def test_parse_multiple_tools(self): + """Test parsing multiple tool calls.""" + parser = XMLBlockParser() + response = """Some text +{"query": "first"} +More text +{"expr": "1+1"} +Final text""" + result = parser(response) + + assert "Some text" in result["text"] + assert "More text" in result["text"] + assert "Final text" in result["text"] + assert len(result["calls"]) == 2 + assert result["calls"][0].tool == "search" + assert result["calls"][0].tag == "A" + assert result["calls"][1].tool == "calculate" + + def test_parse_with_tag(self): + """Test parsing with optional tag attribute.""" + parser = XMLBlockParser() + response = '{"a": 1}' + result = parser(response) + + assert result["calls"][0].tag == "my_tag" + + def test_parse_invalid_json(self): + """Test handling of invalid JSON in tool body.""" + parser = XMLBlockParser() + response = '{invalid json}' + result = parser(response) + + # Should fall back to raw body + assert len(result["calls"]) == 1 + assert "raw" in result["calls"][0].args + + def test_parse_empty_body(self): + """Test parsing with empty tool body.""" + parser = XMLBlockParser() + response = '' + result = parser(response) + + assert len(result["calls"]) == 1 + assert result["calls"][0].args == {} + + +class TestJSONCallParser: + """Test the JSONCallParser.""" + + def test_parse_dict_response(self): + """Test parsing a dictionary response.""" + parser = JSONCallParser() + response = { + "message": "Here's the result", + "tools": [{"tool": "search", "args": {"query": "test"}}], + } + result = parser(response) + + assert result["text"] == "Here's the result" + assert len(result["calls"]) == 1 + assert result["calls"][0].tool == "search" + + def test_parse_json_string(self): + """Test parsing a JSON string.""" + parser = JSONCallParser() + response = json.dumps( + { + "message": "Processing", + "tools": [ + {"tool": "calc", "args": {"expr": "1+1"}}, + {"tool": "search", "args": {"q": "test"}, "tag": "T1"}, + ], + } + ) + result = parser(response) + + assert len(result["calls"]) == 2 + assert result["calls"][1].tag == "T1" + + def test_parse_invalid_json_string(self): + """Test handling of invalid JSON string.""" + parser = JSONCallParser() + response = "Not valid JSON" + result = parser(response) + + assert result["text"] == "Not valid JSON" + assert len(result["calls"]) == 0 + + def test_parse_no_tools(self): + """Test parsing response with no tools.""" + parser = JSONCallParser() + response = {"message": "Just a message"} + result = parser(response) + + assert result["text"] == "Just a message" + assert len(result["calls"]) == 0 + + +class TestExecuteToolsInOrder: + """Test the ExecuteToolsInOrder transform.""" + + def test_basic_tool_execution(self): + """Test basic tool execution with XML parser.""" + + class AddService: + name = "add" + schema_in = {"a": int, "b": int} + schema_out = {"result": int} + + def __call__(self, a, b, **kwargs): + return {"result": a + b} + + registry = ToolRegistry([AddService()]) + parser = XMLBlockParser() + + env = ChatEnv(batch_size=(1,), input_mode="history") + env = TransformedEnv(env, ExecuteToolsInOrder(registry=registry, parser=parser)) + + # Reset + reset_data = TensorDict({"query": "Calculate something"}, batch_size=(1,)) + obs = env.reset(reset_data) + + # Simulate LLM response with tool call + llm_response = '{"a": 3, "b": 5}' + obs["history"].full = obs["history"].prompt.extend( + History(role="assistant", content=llm_response).view(1, 1), dim=-1 + ) + + # Step + next_obs = env.step(obs) + + # Check that tool result is in history + final_history = next_obs[("next", "history")].prompt + assert len(final_history[0]) > len(obs["history"].prompt[0]) + + # Find the tool result message + tool_result_found = False + for msg in final_history[0]: + if msg.role == "tool" and "result" in msg.content: + tool_result_found = True + assert "8" in msg.content or "result" in msg.content + break + assert tool_result_found + + def test_multiple_tools_in_order(self): + """Test that multiple tools execute in order of appearance.""" + execution_order = [] + + class FirstService: + name = "first" + schema_in = {} + schema_out = {} + + def __call__(self, **kwargs): + execution_order.append("first") + return {"executed": "first"} + + class SecondService: + name = "second" + schema_in = {} + schema_out = {} + + def __call__(self, **kwargs): + execution_order.append("second") + return {"executed": "second"} + + class ThirdService: + name = "third" + schema_in = {} + schema_out = {} + + def __call__(self, **kwargs): + execution_order.append("third") + return {"executed": "third"} + + registry = ToolRegistry([FirstService(), SecondService(), ThirdService()]) + parser = XMLBlockParser() + + env = ChatEnv(batch_size=(1,), input_mode="history") + env = TransformedEnv(env, ExecuteToolsInOrder(registry=registry, parser=parser)) + + reset_data = TensorDict({"query": "Test"}, batch_size=(1,)) + obs = env.reset(reset_data) + + # Tools in specific order + llm_response = """ +{} +{} +{} +""" + obs["history"].full = obs["history"].prompt.extend( + History(role="assistant", content=llm_response).view(1, 1), dim=-1 + ) + + env.step(obs) + + # Check execution order matches appearance order + assert execution_order == ["second", "first", "third"] + + def test_error_handling_continue(self): + """Test error handling with stop_on_error=False.""" + + class WorkingService: + name = "working" + schema_in = {} + schema_out = {} + + def __call__(self, **kwargs): + return {"status": "ok"} + + class FailingService: + name = "failing" + schema_in = {} + schema_out = {} + + def __call__(self, **kwargs): + raise ValueError("Intentional failure") + + registry = ToolRegistry([WorkingService(), FailingService()]) + parser = XMLBlockParser() + + env = ChatEnv(batch_size=(1,), input_mode="history") + env = TransformedEnv( + env, + ExecuteToolsInOrder(registry=registry, parser=parser, stop_on_error=False), + ) + + reset_data = TensorDict({"query": "Test"}, batch_size=(1,)) + obs = env.reset(reset_data) + + llm_response = """ +{} +{} +{} +""" + obs["history"].full = obs["history"].prompt.extend( + History(role="assistant", content=llm_response).view(1, 1), dim=-1 + ) + + # Should not raise, continues execution + next_obs = env.step(obs) + + # Check that result contains both successes and failure + final_history = next_obs[("next", "history")].prompt + tool_msg = None + for msg in final_history[0]: + if msg.role == "tool": + tool_msg = msg.content + break + + assert tool_msg is not None + assert "succeeded" in tool_msg + assert "failed" in tool_msg or "error" in tool_msg.lower() + + def test_state_passing_to_tools(self): + """Test that state is passed to tools when enabled.""" + received_state = {} + + class StateCheckService: + name = "check" + schema_in = {} + schema_out = {} + + def __call__(self, _state=None, **kwargs): + received_state["state"] = _state + return {"received_state": _state is not None} + + registry = ToolRegistry([StateCheckService()]) + parser = XMLBlockParser() + + env = ChatEnv(batch_size=(1,), input_mode="history") + env = TransformedEnv( + env, + ExecuteToolsInOrder( + registry=registry, parser=parser, pass_state_to_tools=True + ), + ) + + reset_data = TensorDict({"query": "Check state"}, batch_size=(1,)) + obs = env.reset(reset_data) + + llm_response = '{}' + obs["history"].full = obs["history"].prompt.extend( + History(role="assistant", content=llm_response).view(1, 1), dim=-1 + ) + + env.step(obs) + + # Check that state was received + assert received_state["state"] is not None + assert isinstance(received_state["state"], dict) + + def test_no_tools_in_response(self): + """Test behavior when response contains no tool calls.""" + + class DummyService: + name = "dummy" + schema_in = {} + schema_out = {} + + def __call__(self, **kwargs): + return {} + + registry = ToolRegistry([DummyService()]) + parser = XMLBlockParser() + + env = ChatEnv(batch_size=(1,), input_mode="history") + env = TransformedEnv(env, ExecuteToolsInOrder(registry=registry, parser=parser)) + + reset_data = TensorDict({"query": "Just chat"}, batch_size=(1,)) + obs = env.reset(reset_data) + + # No tool calls in response + llm_response = "Just a normal response without tools." + obs["history"].full = obs["history"].prompt.extend( + History(role="assistant", content=llm_response).view(1, 1), dim=-1 + ) + + # Should work fine without errors + next_obs = env.step(obs) + + # History should not contain tool messages + final_history = next_obs[("next", "history")].prompt + for msg in final_history[0]: + assert msg.role != "tool" + + +class TestToolCall: + """Test the ToolCall dataclass.""" + + def test_creation(self): + """Test creating ToolCall instances.""" + call = ToolCall(tool="test", args={"x": 1}) + assert call.tool == "test" + assert call.args == {"x": 1} + assert call.tag is None + + def test_with_tag(self): + """Test ToolCall with optional tag.""" + call = ToolCall(tool="test", args={}, tag="my_tag") + assert call.tag == "my_tag" diff --git a/torchrl/envs/llm/transforms/__init__.py b/torchrl/envs/llm/transforms/__init__.py index 0a90f86daeb..5ab6b0a8378 100644 --- a/torchrl/envs/llm/transforms/__init__.py +++ b/torchrl/envs/llm/transforms/__init__.py @@ -15,22 +15,37 @@ from .policy_version import PolicyVersion from .reason import AddThinkingPrompt from .tokenizer import Tokenizer -from .tools import MCPToolTransform, PythonInterpreter +from .tools import ( + ExecuteToolsInOrder, + JSONCallParser, + MCPToolTransform, + PythonInterpreter, + ToolCall, + ToolRegistry, + ToolService, + XMLBlockParser, +) __all__ = [ + "AddThinkingPrompt", "BrowserTransform", "DataLoadingPrimer", - "RayDataLoadingPrimer", + "ExecuteToolsInOrder", + "JSONCallParser", + "KLComputation", "KLRewardTransform", - "RetrieveLogProb", - "RetrieveKL", "MCPToolTransform", "PolicyVersion", "PythonInterpreter", - "AddThinkingPrompt", - "KLComputation", + "RayDataLoadingPrimer", + "RetrieveKL", + "RetrieveLogProb", "TemplateTransform", "Tokenizer", + "ToolCall", + "ToolRegistry", + "ToolService", + "XMLBlockParser", "as_nested_tensor", "as_padded_tensor", ] diff --git a/torchrl/envs/llm/transforms/tools.py b/torchrl/envs/llm/transforms/tools.py index 22fccb27cba..dfbb5071872 100644 --- a/torchrl/envs/llm/transforms/tools.py +++ b/torchrl/envs/llm/transforms/tools.py @@ -5,6 +5,7 @@ from __future__ import annotations +import json import os import queue import re @@ -12,7 +13,9 @@ import tempfile import threading import time -from typing import TextIO +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any, Protocol, TextIO import torch @@ -23,6 +26,510 @@ from torchrl.envs import Transform +# --- Tool Service Library: Pluggable Services & Parsers --- + + +class ToolService(Protocol): + """Protocol for side-effecting service callable with structured IO. + + A tool service is a callable that can be invoked with keyword arguments + and returns a dictionary of results. It has a name and input/output schemas. + + Attributes: + name (str): The name of the tool service. + schema_in (dict[str, Any]): Input schema describing expected parameters. + schema_out (dict[str, Any]): Output schema describing returned data. + """ + + name: str + schema_in: dict[str, Any] + schema_out: dict[str, Any] + + def __call__(self, **kwargs) -> dict[str, Any]: + """Execute the tool service. + + Args: + **kwargs: Keyword arguments matching the input schema. + + Returns: + dict[str, Any]: Results matching the output schema. + """ + ... + + +class ToolRegistry: + """Registry for managing available tool services. + + This class maintains a collection of tool services that can be looked up + by name for execution. + + Args: + services (Sequence[ToolService], optional): Initial services to register. + Defaults to an empty sequence. + + Examples: + >>> class AddService: + ... name = "add" + ... schema_in = {"a": int, "b": int} + ... schema_out = {"result": int} + ... def __call__(self, a, b, **kwargs): + ... return {"result": a + b} + >>> registry = ToolRegistry([AddService()]) + >>> service = registry.get("add") + >>> result = service(a=1, b=2) + >>> print(result) + {"result": 3} + """ + + def __init__(self, services: Sequence[ToolService] = ()): + self._svc: dict[str, ToolService] = {s.name: s for s in services} + + def register(self, service: ToolService) -> None: + """Register a new service. + + Args: + service (ToolService): The service to register. + """ + self._svc[service.name] = service + + def get(self, name: str) -> ToolService: + """Retrieve a service by name. + + Args: + name (str): The name of the service to retrieve. + + Returns: + ToolService: The requested service. + + Raises: + KeyError: If the service is not found. + """ + if name not in self._svc: + raise KeyError(f"Unknown tool: {name}") + return self._svc[name] + + def __contains__(self, name: str) -> bool: + """Check if a service is registered. + + Args: + name (str): The name to check. + + Returns: + bool: True if the service exists, False otherwise. + """ + return name in self._svc + + +@dataclass +class ToolCall: + """Representation of a parsed tool call from LLM output. + + Attributes: + tool (str): The name of the tool to call. + args (dict[str, Any]): Arguments to pass to the tool. + tag (str | None): Optional user-visible label or correlation ID. + """ + + tool: str + args: dict[str, Any] + tag: str | None = None + + +class ParseResult(dict): + """Result of parsing an LLM response for tool calls. + + This is a TypedDict-style class that contains: + text (str): The final message to user (post tool blocks removal). + calls (list[ToolCall]): Ordered tool calls as they appear. + meta (dict[str, Any]): Optional parser metadata. + """ + + +class LLMToolParser(Protocol): + """Protocol for parsing LLM responses into ordered tool calls. + + A tool parser takes the LLM's response (as string or structured data) + and extracts ordered tool calls, along with the cleaned user-facing text. + """ + + def __call__(self, response: str | dict[str, Any]) -> ParseResult: + """Parse an LLM response. + + Args: + response (str | dict[str, Any]): The LLM's response to parse. + + Returns: + ParseResult: Parsed result with text, calls, and metadata. + """ + ... + + +class XMLBlockParser: + """Parser for XML-style tool blocks in LLM responses. + + Parses tool calls in the format: + {"arg": "value"} + + Examples: + >>> parser = XMLBlockParser() + >>> response = '{"query": "torchrl"}\\nSome text.' + >>> result = parser(response) + >>> print(result["text"]) + Some text. + >>> print(result["calls"][0].tool) + search + >>> print(result["calls"][0].args) + {"query": "torchrl"} + """ + + _re = re.compile( + r'[^"]+)"(?:\s+tag="(?P[^"]+)")?\s*>\s*(?P.*?)\s*', + re.DOTALL, + ) + + def __call__(self, response: str | dict[str, Any]) -> ParseResult: + """Parse XML-style tool blocks from response. + + Args: + response (str | dict[str, Any]): The response to parse. + + Returns: + ParseResult: Parsed result with cleaned text and tool calls. + """ + text = response if isinstance(response, str) else response.get("text", "") + calls: list[ToolCall] = [] + + def repl(m: re.Match) -> str: + name = m.group("name") + tag = m.group("tag") + body = m.group("body") + try: + args = json.loads(body) if body.strip() else {} + except json.JSONDecodeError: + # If JSON parsing fails, pass the raw body as a "raw" argument + args = {"raw": body} + calls.append(ToolCall(tool=name, args=args, tag=tag)) + return "" # Remove block from final user-visible message + + cleaned = self._re.sub(repl, text).strip() + result = ParseResult() + result["text"] = cleaned + result["calls"] = calls + result["meta"] = {"count": len(calls)} + return result + + +class JSONCallParser: + """Parser for JSON-style function-calling responses. + + Expects responses in the format: + { + "message": "...", + "tools": [ + {"tool": "search", "args": {"query": "..."}, "tag": "A"}, + {"tool": "summarize", "args": {"text": "..."}} + ] + } + + Examples: + >>> parser = JSONCallParser() + >>> response = { + ... "message": "Let me search for that.", + ... "tools": [{"tool": "search", "args": {"query": "torchrl"}}] + ... } + >>> result = parser(response) + >>> print(result["text"]) + Let me search for that. + >>> print(result["calls"][0].tool) + search + """ + + def __call__(self, response: str | dict[str, Any]) -> ParseResult: + """Parse JSON-style function calls from response. + + Args: + response (str | dict[str, Any]): The response to parse. + + Returns: + ParseResult: Parsed result with message and tool calls. + """ + if isinstance(response, str): + try: + response = json.loads(response) + except json.JSONDecodeError: + # If it's not valid JSON, treat as plain text with no tools + result = ParseResult() + result["text"] = response + result["calls"] = [] + result["meta"] = {"count": 0} + return result + + tools_data = response.get("tools", []) + calls = [ToolCall(**c) for c in tools_data] + + result = ParseResult() + result["text"] = response.get("message", "") + result["calls"] = calls + result["meta"] = {"count": len(calls)} + return result + + +class ExecuteToolsInOrder(Transform): + """A Transform that executes tools in the order they appear in LLM output. + + This transform reads the LLM response, parses ordered tool blocks using a + pluggable parser, and executes tools via a ToolRegistry strictly in the + order they appear in the response (independent of transform stacking order). + + The transform integrates naturally with TorchRL's LLM environments and can + read/write conversation history alongside other transforms. + + Args: + registry (ToolRegistry): Registry containing available tool services. + parser (LLMToolParser): Parser for extracting tool calls from LLM output. + in_keys (tuple[str, ...], optional): Key where LLM response is read. + Defaults to ``("history", "prompt")``. + out_keys (tuple[str, ...], optional): Key where tool results are written. + Defaults to ``("tools", "results")``. + message_key (tuple[str, ...], optional): Key for cleaned message text. + Defaults to ``("llm", "message")``. + history_key (tuple[str, ...], optional): Key for conversation history. + Defaults to ``("history", "prompt")``. + write_calls_key (tuple[str, ...], optional): Key for storing parsed calls. + Defaults to ``("tools", "calls")``. + stop_on_error (bool, optional): Whether to stop execution on first error. + Defaults to ``False``. + pass_state_to_tools (bool, optional): Whether to pass TD state to tools. + Defaults to ``True``. + + Examples: + >>> from torchrl.envs.llm import ChatEnv + >>> from torchrl.envs.transforms import TransformedEnv, Compose + >>> from torchrl.envs.llm.transforms import ExecuteToolsInOrder, ToolRegistry, XMLBlockParser + >>> + >>> # Define a simple service + >>> class WebSearch: + ... name = "search" + ... schema_in = {"query": str} + ... schema_out = {"results": list} + ... def __call__(self, query: str, **kwargs): + ... return {"results": [{"title": "TorchRL docs", "url": "https://..."}]} + >>> + >>> # Create registry and parser + >>> registry = ToolRegistry([WebSearch()]) + >>> parser = XMLBlockParser() + >>> + >>> # Create environment with transform + >>> env = ChatEnv(batch_size=(1,)) + >>> env = TransformedEnv( + ... env, + ... ExecuteToolsInOrder(registry=registry, parser=parser) + ... ) + + .. note:: + This transform operates in the forward direction only; inverse is a no-op. + Tool execution order is determined by appearance in the LLM output, + not by the order of transforms in the Compose stack. + """ + + def __init__( + self, + registry: ToolRegistry, + parser: LLMToolParser, + in_keys: tuple[str, ...] | None = None, + out_keys: tuple[str, ...] | None = None, + message_key: tuple[str, ...] | None = None, + history_key: tuple[str, ...] | None = None, + write_calls_key: tuple[str, ...] | None = None, + stop_on_error: bool = False, + pass_state_to_tools: bool = True, + ): + # Set defaults + if in_keys is None: + in_keys = ("history", "prompt") + if out_keys is None: + out_keys = ("tools", "results") + if message_key is None: + message_key = ("llm", "message") + if history_key is None: + history_key = ("history", "prompt") + if write_calls_key is None: + write_calls_key = ("tools", "calls") + + super().__init__(in_keys=[in_keys], out_keys=[out_keys]) + self.registry = registry + self.parser = parser + self._in = in_keys + self._out = out_keys + self._msg = message_key + self._hist = history_key + self._calls = write_calls_key + self.stop_on_error = stop_on_error + self.pass_state_to_tools = pass_state_to_tools + + def _step( + self, tensordict: TensorDictBase, next_tensordict: TensorDictBase + ) -> TensorDictBase: + """Execute tools during environment step. + + Args: + tensordict (TensorDictBase): Input tensordict before step. + next_tensordict (TensorDictBase): Output tensordict after step. + + Returns: + TensorDictBase: Modified next_tensordict with tool results. + """ + if next_tensordict.batch_dims > 1: + with next_tensordict.view(-1) as next_tensordict_flat: + next_tensordict_flat = self._step(tensordict, next_tensordict_flat) + return next_tensordict + + # Check that we're in history mode + parent = self.parent + if parent is None: + raise RuntimeError("ExecuteToolsInOrder must be used with a ChatEnv") + base_env = parent.base_env + if base_env.input_mode != "history": + raise RuntimeError( + "ExecuteToolsInOrder must be used with a ChatEnv in history mode" + ) + + # Get the history and extract the last message (LLM response) + history = next_tensordict["history"].prompt + local_history = history[..., -1] + + procs = [] + # Iterate over batch + for i, response_text in enumerate(local_history.content): + # Parse the response for tool calls + parse: ParseResult = self.parser(response_text) + ordered_calls = parse["calls"] + tool_outputs: list[dict[str, Any]] = [] + + # Execute tools IN ORDER OF APPEARANCE + for j, call in enumerate(ordered_calls): + try: + service = self.registry.get(call.tool) + kwargs = dict(call.args) + if self.pass_state_to_tools: + kwargs["_state"] = self._export_state_for_tool(next_tensordict) + + out = service(**kwargs) + out["_tool"] = call.tool + out["_index"] = j + if call.tag: + out["_tag"] = call.tag + tool_outputs.append(out) + except Exception as e: + err = {"_tool": call.tool, "_index": j, "error": str(e)} + tool_outputs.append(err) + if self.stop_on_error: + break + + # Store results and calls in tensordict + if tool_outputs: + # Format tool results as history entries + results_text = self._format_tool_results(tool_outputs) + if results_text: + procs.append([History(role="tool", content=results_text)]) + else: + procs.append(None) + else: + procs.append(None) + + # Add tool results to history if any tools were executed + if not all(p is None for p in procs): + if any(p is None for p in procs): + procs = [p if p is not None else [] for p in procs] + + # Ensure all batch elements have same length + if len(procs) > 1 and not all(len(p) == len(procs[0]) for p in procs): + + def fill_procs(proc: list[History], max_len: int) -> list[History]: + if len(proc) == max_len: + return proc + return proc + [History(role="tool", content="")] * ( + max_len - len(proc) + ) + + max_len = max(len(p) for p in procs) + procs = [fill_procs(p, max_len) for p in procs] + + # Stack and extend history + procs = lazy_stack([lazy_stack(p) for p in procs]) + history.extend(procs, dim=-1) + next_tensordict["history"].prompt = history + + return next_tensordict + + def _format_tool_results(self, tool_outputs: list[dict[str, Any]]) -> str: + """Format tool execution results as text. + + Args: + tool_outputs (list[dict[str, Any]]): List of tool execution results. + + Returns: + str: Formatted text representation of results. + """ + if not tool_outputs: + return "" + + lines = [""] + for output in tool_outputs: + tool_name = output.pop("_tool", "unknown") + index = output.pop("_index", 0) + tag = output.pop("_tag", None) + + if "error" in output: + lines.append(f"Tool {tool_name} (call {index + 1}) failed:") + lines.append(f" Error: {output['error']}") + else: + header = f"Tool {tool_name} (call {index + 1})" + if tag: + header += f" [tag: {tag}]" + header += " succeeded:" + lines.append(header) + lines.append(f" Result: {json.dumps(output, indent=2)}") + + lines.append("") + return "\n".join(lines) + + def _export_state_for_tool(self, td: TensorDictBase) -> dict[str, Any]: + """Export a filtered, read-only view of TD state for tools. + + Args: + td (TensorDictBase): The tensordict to export from. + + Returns: + dict[str, Any]: Filtered state dictionary. + """ + # Minimal, safe view; customize as needed + keys_for_tools = [("history", "prompt"), ("env", "step"), ("episode", "id")] + out = {} + for k in keys_for_tools: + if td.get(k, None) is not None: + value = td.get(k) + # Convert to Python types if needed + if isinstance(value, torch.Tensor): + value = value.tolist() if value.numel() > 1 else value.item() + out["/".join(k if isinstance(k, tuple) else (k,))] = value + return out + + def _reset( + self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase + ) -> TensorDictBase: + """Handle reset (no-op for this transform). + + Args: + tensordict (TensorDictBase): Input tensordict. + tensordict_reset (TensorDictBase): Reset tensordict. + + Returns: + TensorDictBase: Unchanged reset tensordict. + """ + return tensordict_reset + + class PersistentPythonProcess: """A persistent Python process that can execute code blocks.""" From fd322d9826cba0282aa1d7aa8aca920bb52507da Mon Sep 17 00:00:00 2001 From: vmoens Date: Sat, 25 Oct 2025 15:26:17 -0700 Subject: [PATCH 2/3] Update [ghstack-poisoned] --- .../unittest/llm/scripts_llm/environment.yml | 1 + .github/unittest/llm/scripts_llm/install.sh | 14 + examples/llm/python_mcp_tool.py | 96 ++ examples/llm/tool_service_example.py | 356 +++++ examples/llm/web_search_tool.py | 144 ++ test/llm/test_envs.py | 78 +- torchrl/envs/llm/transforms/__init__.py | 2 + torchrl/envs/llm/transforms/tools.py | 1243 +++++++++++------ 8 files changed, 1536 insertions(+), 398 deletions(-) create mode 100644 examples/llm/python_mcp_tool.py create mode 100644 examples/llm/tool_service_example.py create mode 100644 examples/llm/web_search_tool.py diff --git a/.github/unittest/llm/scripts_llm/environment.yml b/.github/unittest/llm/scripts_llm/environment.yml index 2e6b3c6e173..f1b0131d40d 100644 --- a/.github/unittest/llm/scripts_llm/environment.yml +++ b/.github/unittest/llm/scripts_llm/environment.yml @@ -22,3 +22,4 @@ dependencies: - transformers - datasets - vllm + - mcp diff --git a/.github/unittest/llm/scripts_llm/install.sh b/.github/unittest/llm/scripts_llm/install.sh index f5d23d2afbd..cf4f2372899 100644 --- a/.github/unittest/llm/scripts_llm/install.sh +++ b/.github/unittest/llm/scripts_llm/install.sh @@ -61,3 +61,17 @@ python -m pip install -e . --no-build-isolation # smoke test python -c "import torchrl" + +# Install MCP dependencies for tool execution tests +printf "* Installing MCP dependencies (uvx, Deno)\n" + +# Install uvx (universal package runner) +pip install uvx + +# Install Deno (required by mcp-run-python) +curl -fsSL https://deno.land/install.sh | sh +export PATH="$HOME/.deno/bin:$PATH" + +# Verify installations +uvx --version || echo "Warning: uvx not installed" +deno --version || echo "Warning: Deno not installed" diff --git a/examples/llm/python_mcp_tool.py b/examples/llm/python_mcp_tool.py new file mode 100644 index 00000000000..d9845152005 --- /dev/null +++ b/examples/llm/python_mcp_tool.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Execute Python code using MCP server with mcp-run-python.""" + +import json +import os + +from tensordict import set_list_to_stack, TensorDict + +from torchrl.data.llm import History +from torchrl.envs.llm import ChatEnv +from torchrl.envs.llm.transforms import MCPToolTransform + +set_list_to_stack(True).set() + +deno_path = os.path.expanduser("~/.deno/bin") +if deno_path not in os.environ.get("PATH", ""): + os.environ["PATH"] = f"{deno_path}:{os.environ['PATH']}" + +servers = { + "python": { + "command": "uvx", + "args": ["mcp-run-python", "stdio"], + "env": os.environ.copy(), + } +} + +env = ChatEnv(batch_size=(1,)) +env = env.append_transform(MCPToolTransform(servers=servers)) + +reset_data = TensorDict(query="You are a helpful assistant", batch_size=(1,)) +td = env.reset(reset_data) + +history = td.get("history") + +code = """ +import math +result = math.sqrt(144) + math.pi +print(f"Result: {result}") +result +""" + +response = ( + History( + role="assistant", + content=f'Let me calculate that.\npython.run_python_code\n{json.dumps({"python_code": code})}', + ) + .unsqueeze(0) + .unsqueeze(0) +) + +history.full = history.prompt.extend(response, inplace=True, dim=-1) +history.response = response + +result = env.step(td.set("history", history)) + +print("Python code executed via MCP!") +print("\nTool response:") +tool_response = result["next", "history"].prompt[0, -1] +print(f"Role: {tool_response.role}") +print(f"Content: {tool_response.content}") + +fibonacci_code = """ +def fibonacci(n): + if n <= 1: + return n + return fibonacci(n-1) + fibonacci(n-2) + +result = [fibonacci(i) for i in range(10)] +print(f"Fibonacci sequence: {result}") +result +""" + +history = result["next", "history"] +response2 = ( + History( + role="assistant", + content=f'Now calculating Fibonacci.\npython.run_python_code\n{json.dumps({"python_code": fibonacci_code})}', + ) + .unsqueeze(0) + .unsqueeze(0) +) + +history.full = history.prompt.extend(response2, inplace=True, dim=-1) +history.response = response2 + +result2 = env.step(result["next"].set("history", history)) + +print("\n\nSecond execution:") +print("\nTool response:") +tool_response2 = result2["next", "history"].prompt[0, -1] +print(f"Role: {tool_response2.role}") +print(f"Content: {tool_response2.content[:500]}...") diff --git a/examples/llm/tool_service_example.py b/examples/llm/tool_service_example.py new file mode 100644 index 00000000000..36a50a4ea77 --- /dev/null +++ b/examples/llm/tool_service_example.py @@ -0,0 +1,356 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Example demonstrating the ExecuteToolsInOrder transform with pluggable services. + +This example shows how to: +1. Define tool services with schemas +2. Register them in a ToolRegistry +3. Use different parsers (XML and JSON style) +4. Integrate with ChatEnv and LLM wrappers +""" + +from tensordict import TensorDict +from torchrl.data.llm import History +from torchrl.envs.llm import ChatEnv +from torchrl.envs.llm.transforms import ( + ExecuteToolsInOrder, + JSONCallParser, + ToolRegistry, + XMLBlockParser, +) +from torchrl.envs.transforms import TransformedEnv + + +# --- Define Example Tool Services --- + + +class WebSearchService: + """Example web search service.""" + + name = "search" + schema_in = {"query": str} + schema_out = {"results": list} + + def __call__(self, query: str, _state=None, **kwargs): + """Simulate a web search.""" + # In a real implementation, this would call an actual search API + return { + "results": [ + { + "title": "TorchRL Documentation", + "url": "https://pytorch.org/rl/", + "snippet": "TorchRL is a PyTorch library for reinforcement learning", + }, + { + "title": f"Results for: {query}", + "url": "https://example.com", + "snippet": f"Search results for query: {query}", + }, + ] + } + + +class SummarizeService: + """Example text summarization service.""" + + name = "summarize" + schema_in = {"text": str, "max_length": int} + schema_out = {"summary": str} + + def __call__(self, text: str, max_length: int = 200, _state=None, **kwargs): + """Summarize text to max_length characters.""" + if len(text) <= max_length: + return {"summary": text} + return {"summary": text[:max_length] + "..."} + + +class CalculatorService: + """Example calculator service.""" + + name = "calculate" + schema_in = {"expression": str} + schema_out = {"result": float} + + def __call__(self, expression: str, _state=None, **kwargs): + """Safely evaluate a mathematical expression.""" + try: + # Only allow safe math operations + allowed_chars = set("0123456789+-*/()., ") + if not all(c in allowed_chars for c in expression): + raise ValueError("Invalid characters in expression") + result = eval(expression) + return {"result": float(result)} + except Exception as e: + raise ValueError(f"Invalid expression: {str(e)}") + + +# --- Example 1: Using XML-style parser --- + + +def example_xml_parser(): + """Demonstrate using XML-style tool blocks.""" + print("\n" + "=" * 60) + print("Example 1: XML-style Parser") + print("=" * 60 + "\n") + + # Create registry with services + registry = ToolRegistry( + [WebSearchService(), SummarizeService(), CalculatorService()] + ) + + # Create parser + parser = XMLBlockParser() + + # Create environment + env = ChatEnv( + batch_size=(1,), + system_prompt="You are a helpful assistant with access to tools.", + input_mode="history", + ) + + # Add transform + env = TransformedEnv(env, ExecuteToolsInOrder(registry=registry, parser=parser)) + + # Reset with initial query + reset_data = TensorDict({"query": ["Hello, can you help me?"]}, batch_size=(1,)) + obs = env.reset(reset_data) + + # Simulate LLM response with tool calls + llm_response = """I'll search for information about TorchRL and calculate something. + +{"query": "TorchRL transforms"} + +{"expression": "42 * 3.14"} + +Let me process these results for you.""" + + # Create response tensordict + history = obs["history"].prompt + response = History(role="assistant", content=llm_response).view(1, 1) + full = history.extend( + response, + dim=-1, + inplace=False, + ) + obs["history"].full = full + obs["history"].response = response + + # Step with the response + next_obs = env.step(obs) + + # Print the conversation + print("Final conversation history:") + final_history = next_obs[("next", "history")].prompt + for i, msg in enumerate(final_history[0]): + print(f"\n[{i}] Role: {msg.role}") + print(f"Content: {msg.content[:200]}{'...' if len(msg.content) > 200 else ''}") + + +# --- Example 2: Using JSON-style parser --- + + +def example_json_parser(): + """Demonstrate using JSON-style function calling.""" + print("\n" + "=" * 60) + print("Example 2: JSON-style Parser") + print("=" * 60 + "\n") + + # Create registry with services + registry = ToolRegistry([WebSearchService(), CalculatorService()]) + + # Create JSON parser + parser = JSONCallParser() + + # Create environment + env = ChatEnv( + batch_size=(1,), + system_prompt="You are a helpful assistant.", + input_mode="history", + ) + + # Add transform + env = TransformedEnv(env, ExecuteToolsInOrder(registry=registry, parser=parser)) + + # Reset with initial query + reset_data = TensorDict({"query": ["Calculate 15 + 27"]}, batch_size=(1,)) + obs = env.reset(reset_data) + + # Simulate LLM response in JSON format + # Note: In practice, the LLM would generate this + import json + + llm_response_dict = { + "message": "I'll calculate that for you.", + "tools": [{"tool": "calculate", "args": {"expression": "15 + 27"}}], + } + + # For the JSON parser, we need to pass the content as JSON string + llm_response = json.dumps(llm_response_dict) + + # Create response tensordict + history = obs["history"].prompt + response = History(role="assistant", content=llm_response).view(1, 1) + full = history.extend( + response, + dim=-1, + inplace=False, + ) + obs["history"].full = full + obs["history"].response = response + + # Step with the response + next_obs = env.step(obs) + + # Print the conversation + print("Final conversation history:") + final_history = next_obs[("next", "history")].prompt + for i, msg in enumerate(final_history[0]): + print(f"\n[{i}] Role: {msg.role}") + print(f"Content: {msg.content[:200]}{'...' if len(msg.content) > 200 else ''}") + + +# --- Example 3: State passing to tools --- + + +class StateAwareService: + """Example service that uses the environment state.""" + + name = "check_step" + schema_in = {} + schema_out = {"message": str} + + def __call__(self, _state=None, **kwargs): + """Return information about the current environment state.""" + if _state: + step_info = _state.get("env/step", "unknown") + return { + "message": f"Current environment state: step={step_info}", + "state_keys": list(_state.keys()), + } + return {"message": "No state information available"} + + +def example_state_passing(): + """Demonstrate state passing to tools.""" + print("\n" + "=" * 60) + print("Example 3: State Passing to Tools") + print("=" * 60 + "\n") + + # Create registry with state-aware service + registry = ToolRegistry([StateAwareService()]) + parser = XMLBlockParser() + + # Create environment + env = ChatEnv(batch_size=(1,), input_mode="history") + + # Add transform with state passing enabled (default) + env = TransformedEnv( + env, + ExecuteToolsInOrder(registry=registry, parser=parser, pass_state_to_tools=True), + ) + + # Reset and step + reset_data = TensorDict({"query": ["Check the state"]}, batch_size=(1,)) + obs = env.reset(reset_data) + + llm_response = '{}' + history = obs["history"].prompt + response = History(role="assistant", content=llm_response).view(1, 1) + full = history.extend( + response, + dim=-1, + inplace=False, + ) + obs["history"].full = full + obs["history"].response = response + + next_obs = env.step(obs) + + print("Tool received state information:") + final_history = next_obs[("next", "history")].prompt + print(final_history[0][-1].content) + + +# --- Example 4: Error handling --- + + +def example_error_handling(): + """Demonstrate error handling with stop_on_error.""" + print("\n" + "=" * 60) + print("Example 4: Error Handling") + print("=" * 60 + "\n") + + # Create registry + registry = ToolRegistry([CalculatorService()]) + parser = XMLBlockParser() + + # Create environment with stop_on_error=True + env = ChatEnv(batch_size=(1,), input_mode="history") + env = TransformedEnv( + env, + ExecuteToolsInOrder(registry=registry, parser=parser, stop_on_error=False), + ) + + # Reset + reset_data = TensorDict({"query": ["Do some calculations"]}, batch_size=(1,)) + obs = env.reset(reset_data) + + # Response with both valid and invalid calculations + llm_response = """Let me calculate: + +{"expression": "2 + 2"} +{"expression": "invalid + syntax"} +{"expression": "5 * 5"}""" + + history = obs["history"].prompt + response = History(role="assistant", content=llm_response).view(1, 1) + full = history.extend( + response, + dim=-1, + inplace=False, + ) + obs["history"].full = full + obs["history"].response = response + + next_obs = env.step(obs) + + print("Results (note: execution continues after error):") + final_history = next_obs[("next", "history")].prompt + for msg in final_history[0]: + if msg.role == "tool": + print(msg.content) + + +if __name__ == "__main__": + print("\n" + "=" * 60) + print("Tool Service Transform Examples") + print("=" * 60) + + try: + example_xml_parser() + except Exception as e: + print(f"Example 1 failed: {e}") + + try: + example_json_parser() + except Exception as e: + print(f"Example 2 failed: {e}") + + try: + example_state_passing() + except Exception as e: + print(f"Example 3 failed: {e}") + + try: + example_error_handling() + except Exception as e: + print(f"Example 4 failed: {e}") + + print("\n" + "=" * 60) + print("All examples completed!") + print("=" * 60 + "\n") diff --git a/examples/llm/web_search_tool.py b/examples/llm/web_search_tool.py new file mode 100644 index 00000000000..6152bf0f100 --- /dev/null +++ b/examples/llm/web_search_tool.py @@ -0,0 +1,144 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Simple web search tool using DuckDuckGo and SimpleToolTransform.""" + +import urllib.parse +import urllib.request + +from tensordict import set_list_to_stack, TensorDict + +from torchrl.data.llm import History +from torchrl.envs.llm import ChatEnv +from torchrl.envs.llm.transforms import SimpleToolTransform + +set_list_to_stack(True).set() + + +def web_search(query: str) -> dict: + """Search DuckDuckGo and return results.""" + encoded_query = urllib.parse.quote(query) + url = f"https://html.duckduckgo.com/html/?q={encoded_query}" + + try: + headers = {"User-Agent": "Mozilla/5.0"} + req = urllib.request.Request(url, headers=headers) + + with urllib.request.urlopen(req, timeout=10) as response: + html = response.read().decode("utf-8") + + results = [] + for line in html.split("\n"): + if 'class="result__a"' in line and "href=" in line: + start = line.find('href="') + 6 + end = line.find('"', start) + href = line[start:end] + + title_start = line.find(">", end) + 1 + title_end = line.find("<", title_start) + title = line[title_start:title_end] + + if href and title and len(results) < 5: + results.append({"title": title, "url": href}) + + return {"success": True, "query": query, "results": results} + + except Exception as e: + return {"success": False, "error": str(e)} + + +def fetch_webpage(url: str) -> dict: + """Fetch webpage content.""" + try: + headers = {"User-Agent": "Mozilla/5.0"} + req = urllib.request.Request(url, headers=headers) + + with urllib.request.urlopen(req, timeout=10) as response: + html = response.read().decode("utf-8") + + title = "" + if "" in html: + start = html.find("<title>") + 7 + end = html.find("", start) + title = html[start:end] + + text = html + for tag in ["", start) + 1 + close_tag = f""]: + text = text.replace(tag, " ") + + lines = [line.strip() for line in text.split("\n") if line.strip()] + content = " ".join(lines)[:3000] + + return { + "success": True, + "url": url, + "title": title, + "content": content, + } + + except Exception as e: + return {"success": False, "error": str(e)} + + +if __name__ == "__main__": + tools = {"search": web_search, "fetch": fetch_webpage} + + env = ChatEnv(batch_size=(1,)) + env = env.append_transform(SimpleToolTransform(tools=tools)) + + reset_data = TensorDict(query="You are a helpful assistant", batch_size=(1,)) + td = env.reset(reset_data) + + history = td.get("history") + + assistant_response = ( + History( + role="assistant", + content='Let me search for PyTorch tutorials.\nsearch\n{"query": "pytorch tutorial"}', + ) + .unsqueeze(0) + .unsqueeze(0) + ) + + history.full = history.prompt.extend(assistant_response, inplace=True, dim=-1) + history.response = assistant_response + + result = env.step(td.set("history", history)) + + print("Search executed successfully!") + print("\nTool response:") + tool_response = result["next", "history"].prompt[-1] + print(f"Role: {tool_response.role}") + print(f"Content: {tool_response.content[:500]}...") + + fetch_response = ( + History( + role="assistant", + content='fetch\n{"url": "https://pytorch.org"}', + ) + .unsqueeze(0) + .unsqueeze(0) + ) + + history = result["next", "history"] + history.full = history.prompt.extend(fetch_response, inplace=True, dim=-1) + history.response = fetch_response + + result2 = env.step(result["next"].set("history", history)) + + print("\n\nFetch executed successfully!") + print("\nTool response:") + fetch_tool_response = result2["next", "history"].prompt[-1] + print(f"Role: {fetch_tool_response.role}") + print(f"Content: {fetch_tool_response.content[:500]}...") diff --git a/test/llm/test_envs.py b/test/llm/test_envs.py index cc06aa33d56..fa8b03db774 100644 --- a/test/llm/test_envs.py +++ b/test/llm/test_envs.py @@ -6,6 +6,8 @@ import argparse import importlib.util +import json +import os import random import re import time @@ -33,6 +35,7 @@ _has_transformers = importlib.util.find_spec("transformers") is not None _has_datasets = importlib.util.find_spec("datasets") is not None _has_vllm = importlib.util.find_spec("vllm") is not None +_has_mcp = importlib.util.find_spec("mcp") is not None _has_ifeval = ( _has_datasets and (importlib.util.find_spec("langdetect") is not None) @@ -629,9 +632,9 @@ def test_python_interpreter_persistent_reset(self): @pytest.mark.skipif(not _has_transformers, reason="requires transformers") def test_mcp_tool_transform(self): - """Test the MCPToolTransform with a simple calculator tool.""" + """Test the SimpleToolTransform with a simple calculator tool.""" from torchrl.envs.llm import ChatEnv - from torchrl.envs.llm.transforms.tools import MCPToolTransform + from torchrl.envs.llm.transforms.tools import SimpleToolTransform from transformers import AutoTokenizer # Define a simple calculator tool @@ -669,7 +672,7 @@ def calculator(operation: str, a: float, b: float) -> dict: system_prompt="You are a helpful assistant that uses a calculator.", tokenizer=tokenizer, ) - transform = MCPToolTransform(tools, schemas) + transform = SimpleToolTransform(tools, schemas) env = base_env.append_transform(transform) # Test single tool call @@ -785,7 +788,7 @@ def delayed_calculator(cls, operation: str, a: float, b: float) -> dict: # Create environment factory @classmethod def make_env(cls): - from torchrl.envs.llm.transforms.tools import MCPToolTransform + from torchrl.envs.llm.transforms.tools import SimpleToolTransform tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B") env = ChatEnv( @@ -795,7 +798,7 @@ def make_env(cls): ) tools = {"calculator": cls.delayed_calculator} schemas = {"calculator": cls.calculator_schema} - return env.append_transform(MCPToolTransform(tools, schemas)) + return env.append_transform(SimpleToolTransform(tools, schemas)) @pytest.mark.skipif(not _has_transformers, reason="requires transformers") def test_async_mcp_tools(self): @@ -863,6 +866,71 @@ def test_async_mcp_tools(self): finally: env_pool.close() + @pytest.mark.skipif(not _has_mcp, reason="requires mcp library") + @pytest.mark.skipif(not _has_transformers, reason="requires transformers") + def test_mcp_python_execution(self): + """Test actual MCP Python execution with mcp-run-python server.""" + from torchrl.envs.llm.transforms import MCPToolTransform + + # Setup environment for MCP (Deno needs to be in PATH) + environ = os.environ.copy() + deno_path = os.path.expanduser("~/.deno/bin") + if deno_path not in os.environ.get("PATH", ""): + environ["PATH"] = f"{deno_path}:{os.environ['PATH']}" + + # Configure MCP server + servers = { + "python": { + "command": "uvx", + "args": ["mcp-run-python", "stdio"], + "env": environ, + } + } + + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B") + env = ChatEnv( + batch_size=(1,), + system_prompt="You are a helpful assistant", + tokenizer=tokenizer, + ) + env = env.append_transform(MCPToolTransform(servers=servers)) + + # Reset environment + reset_data = TensorDict(query=["You are a useful assistant"], batch_size=(1,)) + td = env.reset(reset_data) + history = td.get("history") + + # Execute Python code via MCP + code = """ +import math +result = math.sqrt(144) + math.pi +print(f"Result: {result}") +result +""" + response = ( + History( + role="assistant", + content=f'Let me calculate that.\npython.run_python_code\n{json.dumps({"python_code": code})}', + ) + .unsqueeze(0) + .unsqueeze(0) + ) + + history.full = history.prompt.extend(response, inplace=True, dim=-1) + history.response = response + + result = env.step(td.set("history", history)) + + # Check that tool was executed + final_history = result["next", "history"].prompt + assert len(final_history[0]) == 4 # system, user, assistant, tool response + assert final_history[0, -1].role == "tool" + + # Check that result contains expected output + tool_content = final_history[0, -1].content + assert "python.run_python_code executed successfully" in tool_content + assert "15.141592653589793" in tool_content or "Result: 15.14" in tool_content + class TestThinkingPrompt: @pytest.fixture(autouse=True, scope="class") diff --git a/torchrl/envs/llm/transforms/__init__.py b/torchrl/envs/llm/transforms/__init__.py index 5ab6b0a8378..e4502042e72 100644 --- a/torchrl/envs/llm/transforms/__init__.py +++ b/torchrl/envs/llm/transforms/__init__.py @@ -20,6 +20,7 @@ JSONCallParser, MCPToolTransform, PythonInterpreter, + SimpleToolTransform, ToolCall, ToolRegistry, ToolService, @@ -40,6 +41,7 @@ "RayDataLoadingPrimer", "RetrieveKL", "RetrieveLogProb", + "SimpleToolTransform", "TemplateTransform", "Tokenizer", "ToolCall", diff --git a/torchrl/envs/llm/transforms/tools.py b/torchrl/envs/llm/transforms/tools.py index 2256af1e3c0..c4f5ab97397 100644 --- a/torchrl/envs/llm/transforms/tools.py +++ b/torchrl/envs/llm/transforms/tools.py @@ -5,15 +5,17 @@ from __future__ import annotations +import asyncio import json import os import queue import re import subprocess +import sys import tempfile import threading import time -from collections.abc import Sequence +from collections.abc import Callable, Sequence from dataclasses import dataclass from typing import Any, Protocol, TextIO @@ -24,6 +26,241 @@ from torchrl.data.llm import History from torchrl.envs import Transform +from typing_extensions import TypedDict + + +# --- Base Class for Tool Transforms --- + + +class ToolTransformBase(Transform): + """Base class for tool transforms that parse and execute tools from LLM output. + + This class handles all the common boilerplate for tool transforms: + - History extraction and validation + - Batch dimension flattening + - Result collection and padding + - History extension with tool results + + Subclasses only need to implement: + - :meth:`_process_batch_item`: Extract and execute tools from one response + - :meth:`_format_result`: Format one tool result as string (optional) + + Attributes: + use_step (bool): Whether to use _step() vs _call(). Defaults to True. + tool_role (str): Role name for results in history. Defaults to "tool". + + Examples: + >>> class SimpleCalculator(ToolTransformBase): + ... tool_role = "calculator" + ... + ... def _process_batch_item(self, content: str, index: int): + ... # Extract math expressions and evaluate + ... if "2+2" in content: + ... return ["2+2=4"] + ... return None + """ + + use_step: bool = True # Use _step() vs _call() + tool_role: str = "tool" # Role name for results in history + + def _validate_and_extract_history( + self, next_tensordict: TensorDictBase + ) -> tuple[History, History]: + """Validate environment and extract history. + + Args: + next_tensordict: The tensordict containing history. + + Returns: + tuple: (full_history, local_history) where local_history is the last message. + + Raises: + RuntimeError: If parent env doesn't exist or isn't in history mode. + """ + # Check that base_env is in history mode + parent = self.parent + if parent is None: + raise RuntimeError(f"{self.__class__.__name__} must be used with a ChatEnv") + base_env = parent.base_env + if base_env.input_mode != "history": + raise RuntimeError( + f"{self.__class__.__name__} must be used with a ChatEnv in history mode" + ) + + # Get history and isolate last element (the LLM's response) + history = next_tensordict["history"].prompt + local_history = history[..., -1] + + return history, local_history + + def _process_batch_item(self, content: str, index: int) -> list[str] | None: + """Process one item in the batch to extract and execute tools. + + This is the main method subclasses must implement. + + Args: + content: The text content from the LLM response. + index: The index of this item in the batch. + + Returns: + list[str] or None: List of result strings for each tool executed, + or None if no tools were found/executed. + """ + raise NotImplementedError( + f"{self.__class__.__name__} must implement _process_batch_item()" + ) + + def _format_result(self, result: str) -> str: + """Format a single result string. + + Override this to customize result formatting. Default is identity. + + Args: + result: Raw result string from tool execution. + + Returns: + str: Formatted result string. + """ + return result + + def _inject_results_to_history( + self, + history: History, + results: list[list[str] | None], + next_tensordict: TensorDictBase, + ) -> TensorDictBase: + """Inject tool results back into history with proper batching. + + Args: + history: The full conversation history. + results: List of results per batch item (can contain None). + next_tensordict: The tensordict to update. + + Returns: + TensorDictBase: Updated tensordict with results in history. + """ + # Convert string results to History objects + procs = [] + for batch_results in results: + if batch_results is None or len(batch_results) == 0: + procs.append(None) + else: + formatted_results = [self._format_result(r) for r in batch_results] + procs.append( + [ + History(role=self.tool_role, content=result) + for result in formatted_results + ] + ) + + # If there are no tool responses, skip + if all(p is None for p in procs): + return next_tensordict + + # Fill None entries with empty lists for consistent batching + if any(p is None for p in procs): + procs = [p if p is not None else [] for p in procs] + + # Pad all results to same length (required for batching) + if len(procs) > 1 and not all(len(p) == len(procs[0]) for p in procs): + + def fill_procs(proc: list[History], max_len: int) -> list[History]: + if len(proc) == max_len: + return proc + return proc + [History(role="", content="")] * ( + max_len - len(proc) + ) + + max_len = max(len(p) for p in procs) + procs = [fill_procs(p, max_len) for p in procs] + + # Stack and extend history + procs = lazy_stack([lazy_stack(p) for p in procs]) + history.extend(procs, dim=-1) + next_tensordict["history"].prompt = history + + return next_tensordict + + def _process_tensordict(self, next_tensordict: TensorDictBase) -> TensorDictBase: + """Main processing logic for tool transforms. + + Handles batch flattening, history extraction, tool processing, and result injection. + + Args: + next_tensordict: The tensordict to process. + + Returns: + TensorDictBase: Updated tensordict with tool results. + """ + # Flatten batch dimensions if needed + if next_tensordict.batch_dims > 1: + with next_tensordict.view(-1) as next_tensordict_flat: + next_tensordict_flat = self._process_tensordict(next_tensordict_flat) + return next_tensordict + + # Extract and validate history + history, local_history = self._validate_and_extract_history(next_tensordict) + + # Handle content as string or list + content = local_history.content + if isinstance(content, str): + content = [content] + + # Process each batch item + results = [] + for i, text in enumerate(content): + batch_results = self._process_batch_item(text, i) + results.append(batch_results) + + # Inject results back into history + return self._inject_results_to_history(history, results, next_tensordict) + + def _step( + self, tensordict: TensorDictBase, next_tensordict: TensorDictBase + ) -> TensorDictBase: + """Handle step with tool processing. + + Args: + tensordict: Input tensordict. + next_tensordict: Output tensordict. + + Returns: + TensorDictBase: Updated next_tensordict. + """ + if not self.use_step: + raise RuntimeError( + f"{self.__class__.__name__} uses _call(), not _step(). Set use_step=False." + ) + return self._process_tensordict(next_tensordict) + + def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase: + """Handle call with tool processing. + + Args: + next_tensordict: The tensordict to process. + + Returns: + TensorDictBase: Updated tensordict. + """ + if self.use_step: + raise RuntimeError( + f"{self.__class__.__name__} uses _step(), not _call(). Set use_step=True." + ) + return self._process_tensordict(next_tensordict) + + def _reset( + self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase + ) -> TensorDictBase: + """Handle reset (no-op for base class). + + Args: + tensordict (TensorDictBase): Input tensordict. + tensordict_reset (TensorDictBase): Reset tensordict. + + Returns: + TensorDictBase: Unchanged reset tensordict. + """ + return tensordict_reset # --- Tool Service Library: Pluggable Services & Parsers --- @@ -135,7 +372,7 @@ class ToolCall: tag: str | None = None -class ParseResult(dict): +class ParseResult(TypedDict): """Result of parsing an LLM response for tool calls. This is a TypedDict-style class that contains: @@ -144,6 +381,10 @@ class ParseResult(dict): meta (dict[str, Any]): Optional parser metadata. """ + text: str + calls: list[ToolCall] + meta: dict[str, Any] + class LLMToolParser(Protocol): """Protocol for parsing LLM responses into ordered tool calls. @@ -165,7 +406,7 @@ def __call__(self, response: str | dict[str, Any]) -> ParseResult: class XMLBlockParser: - """Parser for XML-style tool blocks in LLM responses. + r"""Parser for XML-style tool blocks in LLM responses. Parses tool calls in the format: {"arg": "value"} @@ -274,7 +515,7 @@ def __call__(self, response: str | dict[str, Any]) -> ParseResult: return result -class ExecuteToolsInOrder(Transform): +class ExecuteToolsInOrder(ToolTransformBase): """A Transform that executes tools in the order they appear in LLM output. This transform reads the LLM response, parses ordered tool blocks using a @@ -287,16 +528,6 @@ class ExecuteToolsInOrder(Transform): Args: registry (ToolRegistry): Registry containing available tool services. parser (LLMToolParser): Parser for extracting tool calls from LLM output. - in_keys (tuple[str, ...], optional): Key where LLM response is read. - Defaults to ``("history", "prompt")``. - out_keys (tuple[str, ...], optional): Key where tool results are written. - Defaults to ``("tools", "results")``. - message_key (tuple[str, ...], optional): Key for cleaned message text. - Defaults to ``("llm", "message")``. - history_key (tuple[str, ...], optional): Key for conversation history. - Defaults to ``("history", "prompt")``. - write_calls_key (tuple[str, ...], optional): Key for storing parsed calls. - Defaults to ``("tools", "calls")``. stop_on_error (bool, optional): Whether to stop execution on first error. Defaults to ``False``. pass_state_to_tools (bool, optional): Whether to pass TD state to tools. @@ -332,135 +563,71 @@ class ExecuteToolsInOrder(Transform): not by the order of transforms in the Compose stack. """ + use_step = True # Use _step() method + def __init__( self, registry: ToolRegistry, parser: LLMToolParser, - in_keys: tuple[str, ...] | None = None, - out_keys: tuple[str, ...] | None = None, - message_key: tuple[str, ...] | None = None, - history_key: tuple[str, ...] | None = None, - write_calls_key: tuple[str, ...] | None = None, stop_on_error: bool = False, pass_state_to_tools: bool = True, ): - # Set defaults - if in_keys is None: - in_keys = ("history", "prompt") - if out_keys is None: - out_keys = ("tools", "results") - if message_key is None: - message_key = ("llm", "message") - if history_key is None: - history_key = ("history", "prompt") - if write_calls_key is None: - write_calls_key = ("tools", "calls") - - super().__init__(in_keys=[in_keys], out_keys=[out_keys]) + super().__init__() self.registry = registry self.parser = parser - self._in = in_keys - self._out = out_keys - self._msg = message_key - self._hist = history_key - self._calls = write_calls_key self.stop_on_error = stop_on_error self.pass_state_to_tools = pass_state_to_tools + self.tool_role = "tool" - def _step( - self, tensordict: TensorDictBase, next_tensordict: TensorDictBase - ) -> TensorDictBase: - """Execute tools during environment step. + def _process_batch_item(self, content: str, index: int) -> list[str] | None: + """Process one batch item to extract and execute tools. + + This is the main method required by ToolTransformBase. Args: - tensordict (TensorDictBase): Input tensordict before step. - next_tensordict (TensorDictBase): Output tensordict after step. + content: The text content from the LLM response. + index: The index of this item in the batch. Returns: - TensorDictBase: Modified next_tensordict with tool results. + list[str] or None: List of result strings for each tool executed, + or None if no tools were found. """ - if next_tensordict.batch_dims > 1: - with next_tensordict.view(-1) as next_tensordict_flat: - next_tensordict_flat = self._step(tensordict, next_tensordict_flat) - return next_tensordict + # Parse the response for tool calls + parse: ParseResult = self.parser(content) + ordered_calls = parse["calls"] - # Check that we're in history mode - parent = self.parent - if parent is None: - raise RuntimeError("ExecuteToolsInOrder must be used with a ChatEnv") - base_env = parent.base_env - if base_env.input_mode != "history": - raise RuntimeError( - "ExecuteToolsInOrder must be used with a ChatEnv in history mode" - ) - - # Get the history and extract the last message (LLM response) - history = next_tensordict["history"].prompt - local_history = history[..., -1] - - procs = [] - # Iterate over batch - for i, response_text in enumerate(local_history.content): - # Parse the response for tool calls - parse: ParseResult = self.parser(response_text) - ordered_calls = parse["calls"] - tool_outputs: list[dict[str, Any]] = [] - - # Execute tools IN ORDER OF APPEARANCE - for j, call in enumerate(ordered_calls): - try: - service = self.registry.get(call.tool) - kwargs = dict(call.args) - if self.pass_state_to_tools: - kwargs["_state"] = self._export_state_for_tool(next_tensordict) - - out = service(**kwargs) - out["_tool"] = call.tool - out["_index"] = j - if call.tag: - out["_tag"] = call.tag - tool_outputs.append(out) - except Exception as e: - err = {"_tool": call.tool, "_index": j, "error": str(e)} - tool_outputs.append(err) - if self.stop_on_error: - break - - # Store results and calls in tensordict - if tool_outputs: - # Format tool results as history entries - results_text = self._format_tool_results(tool_outputs) - if results_text: - procs.append([History(role="tool", content=results_text)]) - else: - procs.append(None) - else: - procs.append(None) + if not ordered_calls: + return None - # Add tool results to history if any tools were executed - if not all(p is None for p in procs): - if any(p is None for p in procs): - procs = [p if p is not None else [] for p in procs] + tool_outputs: list[dict[str, Any]] = [] - # Ensure all batch elements have same length - if len(procs) > 1 and not all(len(p) == len(procs[0]) for p in procs): - - def fill_procs(proc: list[History], max_len: int) -> list[History]: - if len(proc) == max_len: - return proc - return proc + [History(role="tool", content="")] * ( - max_len - len(proc) - ) - - max_len = max(len(p) for p in procs) - procs = [fill_procs(p, max_len) for p in procs] - - # Stack and extend history - procs = lazy_stack([lazy_stack(p) for p in procs]) - history.extend(procs, dim=-1) - next_tensordict["history"].prompt = history + # Execute tools IN ORDER OF APPEARANCE + for j, call in enumerate(ordered_calls): + try: + service = self.registry.get(call.tool) + kwargs = dict(call.args) + if self.pass_state_to_tools: + # Get tensordict from parent context if available + # For now, pass empty state - can be enhanced later + kwargs["_state"] = {} + + out = service(**kwargs) + out["_tool"] = call.tool + out["_index"] = j + if call.tag: + out["_tag"] = call.tag + tool_outputs.append(out) + except Exception as e: + err = {"_tool": call.tool, "_index": j, "error": str(e)} + tool_outputs.append(err) + if self.stop_on_error: + break - return next_tensordict + # Format tool results as a single string + # Format tool results as a single string + if tool_outputs: + results_text = self._format_tool_results(tool_outputs) + return [results_text] if results_text else None def _format_tool_results(self, tool_outputs: list[dict[str, Any]]) -> str: """Format tool execution results as text. @@ -494,41 +661,6 @@ def _format_tool_results(self, tool_outputs: list[dict[str, Any]]) -> str: lines.append("") return "\n".join(lines) - def _export_state_for_tool(self, td: TensorDictBase) -> dict[str, Any]: - """Export a filtered, read-only view of TD state for tools. - - Args: - td (TensorDictBase): The tensordict to export from. - - Returns: - dict[str, Any]: Filtered state dictionary. - """ - # Minimal, safe view; customize as needed - keys_for_tools = [("history", "prompt"), ("env", "step"), ("episode", "id")] - out = {} - for k in keys_for_tools: - if td.get(k, None) is not None: - value = td.get(k) - # Convert to Python types if needed - if isinstance(value, torch.Tensor): - value = value.tolist() if value.numel() > 1 else value.item() - out["/".join(k if isinstance(k, tuple) else (k,))] = value - return out - - def _reset( - self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase - ) -> TensorDictBase: - """Handle reset (no-op for this transform). - - Args: - tensordict (TensorDictBase): Input tensordict. - tensordict_reset (TensorDictBase): Reset tensordict. - - Returns: - TensorDictBase: Unchanged reset tensordict. - """ - return tensordict_reset - class PersistentPythonProcess: """A persistent Python process that can execute code blocks.""" @@ -539,6 +671,7 @@ def __init__(self, timeout: float = 10.0): self._error_queue = queue.Queue() self._accumulated_errors = [] self._init_script = None + self.process = None # Initialize to None to avoid AttributeError in __del__ # Start the process self._start_process() @@ -618,7 +751,7 @@ def run_code(code_str): # Start the process try: self.process = subprocess.Popen( - ["python", "-u", self._init_script], # -u for unbuffered output + [sys.executable, "-u", self._init_script], # -u for unbuffered output stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, @@ -688,7 +821,7 @@ def _read_output(self, pipe: TextIO, q: queue.Queue, pipe_name: str) -> None: except Exception: pass - def execute(self, prompt: str) -> dict[str, any]: + def execute(self, prompt: str) -> dict[str, Any]: """Execute code in the persistent process.""" if not self.process or self.process.poll() is not None: # Get any accumulated errors @@ -839,9 +972,12 @@ def __del__(self): self.cleanup() -class PythonInterpreter(Transform): +class PythonInterpreter(ToolTransformBase): r"""A transform that executes Python code in the LLM response. + This transform inherits from :class:`ToolTransformBase` and handles all the + boilerplate for history extraction, batch processing, and result injection. + Args: tokenizer: The tokenizer to use. Defaults to `None` (no tokenizer). tool_name: The name of the tool in the chat history. Defaults to `"tool"`. @@ -885,6 +1021,8 @@ class PythonInterpreter(Transform): '<|im_start|>assistant\n'] """ + use_step = True # Use _step() method + def __init__( self, tokenizer=None, # type: ignore @@ -894,8 +1032,9 @@ def __init__( ): super().__init__() self.tokenizer = tokenizer - self.tool_name = tool_name + self.tool_role = tool_name # Set the role for history entries self.persistent = persistent + self.timeout = timeout # Initialize as empty list if persistent, None otherwise self.processes: list[PersistentPythonProcess | None] = [] if persistent else [] @@ -903,15 +1042,17 @@ def close(self): """Close the transform.""" if self.processes: for process in self.processes: - process.cleanup() + if process: + process.cleanup() self.processes = [] def clone(self): """Clone the transform.""" return self.__class__( tokenizer=self.tokenizer, - tool_name=self.tool_name, + tool_name=self.tool_role, # tool_role is the instance attribute persistent=self.persistent, + timeout=self.timeout, ) def _ensure_processes(self, batch_size: int): @@ -921,11 +1062,11 @@ def _ensure_processes(self, batch_size: int): # Create new processes if needed while len(self.processes) < batch_size: - self.processes.append(PersistentPythonProcess()) + self.processes.append(PersistentPythonProcess(timeout=self.timeout)) if any(p is None for p in self.processes): self.processes = [ - p if p is not None else PersistentPythonProcess() + p if p is not None else PersistentPythonProcess(timeout=self.timeout) for p in self.processes ] @@ -942,7 +1083,15 @@ def _execute_python_code(self, code: str, i: int) -> dict: if i >= len(self.processes): self._ensure_processes(i + 1) # Use persistent process - return self.processes[i].execute(code) + process = self.processes[i] + if process is None: + return { + "success": False, + "stdout": "", + "stderr": "Process not initialized", + "returncode": -1, + } + return process.execute(code) else: # Use temporary file approach try: @@ -953,10 +1102,10 @@ def _execute_python_code(self, code: str, i: int) -> dict: temp_file = f.name result = subprocess.run( - ["python", temp_file], + [sys.executable, temp_file], capture_output=True, text=True, - timeout=10, + timeout=self.timeout, ) os.unlink(temp_file) @@ -990,87 +1139,55 @@ def _extract_python_code(self, text: str) -> list[str]: matches = re.findall(pattern, text, re.DOTALL) return matches - def _process_llm_response(self, response: str, i: int) -> list[str]: - """Process LLM response and execute any Python code found. + def _process_batch_item(self, content: str, index: int) -> list[str] | None: + """Process one batch item to extract and execute Python code. + + This is the main method required by ToolTransformBase. Args: - response (str): The response from the LLM. - i (int): The index of the response in the batch. + content: The text content from the LLM response. + index: The index of this item in the batch. Returns: - list[str]: A list of strings, each containing the result of the execution of the code block. + list[str] or None: List of result strings for each code block executed, + or None if no code blocks were found. """ - code_blocks = self._extract_python_code(response) + # Ensure we have enough processes for persistent mode + if self.persistent: + if index >= len(self.processes): + self._ensure_processes(index + 1) + # Extract code blocks + code_blocks = self._extract_python_code(content) + if not code_blocks: + return None + + # Execute each code block results = [] - for i, code in enumerate(code_blocks): - result = self._execute_python_code(code, i) + for block_idx, code in enumerate(code_blocks): + result = self._execute_python_code(code, index) if result["success"]: results.append( - f"Code block {i + 1} executed successfully:\n{result['stdout']}" + f"Code block {block_idx + 1} executed successfully:\n{result['stdout']}" ) else: - results.append(f"Code block {i + 1} failed:\n{result['stderr']}") + results.append( + f"Code block {block_idx + 1} failed:\n{result['stderr']}" + ) - return results + return results if results else None def _step( self, tensordict: TensorDictBase, next_tensordict: TensorDictBase ) -> TensorDictBase: - if next_tensordict.batch_dims > 1: - with next_tensordict.view(-1) as next_tensordict_flat, tensordict.view( - -1 - ) as tensordict_flat: - # Call the transform on the flattened tensordict - next_tensordict_flat = self._step(tensordict_flat, next_tensordict_flat) - return next_tensordict - - # Ensure we have enough processes for the batch - if self.persistent: + """Override to handle batch size management for persistent processes.""" + # Ensure we have enough processes for the entire batch + if self.persistent and next_tensordict.batch_dims == 1: self._ensure_processes(len(next_tensordict)) - # Convert text to a history - history = next_tensordict["history"].prompt - # Isolate last element, which should be our action - local_history = history[..., -1] - - procs = [] - # Iterate over env batch-size - content = local_history.content - if isinstance(content, str): - content = [content] - for i, t in enumerate(content): - results = self._process_llm_response(t, i) - if len(results) == 0: - procs.append(None) - continue - procs.append( - [History(role=self.tool_name, content=result) for result in results] - ) - - # If there is no tool response, just skip entire batch - if all(p is None for p in procs): - return next_tensordict - if any(p is None for p in procs): - procs = [p if p is not None else [] for p in procs] - # We need to have the same number of items for eache element in the batch - if len(procs) > 1 and not all(len(p) == len(procs[0]) for p in procs): - - def fill_procs(proc: list[History], max_len: int) -> list[History]: - if len(proc) == max_len: - return proc - return proc + [History(role="", content="")] * ( - max_len - len(proc) - ) - - max_len = max(len(p) for p in procs) - procs = [fill_procs(p, max_len) for p in procs] - # Procs has the shape of the batch-size. We can cat along dim=-1 - procs = lazy_stack([lazy_stack(p) for p in procs]) - history.extend(procs, dim=-1) - next_tensordict["history"].prompt = history - return next_tensordict + # Delegate to base class for all the heavy lifting + return super()._step(tensordict, next_tensordict) def __del__(self): """Cleanup persistent processes on deletion.""" @@ -1095,124 +1212,100 @@ def _reset( if reset[i] and process is not None: process.cleanup() self.processes = [ - process if not reset[i] else PersistentPythonProcess() + process + if not reset[i] + else PersistentPythonProcess(timeout=self.timeout) for i, process in enumerate(self.processes) ] return tensordict_reset -class MCPToolTransform(Transform): - r"""A transform that executes MCP-style tools in response to LLM actions. +class SimpleToolTransform(ToolTransformBase): + r"""A simple transform that executes tools from a dictionary of callables. - This transform allows execution of tools following the Mission Control Protocol pattern, - where tools are defined with clear input/output schemas and executed in a controlled manner. + This is a lightweight alternative to MCPToolTransform for simple use cases + where you don't need the full Model Context Protocol infrastructure. Args: - tools (dict[str, callable]): A dictionary mapping tool names to their implementation functions. - Each function should accept kwargs matching its schema and return a dict with results. - tool_schemas (dict[str, dict]): A dictionary mapping tool names to their JSON schemas. - Each schema should define the tool's parameters and return type. - tokenizer: The tokenizer to use. Defaults to `None` (no tokenizer). - tool_name: The name of the tool in the chat history. Defaults to `"tool"`. - timeout: The timeout for tool execution in seconds. Defaults to `10.0`. + tools (dict[str, Callable]): Dictionary mapping tool names to their implementation functions. + Each function should accept kwargs matching its expected parameters. + tool_schemas (dict[str, dict], optional): Dictionary mapping tool names to their schemas. + Used for documentation purposes only. + parser (LLMToolParser | None, optional): Parser for extracting tool calls. If None, + uses a simple XML-style parser. Defaults to None. + tool_call_pattern (str | None, optional): Regex pattern for extracting tool calls. + Only used if parser is None. Format should capture (tool_name, args_json). + Defaults to ``r"(.*?)\\n(.*?)"``. + tool_name (str, optional): Role name for tool results in history. Defaults to "tool". + timeout (float, optional): Timeout for tool execution in seconds. Defaults to 10.0. Examples: - >>> from torchrl.envs.llm.transforms import MCPToolTransform - >>> from transformers import AutoTokenizer - >>> from tensordict import TensorDict, set_list_to_stack + >>> from torchrl.envs.llm.transforms import SimpleToolTransform, XMLBlockParser >>> from torchrl.envs.llm import ChatEnv + >>> from tensordict import TensorDict, set_list_to_stack >>> set_list_to_stack(True).set() - >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B") + >>> >>> # Define a simple tool - >>> def add_numbers(a: int, b: int) -> dict: - ... return {"result": a + b} - >>> # Define its schema - >>> add_schema = { - ... "name": "add_numbers", - ... "description": "Add two numbers", - ... "parameters": { - ... "type": "object", - ... "properties": { - ... "a": {"type": "integer"}, - ... "b": {"type": "integer"} - ... }, - ... "required": ["a", "b"] - ... } - ... } - >>> tools = {"add_numbers": add_numbers} - >>> schemas = {"add_numbers": add_schema} - >>> env = ChatEnv( - ... batch_size=(1,), - ... system_prompt="I'm the system, do as I say", - ... apply_template=True, - ... tokenizer=tokenizer, + >>> def calculator(operation: str, a: float, b: float): + ... if operation == "add": + ... return {"result": a + b} + ... return {"error": "unknown operation"} + >>> + >>> tools = {"calculator": calculator} + >>> env = ChatEnv(batch_size=(1,)) + >>> + >>> # Option 1: Use default XML-style pattern + >>> env = env.append_transform(SimpleToolTransform(tools=tools)) + >>> + >>> # Option 2: Use XMLBlockParser for more features + >>> parser = XMLBlockParser() + >>> env = env.append_transform(SimpleToolTransform(tools=tools, parser=parser)) + >>> + >>> # Option 3: Custom pattern + >>> env = env.append_transform( + ... SimpleToolTransform( + ... tools=tools, + ... tool_call_pattern=r"CALL\[(.*?)\]\((.*?)\)" + ... ) ... ) - >>> env = env.append_transform(MCPToolTransform(tools, schemas)) - >>> r = env.reset(TensorDict(text=["This is the user prompt"], batch_size=(1,))) - >>> r["text_response"] = ["Let me add two numbers:\nadd_numbers\n{\"a\": 1, \"b\": 2}"] - >>> s = env.step(r) - >>> print(s['next', 'history'].apply_chat_template(tokenizer=tokenizer)) - ['<|im_start|>system\n' - "I'm the system, do as I say<|im_end|>\n" - '<|im_start|>user\n' - 'This is the user prompt<|im_end|>\n' - '<|im_start|>assistant\n' - 'Let me add two numbers:\n' - 'add_numbers\n' - '{"a": 1, "b": 2}<|im_end|>\n' - '<|im_start|>user\n' - '\n' - 'Tool add_numbers executed successfully:\n' - '{"result": 3}\n' - '<|im_end|>\n' - '<|im_start|>assistant\n'] """ + use_step = True + def __init__( self, - tools: dict[str, callable], - tool_schemas: dict[str, dict], - tokenizer=None, # type: ignore + tools: dict[str, Callable], + tool_schemas: dict[str, dict] | None = None, + parser: LLMToolParser | None = None, + tool_call_pattern: str | None = None, tool_name: str = "tool", timeout: float = 10.0, ): super().__init__() self.tools = tools - self.tool_schemas = tool_schemas - self.tokenizer = tokenizer - self.tool_name = tool_name + self.tool_schemas = tool_schemas or {} + self.parser = parser + self.tool_call_pattern = tool_call_pattern or r"(.*?)\n(.*?)" + self.tool_role = tool_name self.timeout = timeout - def _extract_tool_calls( - self, text: str - ) -> list[tuple[str, str]]: # noqa: D415, D301, D209, D205 - """Extract tool calls from text in the format tool_name\nargs_json.""" - import re + def _extract_tool_calls(self, text: str) -> list[tuple[str, str]]: + """Extract tool calls from text. - pattern = r"(.*?)\n(.*?)" - matches = re.findall(pattern, text, re.DOTALL) - return matches + Uses parser if provided, otherwise falls back to regex pattern. + """ + if self.parser is not None: + # Use the parser (e.g., XMLBlockParser) + result: ParseResult = self.parser(text) + calls = result.get("calls", []) + return [(call.tool, json.dumps(call.args)) for call in calls] + else: + # Use regex pattern + matches = re.findall(self.tool_call_pattern, text, re.DOTALL) + return matches def _execute_tool(self, tool_name: str, args_json: str) -> dict: """Execute a tool with the given arguments.""" - import json - import signal - from contextlib import contextmanager - - @contextmanager - def timeout_context(seconds): - def signal_handler(signum, frame): - raise TimeoutError("Tool execution timed out") - - # Set the signal handler and a timeout - signal.signal(signal.SIGALRM, signal_handler) - signal.alarm(int(seconds)) - try: - yield - finally: - # Disable the alarm - signal.alarm(0) - try: if tool_name not in self.tools: return { @@ -1222,25 +1315,18 @@ def signal_handler(signum, frame): # Parse arguments try: - args = json.loads(args_json) + args = json.loads(args_json) if args_json.strip() else {} except json.JSONDecodeError as e: return { "success": False, "error": f"Failed to parse tool arguments: {str(e)}", } - # Execute with timeout - with timeout_context(self.timeout): - result = self.tools[tool_name](**args) - return { - "success": True, - "result": result, - } - - except TimeoutError: + # Execute tool + result = self.tools[tool_name](**args) return { - "success": False, - "error": f"Tool execution timed out after {self.timeout} seconds", + "success": True, + "result": result, } except Exception as e: return { @@ -1248,16 +1334,11 @@ def signal_handler(signum, frame): "error": f"Tool execution failed: {str(e)}", } - def _process_llm_response(self, response: str) -> list[str]: - """Process LLM response and execute any tool calls found. - - Args: - response (str): The response from the LLM. - - Returns: - list[str]: A list of strings, each containing the result of a tool execution. - """ - tool_calls = self._extract_tool_calls(response) + def _process_batch_item(self, content: str, index: int) -> list[str] | None: + """Process one batch item to extract and execute simple tools.""" + tool_calls = self._extract_tool_calls(content) + if not tool_calls: + return None results = [] for tool_name, args_json in tool_calls: @@ -1270,65 +1351,441 @@ def _process_llm_response(self, response: str) -> list[str]: else: results.append(f"Tool {tool_name} failed:\n{result['error']}") - return results + return results if results else None - def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase: - if next_tensordict.batch_dims > 1: - with next_tensordict.view(-1) as next_tensordict_flat: - # Call the transform on the flattened tensordict - next_tensordict_flat = self._call(next_tensordict_flat) - return next_tensordict - # Check that base_env is on history mode - parent = self.parent - if parent is None: - raise RuntimeError("MCPToolTransform must be used with a ChatEnv") - base_env = parent.base_env - if base_env.input_mode != "history": - raise RuntimeError( - "MCPToolTransform must be used with a ChatEnv in history mode" - ) +class MCPToolTransform(ToolTransformBase): + r"""A transform that executes tools via the Model Context Protocol (MCP). - # Convert text to a history - history = next_tensordict["history"].prompt - # Isolate last element, which should be our action - local_history = history[..., -1] + This transform connects to MCP servers and executes tools through the official + MCP library. It runs async operations in a background thread to work with + TorchRL's synchronous transform API. - procs = [] - # Iterate over env batch-size - for t in local_history.content: - results = self._process_llm_response(t) - if len(results) == 0: - procs.append(None) - continue - procs.append( - [History(role=self.tool_name, content=result) for result in results] + Args: + servers (dict[str, dict]): Dictionary mapping server names to their configurations. + Each config should have: + - "command" (str): Command to launch the server (e.g., "npx", "uvx") + - "args" (list[str]): Arguments for the command + Example: {"browser": {"command": "npx", "args": ["@browsermcp/mcp@latest"]}} + tool_call_pattern (str, optional): Regex pattern for extracting tool calls. + Should capture (tool_name_with_server, args_json). + Defaults to ``r"([\\w.]+)\\n(.*?)"``. + tool_name (str, optional): Role name for tool results in history. Defaults to "tool". + timeout (float, optional): Timeout for tool execution in seconds. Defaults to 10.0. + + Examples: + >>> import os + >>> import json + >>> from torchrl.envs.llm import ChatEnv + >>> from torchrl.envs.llm.transforms import MCPToolTransform + >>> from torchrl.data.llm import History + >>> from tensordict import TensorDict, set_list_to_stack + >>> set_list_to_stack(True).set() + >>> + >>> # Add Deno to PATH (required for mcp-run-python) + >>> environ = os.environ.copy() + >>> deno_path = os.path.expanduser("~/.deno/bin") + >>> if deno_path not in os.environ.get('PATH', ''): + ... environ['PATH'] = f"{deno_path}:{os.environ['PATH']}" + >>> + >>> # Define MCP servers + >>> servers = { + ... "browser": { + ... "command": "npx", + ... "args": ["@browsermcp/mcp@latest"] + ... }, + ... "python": { + ... "command": "uvx", + ... "args": ["mcp-run-python", "stdio"], + ... "env": environ + ... } + ... } + >>> + >>> # Create environment with MCP transform + >>> env = ChatEnv(batch_size=(1,)) + >>> env = env.append_transform(MCPToolTransform(servers=servers)) # doctest: +SKIP + [torchrl][INFO] Connecting to MCP server 'browser' (npx @browsermcp/mcp@latest) + [torchrl][INFO] Connected to MCP server 'browser' with 12 tools + [torchrl][INFO] Connecting to MCP server 'python' (uvx mcp-run-python stdio) + [torchrl][INFO] Connected to MCP server 'python' with 1 tools + >>> + >>> # Execute Python code via MCP + >>> reset_data = TensorDict(query="You are a useful assistant", batch_size=(1,)) + >>> td = env.reset(reset_data) + >>> history = td.get("history") + >>> code = ''' + ... import math + ... result = math.sqrt(144) + math.pi + ... print(f"Result: {result}") + ... result + ... ''' + >>> response = History( + ... role="assistant", + ... content=f'Let me calculate that.\npython.run_python_code\n{json.dumps({"python_code": code})}', + ... ).unsqueeze(0).unsqueeze(0) + >>> history.full = history.prompt.extend(response, inplace=True, dim=-1) + >>> history.response = response + >>> result = env.step(td.set("history", history)) # doctest: +SKIP + >>> print(result["next", "history", "prompt"][..., -1].content) # doctest: +SKIP + LinkedList(LinkedList(["Tool python.run_python_code executed successfully:\n[TextContent(type='text', text='success\\n\\nResult: 15.141592653589793\\n\\n\\n15.141592653589793\\n', annotations=None, meta=None)]"])) + + .. note:: + This requires the `mcp` package to be installed: `pip install mcp` + The transform manages async MCP connections in a background thread. + + .. note:: + Some MCP servers have additional requirements: + - `mcp-run-python` requires Deno: `curl -fsSL https://deno.land/install.sh | sh` + - Server-specific dependencies should be installed before use + """ + + use_step = True # Use _step() method + + def __init__( + self, + servers: dict[str, dict], + tool_call_pattern: str | None = None, + tool_name: str = "tool", + timeout: float = 10.0, + ): + super().__init__() + self.server_configs = servers + self.tool_call_pattern = tool_call_pattern or r"([\w.]+)\n(.*?)" + self.tool_role = tool_name + self.timeout = timeout + + # MCP session management + self._loop = None + self._thread = None + self._sessions = {} + self._tools_cache = {} + self._shutdown_event = threading.Event() + self._ready_event = threading.Event() + self._connection_error = None + + # Start the async event loop in a background thread + self._start_mcp_thread() + + def _start_mcp_thread(self): + """Start a background thread running an async event loop for MCP, since it's made of coroutines.""" + + def run_loop(): + try: + import asyncio + except ImportError: + self._connection_error = "asyncio not available for MCPToolTransform" + torchrl_logger.error(self._connection_error) + self._ready_event.set() + return + + try: + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + + # Connect to all MCP servers + self._loop.run_until_complete(self._connect_servers()) + + # Signal that initialization is complete + self._ready_event.set() + + # Keep loop running until shutdown + while not self._shutdown_event.is_set(): + self._loop.run_until_complete(asyncio.sleep(0.1)) + + # Cleanup + self._loop.run_until_complete(self._disconnect_servers()) + self._loop.close() + except Exception as e: + self._connection_error = f"MCP thread failed: {str(e)}" + torchrl_logger.error(self._connection_error) + self._ready_event.set() + + self._thread = threading.Thread(target=run_loop, daemon=True) + self._thread.start() + + # Wait for initialization to complete (with timeout) + if not self._ready_event.wait(timeout=10.0): + torchrl_logger.warning("MCP initialization timed out after 10 seconds") + + if self._connection_error: + torchrl_logger.warning( + f"MCP initialization had errors: {self._connection_error}" ) - # If there is no tool response, just skip entire batch - if all(p is None for p in procs): - return next_tensordict - if any(p is None for p in procs): - procs = [p if p is not None else [] for p in procs] - # We need to have the same number of items for each element in the batch - if len(procs) > 1 and not all(len(p) == len(procs[0]) for p in procs): + async def _connect_servers(self): + """Connect to all configured MCP servers.""" + try: + from mcp import ClientSession, StdioServerParameters + from mcp.client.stdio import stdio_client + except ImportError as e: + torchrl_logger.error( + f"MCP library not installed. Install with: pip install mcp\nError: {e}" + ) + return - def fill_procs(proc: list[History], max_len: int) -> list[History]: - if len(proc) == max_len: - return proc - return proc + [History(role="", content="")] * ( - max_len - len(proc) + for server_name, config in self.server_configs.items(): + try: + # Create stdio transport + server_params = StdioServerParameters( + command=config["command"], + args=config.get("args", []), + env=config.get("env", None), ) - max_len = max(len(p) for p in procs) - procs = [fill_procs(p, max_len) for p in procs] - # Procs has the shape of the batch-size. We can cat along dim=-1 - procs = lazy_stack([lazy_stack(p) for p in procs]) - history.extend(procs, dim=-1) - next_tensordict["history"].prompt = history - return next_tensordict + torchrl_logger.info( + f"Connecting to MCP server '{server_name}' ({config['command']} {' '.join(config.get('args', []))})" + ) - def _reset( - self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase - ) -> TensorDictBase: - return tensordict_reset + # Connect and initialize session + stdio = stdio_client(server_params) + try: + read, write = await stdio.__aenter__() + except Exception as e: + error_msg = str(e).lower() + if ( + "deno" in error_msg + or "no such file or directory: 'deno'" in error_msg + ): + torchrl_logger.error( + f"Failed to start stdio for '{server_name}': Deno is not installed.\n" + f" Install Deno: curl -fsSL https://deno.land/install.sh | sh\n" + f" After installing, restart your terminal/shell." + ) + else: + torchrl_logger.error( + f"Failed to start stdio for '{server_name}': {type(e).__name__}: {e}" + ) + raise + + session = ClientSession(read, write) + try: + await session.__aenter__() + except Exception as e: + error_msg = str(e).lower() + if "connection closed" in error_msg: + # Subprocess likely crashed - check for common issues + torchrl_logger.error( + f"Failed to initialize session for '{server_name}': Subprocess terminated.\n" + f" The MCP server '{config['command']}' started but immediately crashed.\n" + f" Common causes:\n" + f" - Missing dependencies (e.g., Deno for mcp-run-python)\n" + f" - Invalid server configuration\n" + f" Try running manually: {config['command']} {' '.join(config.get('args', []))}\n" + f" Error: {e}" + ) + else: + torchrl_logger.error( + f"Failed to initialize session for '{server_name}': {type(e).__name__}: {e}" + ) + # Try to close stdio + try: + await stdio.__aexit__(None, None, None) + except Exception: + pass + raise + + self._sessions[server_name] = { + "session": session, + "stdio": stdio, + } + + # Discover tools + try: + tools_response = await session.list_tools() + tools = {tool.name: tool for tool in tools_response.tools} + self._tools_cache[server_name] = tools + torchrl_logger.info( + f"Connected to MCP server '{server_name}' with {len(tools)} tools" + ) + except Exception as e: + error_msg = str(e).lower() + if "connection closed" in error_msg: + torchrl_logger.error( + f"Could not list tools for server '{server_name}': Connection closed.\n" + f" The MCP server started but crashed immediately.\n" + f" This often means missing dependencies (e.g., Deno for mcp-run-python).\n" + f" Test manually: {config['command']} {' '.join(config.get('args', []))}\n" + f" For mcp-run-python, install Deno: curl -fsSL https://deno.land/install.sh | sh" + ) + else: + torchrl_logger.warning( + f"Could not list tools for server '{server_name}': {e}" + ) + self._tools_cache[server_name] = {} + # Don't keep a session we can't list tools from + try: + await session.__aexit__(None, None, None) + await stdio.__aexit__(None, None, None) + except Exception: + pass + if server_name in self._sessions: + del self._sessions[server_name] + + except FileNotFoundError as e: + # Check if it's a Deno dependency issue + if "deno" in str(e).lower(): + torchrl_logger.error( + f"Failed to connect to MCP server '{server_name}': Deno is not installed.\n" + f" Install Deno: curl -fsSL https://deno.land/install.sh | sh\n" + f" Or use a different MCP server that doesn't require Deno.\n" + f" Error: {e}" + ) + else: + torchrl_logger.error( + f"Failed to connect to MCP server '{server_name}': Command not found.\n" + f" Make sure '{config['command']}' is installed and in your PATH.\n" + f" Error: {e}" + ) + except Exception as e: + torchrl_logger.error( + f"Failed to connect to MCP server '{server_name}': {type(e).__name__}: {e}" + ) + + async def _disconnect_servers(self): + """Disconnect from all MCP servers.""" + for server_name, server_data in self._sessions.items(): + try: + session = server_data["session"] + stdio = server_data["stdio"] + await session.__aexit__(None, None, None) + await stdio.__aexit__(None, None, None) + except Exception as e: + torchrl_logger.warning(f"Error disconnecting from '{server_name}': {e}") + + self._sessions.clear() + self._tools_cache.clear() + + def _extract_tool_calls(self, text: str) -> list[tuple[str, str, str]]: + r"""Extract tool calls from text in format server.tool_name\nargs_json.""" + matches = re.findall(self.tool_call_pattern, text, re.DOTALL) + + # Parse into (server_name, tool_name, args_json) + parsed = [] + for full_name, args_json in matches: + if "." in full_name: + server_name, tool_name = full_name.split(".", 1) + else: + # Default to first server if no prefix + server_name = next(iter(self.server_configs.keys()), None) + tool_name = full_name + + if server_name: + parsed.append((server_name, tool_name, args_json)) + + return parsed + + def _execute_tool_sync( + self, server_name: str, tool_name: str, args_json: str + ) -> dict: + """Execute a tool via MCP (blocking call that schedules async work).""" + if not self._loop or not self._thread or not self._thread.is_alive(): + return { + "success": False, + "error": "MCP thread not running", + } + + # Schedule the async call in the background thread + future = asyncio.run_coroutine_threadsafe( + self._execute_tool_async(server_name, tool_name, args_json), self._loop + ) + + try: + result = future.result(timeout=self.timeout) + return result + except TimeoutError: + return { + "success": False, + "error": f"Tool execution timed out after {self.timeout}s", + } + except Exception as e: + return { + "success": False, + "error": f"Tool execution failed: {str(e)}", + } + + async def _execute_tool_async( + self, server_name: str, tool_name: str, args_json: str + ) -> dict: + """Execute a tool via MCP (async implementation).""" + try: + # Check if server exists + if server_name not in self._sessions: + return { + "success": False, + "error": f"MCP server '{server_name}' not connected", + } + + session = self._sessions[server_name]["session"] + + # Parse arguments + try: + args = json.loads(args_json) if args_json.strip() else {} + except json.JSONDecodeError as e: + return { + "success": False, + "error": f"Failed to parse tool arguments: {str(e)}", + } + + # Call the tool via MCP + result = await session.call_tool(tool_name, arguments=args) + + return { + "success": True, + "result": result.content if hasattr(result, "content") else str(result), + } + + except Exception as e: + return { + "success": False, + "error": f"MCP tool call failed: {str(e)}", + } + + def _process_batch_item(self, content: str, index: int) -> list[str] | None: + """Process one batch item to extract and execute MCP tools. + + This is the main method required by ToolTransformBase. + + Args: + content: The text content from the LLM response. + index: The index of this item in the batch. + + Returns: + list[str] or None: List of result strings for each tool executed, + or None if no tools were found. + """ + # Extract tool calls + tool_calls = self._extract_tool_calls(content) + if not tool_calls: + return None + + # Execute each tool via MCP + results = [] + for server_name, tool_name, args_json in tool_calls: + result = self._execute_tool_sync(server_name, tool_name, args_json) + + if result["success"]: + results.append( + f"Tool {server_name}.{tool_name} executed successfully:\n{result['result']}" + ) + else: + results.append( + f"Tool {server_name}.{tool_name} failed:\n{result['error']}" + ) + + return results if results else None + + def close(self): + """Shutdown the MCP connections and background thread.""" + if self._thread and self._thread.is_alive(): + self._shutdown_event.set() + self._thread.join(timeout=2.0) + + self._loop = None + self._thread = None + + def __del__(self): + """Ensure cleanup on deletion.""" + try: + self.close() + except Exception: + pass From d747f8e7eddf625482635bd51814623c24bff95f Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 27 Oct 2025 18:34:04 +0000 Subject: [PATCH 3/3] Update [ghstack-poisoned] --- docs/source/reference/index.rst | 1 + docs/source/reference/llms.rst | 5 +- docs/source/reference/services.rst | 609 ++++++++++++++++++++++ examples/services/distributed_services.py | 245 +++++++++ test/test_services.py | 511 ++++++++++++++++++ torchrl/envs/llm/transforms/__init__.py | 2 + torchrl/envs/llm/transforms/tools.py | 185 ++++++- torchrl/services/__init__.py | 77 +++ torchrl/services/base.py | 108 ++++ torchrl/services/ray_service.py | 452 ++++++++++++++++ 10 files changed, 2181 insertions(+), 14 deletions(-) create mode 100644 docs/source/reference/services.rst create mode 100644 examples/services/distributed_services.py create mode 100644 test/test_services.py create mode 100644 torchrl/services/__init__.py create mode 100644 torchrl/services/base.py create mode 100644 torchrl/services/ray_service.py diff --git a/docs/source/reference/index.rst b/docs/source/reference/index.rst index 53f4c246628..b1603fb6264 100644 --- a/docs/source/reference/index.rst +++ b/docs/source/reference/index.rst @@ -10,6 +10,7 @@ API Reference llms modules objectives + services trainers utils config diff --git a/docs/source/reference/llms.rst b/docs/source/reference/llms.rst index 3a04bc5c95c..2a52e5f1160 100644 --- a/docs/source/reference/llms.rst +++ b/docs/source/reference/llms.rst @@ -930,7 +930,9 @@ Tools are usually implemented as transforms, and appended to a base environment such as :class:`~torchrl.envs.llm.ChatEnv`. An example of a tool transform is the :class:`~torchrl.envs.llm.transforms.PythonInterpreter` transform, which is used -to execute Python code in the context of the LLM. +to execute Python code in the context of the LLM. The PythonInterpreter can optionally use a shared +:class:`~torchrl.envs.llm.transforms.PythonExecutorService` for efficient resource usage across multiple environments. +See :ref:`ref_services` for more details on the service registry system. >>> from torchrl.envs.llm.transforms import PythonInterpreter >>> from torchrl.envs.llm import ChatEnv @@ -1141,6 +1143,7 @@ By following these design principles, reward transforms can be effectively integ KLRewardTransform MCPToolTransform PolicyVersion + PythonExecutorService PythonInterpreter RayDataLoadingPrimer RetrieveKL diff --git a/docs/source/reference/services.rst b/docs/source/reference/services.rst new file mode 100644 index 00000000000..72d26fd3c54 --- /dev/null +++ b/docs/source/reference/services.rst @@ -0,0 +1,609 @@ +.. currentmodule:: torchrl + +Service Registry +================ + +.. _ref_services: + +TorchRL provides a service registry system for managing distributed services across workers in distributed applications. +This is particularly useful for sharing resources like tokenizers, replay buffers, or Python executor pools across +multiple environments or collectors. + +The service registry provides a **backend-agnostic API** for distributed service management. While the current +implementation focuses on Ray as the primary backend, the design allows for future backends (e.g., Monarch, +local multiprocessing) without changing the core API. + +Overview +-------- + +The service registry provides a centralized way to register and access distributed services that can be shared across +different parts of your application. Services are registered once and can be accessed by any worker, with the underlying +backend handling the distributed communication and resource management. + +**Current Backend Support:** + +- **Ray**: Full support for Ray-based distributed services (recommended for production use) +- **Other backends**: Planned for future releases (e.g., Monarch, local multiprocessing) + +Key Features +~~~~~~~~~~~~ + +- **Centralized Management**: Register services once and access them from anywhere in your distributed system +- **Namespace Isolation**: Services are isolated within namespaces for multi-tenant support +- **Type Safety**: Dict-like access with ``services["name"]`` syntax +- **Automatic Cleanup**: Reset all services in a namespace with a single call +- **Backend Flexibility**: Designed to support multiple distributed backends (currently Ray) + +Basic Usage +----------- + +Getting Started +~~~~~~~~~~~~~~~ + +The service registry API is backend-agnostic, but you need to specify which backend to use when getting the registry. +Currently, Ray is the only supported backend. + +.. code-block:: python + + import ray + from torchrl.services import get_services + + # Initialize your backend (Ray in this example) + ray.init() + + # Get the service registry for your chosen backend + services = get_services(backend="ray", namespace="my_namespace") + + # Register a service (the class will become a distributed service) + services.register( + "tokenizer", + TokenizerService, + vocab_size=50000, + num_cpus=1, # Backend-specific option (Ray) + ) + + # Access the service from any worker + # (other workers just need to call get_services with the same backend and namespace) + services = get_services(backend="ray", namespace="my_namespace") + tokenizer = services["tokenizer"] + + # Call the service (syntax depends on backend) + # For Ray, you need to use .remote() and ray.get() + result = ray.get(tokenizer.encode.remote("Hello world")) + + # Cleanup when done + services.reset() + ray.shutdown() + +Service Registration +~~~~~~~~~~~~~~~~~~~~ + +Services are registered by providing a name, a class (that will become a distributed service), and any initialization arguments. +The exact behavior depends on the backend being used. + +**Basic Registration (Backend-Agnostic):** + +.. code-block:: python + + # Register a service with constructor arguments + services.register( + "my_service", + MyServiceClass, + arg1="value1", + arg2="value2", + ) + +The ``register`` method accepts: + +- **name** (str): Unique identifier for the service +- **service_factory** (type): Class to instantiate as a distributed service +- **kwargs**: Arguments passed to the service constructor and/or backend-specific options + +**Backend-Specific Options (Ray):** + +When using the Ray backend, you can pass Ray actor options alongside constructor arguments: + +.. code-block:: python + + # Ray-specific: Mix actor options and constructor arguments + services.register( + "gpu_service", + GPUService, + model_name="gpt2", # Constructor argument + num_cpus=4, # Ray actor option + num_gpus=1, # Ray actor option + max_concurrency=16, # Ray actor option + ) + +For more explicit separation of backend options and constructor arguments, the Ray backend provides +``register_with_options`` (note that options are expected not to collide with constructor arguments): + +.. code-block:: python + + # Ray-specific: Explicit separation of options + services.register_with_options( + "my_service", + MyServiceClass, + actor_options={ + "num_cpus": 4, + "num_gpus": 1, + "max_concurrency": 16, + }, + model_name="gpt2", # Constructor argument + batch_size=32, # Constructor argument + ) + +.. note:: + The ``register_with_options`` method is specific to the Ray backend. Other backends may have + different mechanisms for separating backend options from constructor arguments. + +Service Access +~~~~~~~~~~~~~~ + +Services can be accessed using dict-like syntax: + +.. code-block:: python + + # Check if service exists + if "tokenizer" in services: + tokenizer = services["tokenizer"] + + # Get service (raises KeyError if not found) + tokenizer = services["tokenizer"] + + # Alternative: use get() method + tokenizer = services.get("tokenizer") + + # List all services + service_names = services.list() + print(f"Available services: {service_names}") + +Cross-Worker Visibility +~~~~~~~~~~~~~~~~~~~~~~~ + +Services registered by one worker are immediately visible to all other workers in the same namespace. +This is a core feature of the service registry, enabled by the underlying distributed backend. + +**Example with Ray Backend:** + +.. code-block:: python + + import ray + from torchrl.services import get_services + + @ray.remote + class Worker: + def register_service(self): + # Worker 1: Register a service + services = get_services(backend="ray", namespace="shared") + services.register("shared_tokenizer", TokenizerService, vocab_size=50000) + return "registered" + + def use_service(self): + # Worker 2: Use the service registered by Worker 1 + services = get_services(backend="ray", namespace="shared") + tokenizer = services["shared_tokenizer"] + return ray.get(tokenizer.encode.remote("Hello")) + + # Worker 1 registers the service + worker1 = Worker.remote() + ray.get(worker1.register_service.remote()) + + # Worker 2 can immediately use it + worker2 = Worker.remote() + result = ray.get(worker2.use_service.remote()) + +The key insight is that both workers access the same service registry by using the same ``backend`` and +``namespace`` parameters in ``get_services()``. The backend handles the distributed coordination. + +Namespace Isolation +~~~~~~~~~~~~~~~~~~~ + +Different namespaces provide complete isolation between service registries: + +.. code-block:: python + + # Training namespace + train_services = get_services(backend="ray", namespace="training") + train_services.register("tokenizer", TokenizerService, vocab_size=50000) + + # Evaluation namespace + eval_services = get_services(backend="ray", namespace="evaluation") + eval_services.register("tokenizer", TokenizerService, vocab_size=30000) + + # These are completely independent services + assert "tokenizer" in train_services + assert "tokenizer" in eval_services + # But they have different configurations + +Cleanup +~~~~~~~ + +Always clean up services when done to free resources: + +.. code-block:: python + + # Reset all services in a namespace + services.reset() + + # This terminates all service actors and clears the registry + # After reset(), the registry is empty + assert services.list() == [] + +Python Executor Service +----------------------- + +One of the most useful built-in services is the :class:`~torchrl.envs.llm.transforms.PythonExecutorService`, +which provides a shared pool of Python interpreter processes for executing code across multiple environments. +This service is designed to work with any backend, though it's currently optimized for Ray. + +Overview +~~~~~~~~ + +The Python Executor Service allows you to share a fixed pool of Python interpreters (e.g., 32 processes) across +many environments (e.g., 128 environments). This provides significant resource savings compared to giving each +environment its own interpreter process. The service is registered through the service registry and can be +accessed by any worker using the :class:`~torchrl.envs.llm.transforms.PythonInterpreter` transform. + +**Resource Efficiency:** + ++---------------------------+---------------+------------+------------------+ +| Configuration | Environments | Processes | Resource Usage | ++===========================+===============+============+==================+ +| Local (persistent) | 128 | 128 | 100% | ++---------------------------+---------------+------------+------------------+ +| Service (pool=32) | 128 | 32 | **25%** | ++---------------------------+---------------+------------+------------------+ +| Service (pool=64) | 128 | 64 | **50%** | ++---------------------------+---------------+------------+------------------+ + +Basic Usage +~~~~~~~~~~~ + +The Python Executor Service is registered like any other service, then accessed through the +:class:`~torchrl.envs.llm.transforms.PythonInterpreter` transform by specifying ``services="ray"`` +(or the appropriate backend name). + +**Example with Ray Backend:** + +.. code-block:: python + + import ray + from torchrl.services import get_services + from torchrl.envs.llm.transforms import PythonExecutorService, PythonInterpreter + from torchrl.envs.llm import ChatEnv + + # Initialize your backend + ray.init() + + # Register the Python executor service + services = get_services(backend="ray", namespace="my_namespace") + services.register( + "python_executor", + PythonExecutorService, + pool_size=32, # 32 interpreter processes + timeout=10.0, # 10 second timeout + num_cpus=32, # Ray-specific: Allocate 32 CPUs + max_concurrency=32, # Ray-specific: Allow 32 concurrent executions + ) + + # Create environments that use the service + env = ChatEnv( + batch_size=(128,), # 128 parallel environments + system_prompt="Execute Python code when requested.", + ) + + # Add PythonInterpreter transform configured to use the service + env = env.append_transform( + PythonInterpreter( + services="ray", # Use Ray backend + namespace="my_namespace", # Same namespace as registration + ) + ) + + # All 128 environments now share the 32 interpreters! + # The backend (Ray) automatically queues requests when all interpreters are busy + +Optional Service Usage +~~~~~~~~~~~~~~~~~~~~~~ + +The :class:`~torchrl.envs.llm.transforms.PythonInterpreter` transform supports optional service usage. +You can easily switch between using a shared service or local processes: + +.. code-block:: python + + # Option 1: Use shared Ray service (recommended for many envs) + env = env.append_transform( + PythonInterpreter( + services="ray", + namespace="my_namespace", + ) + ) + + # Option 2: Use local persistent processes (good for few envs) + env = env.append_transform( + PythonInterpreter( + services=None, + persistent=True, + ) + ) + + # Option 3: Use temporary processes (good for infrequent use) + env = env.append_transform( + PythonInterpreter( + services=None, + persistent=False, + ) + ) + +Conditional Usage Pattern +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +You can decide at runtime whether to use a distributed service based on your configuration: + +.. code-block:: python + + import ray + from torchrl.services import get_services + from torchrl.envs.llm.transforms import PythonExecutorService, PythonInterpreter + + num_envs = 128 + use_distributed_service = ray.is_initialized() and num_envs > 16 + + if use_distributed_service: + # Use distributed service for efficient resource usage + services = get_services(backend="ray") # Could be other backends in the future + if "python_executor" not in services: + services.register( + "python_executor", + PythonExecutorService, + pool_size=32, + timeout=10.0, + num_cpus=32, # Backend-specific option + max_concurrency=32, # Backend-specific option + ) + + # Configure transform to use the service + interpreter = PythonInterpreter(services="ray") + else: + # Use local processes (no distributed backend) + interpreter = PythonInterpreter(services=None, persistent=True) + + env = env.append_transform(interpreter) + +How It Works +~~~~~~~~~~~~ + +The Python Executor Service uses a simple round-robin assignment strategy to distribute work across +a pool of interpreter processes. The backend handles concurrency control and request queuing. + +**Architecture:** + +1. **Pool of Interpreters**: The service maintains a fixed pool of ``PersistentPythonProcess`` instances +2. **Round-Robin Assignment**: Each request is assigned to the next interpreter in the pool +3. **Backend Queuing**: When all interpreters are busy, the backend queues additional requests +4. **Concurrent Execution**: The backend controls how many requests can execute simultaneously + +.. code-block:: python + + # Inside PythonExecutorService + def execute(self, code: str) -> dict: + # Simple round-robin assignment + with self._lock: + process = self.processes[self.next_idx] + self.next_idx = (self.next_idx + 1) % self.pool_size + + # Backend handles queuing (e.g., Ray's max_concurrency parameter) + return process.execute(code) + +**Backend-Specific Behavior:** + +- **Ray**: Uses the ``max_concurrency`` parameter to control concurrent executions. Requests beyond + this limit are automatically queued by Ray's actor system. +- **Other backends**: Will have their own mechanisms for concurrency control and queuing. + +Performance Considerations +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +**When to Use Service Mode (Distributed):** + +- Running > 16 parallel environments +- Resource efficiency is important +- Code execution is frequent +- Have a distributed backend available (e.g., Ray) + +**When to Use Local Persistent Mode:** + +- Running < 16 environments +- Need strict isolation between environments +- Latency is critical +- Don't want distributed backend dependency + +**When to Use Local Temp File Mode:** + +- Code execution is infrequent +- Don't want persistent processes +- Memory is more important than speed + +Advanced Usage +-------------- + +Multiple Service Configurations +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +You can register multiple services with different configurations: + +.. code-block:: python + + services = get_services(backend="ray") + + # Fast service for simple code + services.register( + "python_executor_fast", + PythonExecutorService, + pool_size=16, + timeout=5.0, + num_cpus=16, + max_concurrency=16, + ) + + # Heavy service for complex code + services.register( + "python_executor_heavy", + PythonExecutorService, + pool_size=64, + timeout=30.0, + num_cpus=64, + max_concurrency=64, + ) + + # Use different services for different environments + fast_env = env.append_transform( + PythonInterpreter(services="ray", service_name="python_executor_fast") + ) + heavy_env = env.append_transform( + PythonInterpreter(services="ray", service_name="python_executor_heavy") + ) + +Custom Services +~~~~~~~~~~~~~~~ + +You can create your own services by defining a class and registering it: + +.. code-block:: python + + class MyCustomService: + """A custom service for your application.""" + + def __init__(self, config: dict): + self.config = config + # Initialize your service + + def process(self, data: str) -> dict: + # Process data and return results + return {"result": f"Processed: {data}"} + + # Register the custom service + services = get_services(backend="ray") + services.register( + "my_service", + MyCustomService, + config={"param1": "value1"}, + num_cpus=2, + ) + + # Use the service + my_service = services["my_service"] + result = ray.get(my_service.process.remote("Hello")) + +API Reference +------------- + +Service Registry +~~~~~~~~~~~~~~~~ + +.. currentmodule:: torchrl.services + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + get_services + reset_services + ServiceBase + RayService + +Python Executor Service +~~~~~~~~~~~~~~~~~~~~~~~ + +.. currentmodule:: torchrl.envs.llm.transforms + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + PythonExecutorService + PythonInterpreter + +Best Practices +-------------- + +1. **Specify Backend and Namespace**: Always explicitly specify both the backend and namespace when calling + ``get_services()`` to ensure services are registered and accessed from the correct location. + +2. **Clean Up**: Always call ``services.reset()`` when done to free resources and terminate distributed services. + +3. **Service Naming**: Use descriptive names that indicate the service's purpose (e.g., ``"python_executor"``, + ``"tokenizer_service"``). + +4. **Backend-Specific Options**: Understand which options are backend-specific (e.g., ``num_cpus``, ``num_gpus``, + ``max_concurrency`` for Ray) and which are constructor arguments for your service class. + +5. **Error Handling**: Check if services exist before accessing them: + + .. code-block:: python + + if "my_service" in services: + service = services["my_service"] + else: + # Register or handle missing service + +6. **Conditional Registration**: Only register services if they don't already exist: + + .. code-block:: python + + if "python_executor" not in services: + services.register("python_executor", PythonExecutorService, ...) + +7. **Context Managers**: Consider using context managers for automatic cleanup: + + .. code-block:: python + + class ServiceContext: + def __init__(self, backend, namespace): + self.services = get_services(backend=backend, namespace=namespace) + + def __enter__(self): + return self.services + + def __exit__(self, *args): + self.services.reset() + + with ServiceContext("ray", "my_namespace") as services: + services.register("my_service", MyService) + # Use services... + # Automatic cleanup + +8. **Backend Portability**: When writing code that should work with multiple backends, avoid using + backend-specific methods like ``register_with_options()`` (Ray-only). Stick to the common ``register()`` + API for maximum portability. + +Examples +-------- + +For complete examples, see: + +- ``examples/services/distributed_services.py`` - Basic service registry usage +- ``examples/llm/python_executor_service.py`` - Python executor service examples +- ``test/test_services.py`` - Comprehensive test suite +- ``test/test_python_executor_service.py`` - Python executor service tests + +See Also +-------- + +- :ref:`ref_llms` - LLM API documentation +- :ref:`ref_collectors` - Collector documentation +- `Ray Documentation `_ - Ray distributed framework documentation + +.. note:: + **Future Backend Support** + + The service registry is designed to be backend-agnostic. While Ray is currently the only supported + backend, the API is structured to easily accommodate additional backends in the future, such as: + + - **Monarch**: For specialized distributed computing scenarios + - **Local Multiprocessing**: For single-node parallelism without external dependencies + - **Custom Backends**: You can implement your own backend by subclassing :class:`~torchrl.services.ServiceBase` + + The core API (``get_services()``, ``register()``, ``get()``, ``list()``, ``reset()``) will remain + consistent across all backends, ensuring your code remains portable. diff --git a/examples/services/distributed_services.py b/examples/services/distributed_services.py new file mode 100644 index 00000000000..473d4508c01 --- /dev/null +++ b/examples/services/distributed_services.py @@ -0,0 +1,245 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Example: Distributed Service Registry with Ray + +This example demonstrates how to use TorchRL's service registry to share +actors (tokenizers, replay buffers, etc.) across distributed workers. + +Key features: +- Services registered by one worker are immediately visible to all workers +- Supports Ray's full options API for resource management +- Clean dict-like interface for service access + +Run this example: + python examples/services/distributed_services.py +""" + +import ray +from torchrl.services import get_services + + +# Example 1: Simple service class +class TokenizerService: + """A simple tokenizer service that can be shared across workers.""" + + def __init__(self, vocab_size: int = 10000): + self.vocab_size = vocab_size + print(f"TokenizerService initialized with vocab_size={vocab_size}") + + def encode(self, text: str) -> list[int]: + """Simple character-based encoding.""" + return [hash(c) % self.vocab_size for c in text] + + def decode(self, tokens: list[int]) -> str: + """Simple decoding.""" + return "".join([chr(65 + (t % 26)) for t in tokens]) + + +# Example 2: Stateful service +class CounterService: + """A stateful counter service.""" + + def __init__(self, initial_value: int = 0): + self.count = initial_value + + def increment(self) -> int: + self.count += 1 + return self.count + + def get_count(self) -> int: + return self.count + + def reset(self): + self.count = 0 + + +# Example 3: Resource-intensive service +class ModelService: + """Simulates a model service that needs GPU resources.""" + + def __init__(self, model_name: str): + self.model_name = model_name + print(f"ModelService '{model_name}' initialized") + + def predict(self, data: list) -> list: + """Simulate model inference.""" + return [x * 2 for x in data] # Simple transformation + + +@ray.remote +class Worker: + """Worker class that uses services. + + Separates initialization from execution for better control. + """ + + def __init__(self, worker_id: int, namespace: str): + """Initialize worker and get service registry.""" + self.worker_id = worker_id + self.namespace = namespace + self.services = get_services(backend="ray", namespace=namespace) + print(f"\n[Worker {worker_id}] Initialized") + + def setup_services(self): + """Register shared services (typically done by worker 0).""" + print(f"[Worker {self.worker_id}] Registering services...") + + # Register tokenizer with specific resources + self.services.register( + "tokenizer", + TokenizerService, + vocab_size=5000, + num_cpus=1, + num_gpus=0, + ) + + # Register counter + self.services.register("counter", CounterService, initial_value=0, num_cpus=0.5) + + # Register model with GPU (if available) + self.services.register( + "model", + ModelService, + model_name="my_model", + num_cpus=2, + num_gpus=0, # Set to 1 if GPU available + ) + + print(f"[Worker {self.worker_id}] Services registered: {self.services.list()}") + return "setup_complete" + + def run(self): + """Execute worker's main task using services.""" + print(f"[Worker {self.worker_id}] Starting execution...") + print(f"[Worker {self.worker_id}] Available services: {self.services.list()}") + + # Use tokenizer + if "tokenizer" in self.services: + tokenizer = self.services["tokenizer"] + text = f"Hello from worker {self.worker_id}" + encoded = ray.get(tokenizer.encode.remote(text)) + decoded = ray.get(tokenizer.decode.remote(encoded)) + print(f"[Worker {self.worker_id}] Tokenizer - Encoded/decoded: '{decoded}'") + + # Use counter (demonstrates statefulness) + if "counter" in self.services: + counter = self.services["counter"] + count = ray.get(counter.increment.remote()) + print(f"[Worker {self.worker_id}] Counter incremented to: {count}") + + # Use model + if "model" in self.services: + model = self.services["model"] + result = ray.get(model.predict.remote([1, 2, 3])) + print(f"[Worker {self.worker_id}] Model prediction: {result}") + + print(f"[Worker {self.worker_id}] Finished!") + return f"Worker {self.worker_id} completed" + + +def main(): + """Main function demonstrating service registry usage.""" + print("=== TorchRL Distributed Service Registry Example ===\n") + + # Initialize Ray + if not ray.is_initialized(): + ray.init(ignore_reinit_error=True) + + namespace = "example_services" + + # Example 1: Basic usage + print("--- Example 1: Basic Usage ---") + services = get_services(backend="ray", namespace=namespace) + + # Register a simple service + services.register("shared_tokenizer", TokenizerService, vocab_size=1000) + + # Access it + tokenizer = services["shared_tokenizer"] + result = ray.get(tokenizer.encode.remote("Hello")) + print(f"Encoded 'Hello': {result}\n") + + # Example 2: Conditional registration pattern + print("--- Example 2: Conditional Registration ---") + assert "shared_tokenizer" in services + try: + services.register("shared_tokenizer", TokenizerService, vocab_size=1000) + raise RuntimeError("Registed twice! Should not happen!") + except ValueError: + print("shared_tokenizer already registered") + + # Example 3: Multiple workers using same services + print("--- Example 3: Multiple Workers Sharing Services ---") + + # Create worker actors + num_workers = 3 + workers = [ + Worker.remote(worker_id, namespace) for worker_id in range(num_workers) # type: ignore[attr-defined] + ] + + # Worker 0 sets up services, others wait for completion + # The registered services are: (1) tokenizer, (2) counter, (3) model + print("Worker 0 setting up services...") + setup_complete = ray.get(workers[0].setup_services.remote()) + print(f"Setup complete: {setup_complete}") + + # note that in `main` the services are updated too! + assert "tokenizer" in services + assert "counter" in services + assert "model" in services + + # Now all workers can run in parallel + print("\nAll workers executing...") + run_futures = [worker.run.remote() for worker in workers] + + # Wait for all workers to complete + results = ray.get(run_futures) + print("\nAll workers completed:", results) + + # Example 4: Using register_with_options for clarity + print("\n--- Example 4: Register with Explicit Options ---") + services.reset() + # More explicit separation of Ray options vs constructor args + services.register_with_options( # type: ignore[attr-defined] + "tokenizer", + TokenizerService, + actor_options={ + "num_cpus": 2, + "num_gpus": 0, + "max_concurrency": 10, + "memory": 5 * 1024**3, + }, + vocab_size=50000, # Constructor argument + ) + + tokenizer = services["tokenizer"] + result = ray.get(tokenizer.encode.remote("test")) + print(f"Used tokenizer with explicit options: {result}") + + # Example 5: Listing all services + print("\n--- Example 5: Service Discovery ---") + all_services = services.list() + print(f"All registered services ({len(all_services)}): {all_services}") + + # Example 6: Resetting services + print("\n--- Example 6: Resetting Services ---") + print(f"Services before reset: {services.list()}") + + # Reset will terminate all actors and clear the registry + services.reset() + + print(f"Services after reset: {services.list()}") + assert len(services.list()) == 0, "Registry should be empty after reset" + + print("\n=== Example Complete ===") + + # Cleanup + ray.shutdown() + + +if __name__ == "__main__": + main() diff --git a/test/test_services.py b/test/test_services.py new file mode 100644 index 00000000000..a62ce9a9052 --- /dev/null +++ b/test/test_services.py @@ -0,0 +1,511 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import pytest + +pytest.importorskip("ray") + +# Import from mocking_classes which is a proper module +import sys +from pathlib import Path + +import ray +from ray.util.state import get_actor as get_actor_by_id +from torchrl._utils import logger + +sys.path.insert(0, str(Path(__file__).parent)) + + +from test_services_fixtures import SimpleService, TokenizerService +from torchrl.services import get_services, RayService + + +@pytest.fixture(scope="module", autouse=True) +def ray_init(): + """Initialize Ray once for the entire test module.""" + import os + + if ray.is_initialized(): + ray.shutdown(raise_on_error=False) + + # Add test directory to Ray worker PYTHONPATH so they can import test_services_fixtures + test_dir = os.path.dirname(os.path.abspath(__file__)) + + ray.init( + ignore_reinit_error=True, + namespace="test_torchrl_services", + runtime_env={"env_vars": {"PYTHONPATH": test_dir}}, + ) + yield + # Cleanup once at the end + ray.shutdown() + + +@pytest.fixture(scope="function", autouse=True) +def kill_all_actors(): + """Kill all actors after each test.""" + yield + if not ray.is_initialized(): + return + # list actors + actors = ray.state.actors() + if not actors: + return + for actor_id, info in actors.items(): + if info["State"] == "ALIVE": + try: + actor = get_actor_by_id(actor_id) + ray.kill(actor, no_restart=True) + except Exception as e: + logger.warning(f"Error killing actor {actor_id}: {e}") + + +@pytest.fixture(scope="function", autouse=True) +def cleanup_services(): + """Clean up services after each test to avoid conflicts.""" + # Run the test + yield + + # After test: attempt to get and reset any services that might exist + # This prevents namespace pollution between tests + try: + # Each test uses a unique namespace, so this is mostly for safety + pass + except Exception: + pass + + +class TestRayService: + """Test suite for RayService.""" + + def test_initialization(self): + """Test that RayService initializes correctly.""" + services = RayService(namespace="test_init") + try: + assert isinstance(services, RayService) + # shutdown the Ray service + finally: + services.shutdown(raise_on_error=False) + + def test_initialization_with_existing_ray(self): + """Test RayService with already-initialized Ray.""" + # Ray is already initialized by fixture + services = RayService(namespace="test_torchrl_services") + try: + assert isinstance(services, RayService) + finally: + services.shutdown(raise_on_error=False) + + def test_register_service(self): + """Test registering a new service.""" + services = RayService(namespace="test_register") + try: + actor = services.register("simple", SimpleService, value=42) + assert actor is not None + + # Verify we can call methods on the actor + result = ray.get(actor.get_value.remote(), timeout=10) + assert result == 42 + finally: + services.shutdown(raise_on_error=False) + + def test_register_with_ray_options(self): + """Test registering a service with Ray options.""" + services = RayService(namespace="test_options") + try: + actor = services.register( + "tokenizer", TokenizerService, vocab_size=5000, num_cpus=1, num_gpus=0 + ) + + # Verify the service works + result = ray.get(actor.encode.remote("hello"), timeout=10) + assert isinstance(result, list) + assert len(result) == 5 + finally: + services.shutdown(raise_on_error=False) + + def test_register_duplicate_raises(self): + """Test that registering duplicate service raises ValueError.""" + services = RayService(namespace="test_duplicate") + + try: + services.register("simple", SimpleService) + + with pytest.raises(ValueError, match="already exists"): + services.register("simple", SimpleService) + finally: + services.shutdown(raise_on_error=False) + + def test_get_service(self): + """Test retrieving a registered service.""" + services = RayService(namespace="test_get") + + try: + # Register a service + original_actor = services.register("simple", SimpleService, value=100) + + # Retrieve the same service + retrieved_actor = services.get("simple") + + # Verify they reference the same actor + result = ray.get(retrieved_actor.get_value.remote(), timeout=10) + assert result == 100 + finally: + services.shutdown(raise_on_error=False) + + def test_get_nonexistent_raises(self): + """Test that getting a nonexistent service raises KeyError.""" + services = RayService(namespace="test_get_missing") + try: + with pytest.raises(KeyError, match="not found"): + services.get("nonexistent") + finally: + services.shutdown(raise_on_error=False) + + def test_getitem_access(self): + """Test dict-like access with [].""" + services = RayService(namespace="test_getitem") + + try: + services.register("tokenizer", TokenizerService, vocab_size=100) + + # Access using dict syntax + tokenizer = services["tokenizer"] + result = ray.get(tokenizer.encode.remote("test"), timeout=10) + assert isinstance(result, list) + finally: + services.shutdown(raise_on_error=False) + + def test_contains(self): + """Test checking if service exists with 'in'.""" + services = RayService(namespace="test_contains") + + try: + services.register("existing", SimpleService) + assert "existing" in services + assert "nonexistent" not in services + finally: + services.shutdown(raise_on_error=False) + + def test_list_services(self): + """Test listing all registered services.""" + services = RayService(namespace="test_list") + + try: + # Initially empty + assert services.list() == [] + + # Register multiple services + services.register("service1", SimpleService, value=1) + services.register("service2", SimpleService, value=2) + services.register("service3", TokenizerService) + + service_names = services.list() + assert sorted(service_names) == ["service1", "service2", "service3"] + finally: + services.shutdown(raise_on_error=False) + + def test_cross_worker_visibility(self): + """Test that services registered by one worker are visible to another.""" + namespace = "test_cross_worker" + + # Worker 1: Register a service + services1 = RayService(namespace=namespace) + services1.register("shared_service", SimpleService, value=999) + + # Worker 2: Should see the same service + services2 = RayService(namespace=namespace) + assert "shared_service" in services2 + try: + shared_actor = services2["shared_service"] + result = ray.get(shared_actor.get_value.remote(), timeout=10) + assert result == 999 + finally: + services2.shutdown(raise_on_error=False) + + def test_namespace_isolation(self): + """Test that different namespaces isolate services.""" + # Register in namespace A + services_a = RayService(namespace="namespace_a") + services_a.register("service", SimpleService, value=111) + + # Register different service with same name in namespace B + services_b = RayService(namespace="namespace_b") + services_b.register("service", SimpleService, value=222) + + try: + # Verify they're isolated + actor_a = services_a["service"] + actor_b = services_b["service"] + + result_a = ray.get(actor_a.get_value.remote(), timeout=10) + result_b = ray.get(actor_b.get_value.remote(), timeout=10) + + assert result_a == 111 + assert result_b == 222 + finally: + services_a.shutdown(raise_on_error=False) + services_b.shutdown(raise_on_error=False) + + def test_options_method(self): + """Test the register_with_options() method for explicit configuration.""" + services = RayService(namespace="test_options_method") + + try: + # Register with explicit actor options + services.register_with_options( + "simple", + SimpleService, + actor_options={"num_cpus": 1, "max_concurrency": 5}, + value=42, + ) + + # Verify it works + actor = services["simple"] + result = ray.get(actor.get_value.remote(), timeout=10) + assert result == 42 + finally: + services.shutdown(raise_on_error=False) + + def test_service_persistence(self): + """Test that services persist across RayService instances.""" + namespace = "test_persistence" + + # Create first instance and register + services1 = RayService(namespace=namespace) + services1.register("persistent", SimpleService, value=777) + + # Create second instance + services2 = RayService(namespace=namespace) + + try: + # Should be able to access the service + assert "persistent" in services2 + actor = services2["persistent"] + result = ray.get(actor.get_value.remote(), timeout=10) + assert result == 777 + finally: + services1.shutdown(raise_on_error=False) + services2.shutdown(raise_on_error=False) + + def test_setitem_registration(self): + """Test registering with dict-like syntax.""" + services = RayService(namespace="test_setitem") + + try: + # Register using setitem + services["simple"] = SimpleService + + # Verify it was registered + assert "simple" in services + finally: + services.shutdown(raise_on_error=False) + + def test_reset(self): + """Test resetting the service registry.""" + services = RayService(namespace="test_reset") + + try: + # Register multiple services + services.register("service1", SimpleService, value=1) + services.register("service2", SimpleService, value=2) + services.register("service3", TokenizerService) + + # Verify they exist + assert len(services.list()) == 3 + assert "service1" in services + assert "service2" in services + assert "service3" in services + + # Reset + services.reset() + + # Verify all services are gone + assert len(services.list()) == 0 + assert "service1" not in services + assert "service2" not in services + assert "service3" not in services + finally: + services.shutdown(raise_on_error=False) + + def test_reset_multiple_namespaces(self): + """Test that reset only affects the specific namespace.""" + # Create services in different namespaces + services_a = RayService(namespace="namespace_a") + services_b = RayService(namespace="namespace_b") + + try: + # Register services in both namespaces + services_a.register("service", SimpleService, value=1) + services_b.register("service", SimpleService, value=2) + + # Verify both exist + assert "service" in services_a + assert "service" in services_b + + # Reset namespace A + services_a.reset() + + # Verify namespace A is empty but B is not + assert "service" not in services_a + assert "service" in services_b + + # Cleanup namespace B + services_b.reset() + finally: + services_a.shutdown(raise_on_error=False) + services_b.shutdown(raise_on_error=False) + + +class TestGetServices: + """Test suite for get_services() function.""" + + def test_get_services_ray(self): + """Test get_services with ray backend.""" + services = get_services(backend="ray", namespace="test_get_services") + try: + assert isinstance(services, RayService) + finally: + services.shutdown(raise_on_error=False) + + def test_get_services_invalid_backend(self): + """Test that invalid backend raises ValueError.""" + with pytest.raises(ValueError, match="Unsupported backend"): + get_services(backend="invalid") + + def test_get_services_with_ray_config(self): + """Test get_services with Ray configuration.""" + services = get_services( + backend="ray", + namespace="test_with_config", + ) + try: + assert isinstance(services, RayService) + finally: + services.shutdown(raise_on_error=False) + + +class TestIntegrationScenarios: + """Integration tests for realistic usage scenarios.""" + + def test_tokenizer_sharing(self): + """Test sharing a tokenizer across workers.""" + namespace = "test_tokenizer_integration" + + # Setup: Register tokenizer + services = get_services(backend="ray", namespace=namespace) + try: + services.register( + "tokenizer", TokenizerService, vocab_size=10000, num_cpus=1 + ) + + # Worker 1: Use tokenizer + tokenizer = services["tokenizer"] + encoded = ray.get(tokenizer.encode.remote("hello world"), timeout=10) + assert isinstance(encoded, list) + assert len(encoded) == 11 # "hello world" has 11 characters + + # Worker 2: Access same tokenizer + services2 = get_services(backend="ray", namespace=namespace) + tokenizer2 = services2["tokenizer"] + encoded2 = ray.get(tokenizer2.encode.remote("hello world"), timeout=10) + + # Should produce same results + assert encoded == encoded2 + finally: + services.shutdown(raise_on_error=False) + + def test_stateful_service(self): + """Test that services maintain state across calls.""" + services = RayService(namespace="test_stateful") + + services.register("counter", SimpleService, value=0) + + try: + counter = services["counter"] + + # Modify state + ray.get(counter.set_value.remote(10), timeout=10) + + # Access from "different worker" + services2 = RayService(namespace="test_stateful") + counter2 = services2["counter"] + + # Should see the modified state + result = ray.get(counter2.get_value.remote(), timeout=10) + assert result == 10 + finally: + services.shutdown(raise_on_error=False) + # services2.shutdown(raise_on_error=False) + + def test_conditional_registration(self): + """Test pattern: register only if not exists.""" + namespace = "test_conditional" + + services1 = get_services(backend="ray", namespace=namespace) + + # Worker 2: Try to register same service + services2 = get_services(backend="ray", namespace=namespace) + + try: + # Worker 1: Register if not exists + if "shared_tokenizer" not in services1: + services1.register( + "shared_tokenizer", TokenizerService, vocab_size=5000 + ) + + if "shared_tokenizer" not in services2: + services2.register( + "shared_tokenizer", TokenizerService, vocab_size=10000 + ) + else: + # Should take this branch + pass + + # Both should see the same tokenizer (first one registered) + tok1 = services1["shared_tokenizer"] + tok2 = services2["shared_tokenizer"] + + # Verify they're the same actor + vocab1 = ray.get(tok1.getattr.remote("vocab_size"), timeout=10) + vocab2 = ray.get(tok2.getattr.remote("vocab_size"), timeout=10) + assert vocab1 == vocab2 == 5000 + finally: + services1.shutdown(raise_on_error=False) + services2.shutdown(raise_on_error=False) + + def test_multiple_services_management(self): + """Test managing multiple different services.""" + services = RayService(namespace="test_multiple") + + try: + # Register various services + services.register( + "tokenizer", TokenizerService, vocab_size=1000, num_cpus=1 + ) + services.register("counter1", SimpleService, value=1, num_cpus=0.5) + services.register("counter2", SimpleService, value=2, num_cpus=0.5) + + # Verify all are accessible + assert len(services.list()) == 3 + + tok = services["tokenizer"] + c1 = services["counter1"] + c2 = services["counter2"] + + # Use them + assert ray.get(tok.getattr.remote("vocab_size"), timeout=10) == 1000 + assert ray.get(c1.getattr.remote("value"), timeout=10) == 1 + assert ray.get(c2.getattr.remote("value"), timeout=10) == 2 + finally: + services.shutdown(raise_on_error=False) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + _, args = parser.parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + args) diff --git a/torchrl/envs/llm/transforms/__init__.py b/torchrl/envs/llm/transforms/__init__.py index e4502042e72..5591a81b8fb 100644 --- a/torchrl/envs/llm/transforms/__init__.py +++ b/torchrl/envs/llm/transforms/__init__.py @@ -19,6 +19,7 @@ ExecuteToolsInOrder, JSONCallParser, MCPToolTransform, + PythonExecutorService, PythonInterpreter, SimpleToolTransform, ToolCall, @@ -37,6 +38,7 @@ "KLRewardTransform", "MCPToolTransform", "PolicyVersion", + "PythonExecutorService", "PythonInterpreter", "RayDataLoadingPrimer", "RetrieveKL", diff --git a/torchrl/envs/llm/transforms/tools.py b/torchrl/envs/llm/transforms/tools.py index c4f5ab97397..6a17125b1d4 100644 --- a/torchrl/envs/llm/transforms/tools.py +++ b/torchrl/envs/llm/transforms/tools.py @@ -972,6 +972,77 @@ def __del__(self): self.cleanup() +class PythonExecutorService: + """Ray actor that manages a pool of persistent Python interpreters. + + This service allows multiple environments to share a pool of Python + interpreters, reducing resource usage and improving efficiency. + + Args: + pool_size (int): Number of Python interpreter processes to maintain. + timeout (float): Timeout for code execution in seconds. + + Examples: + >>> # Register the service + >>> from torchrl.services import get_services + >>> services = get_services(backend="ray") + >>> services.register( + ... "python_executor", + ... PythonExecutorService, + ... pool_size=32, + ... timeout=10.0, + ... num_cpus=32, + ... max_concurrency=32 + ... ) + >>> + >>> # Use in transform + >>> env = env.append_transform( + ... PythonInterpreter(services="ray") + ... ) + """ + + def __init__(self, pool_size: int = 32, timeout: float = 10.0): + self.pool_size = pool_size + self.timeout = timeout + self.processes = [ + PersistentPythonProcess(timeout=timeout) for _ in range(pool_size) + ] + self.next_idx = 0 + self._lock = threading.Lock() + + def execute(self, code: str) -> dict: + """Execute Python code using next available process (round-robin). + + Args: + code: Python code to execute. + + Returns: + dict: Execution result with keys 'success', 'stdout', 'stderr', 'returncode'. + """ + # Simple round-robin - Ray handles the queuing via max_concurrency + with self._lock: + process = self.processes[self.next_idx] + self.next_idx = (self.next_idx + 1) % self.pool_size + + return process.execute(code) + + def cleanup(self): + """Cleanup all processes in the pool.""" + if hasattr(self, "processes"): + for process in self.processes: + if process: + process.cleanup() + self.processes = [] + + def __del__(self): + """Ensure cleanup on deletion.""" + try: + self.cleanup() + except Exception: + # Ignore errors during cleanup - we might be in Ray actor context + pass + + class PythonInterpreter(ToolTransformBase): r"""A transform that executes Python code in the LLM response. @@ -983,6 +1054,13 @@ class PythonInterpreter(ToolTransformBase): tool_name: The name of the tool in the chat history. Defaults to `"tool"`. persistent: Whether to use persistent processes. Defaults to `False`. timeout: The timeout for the persistent processes. Defaults to `10.0`. + services: Backend for shared Python executor service. If `"ray"`, uses + a shared Ray actor service for execution. If `None`, uses local + processes. Defaults to `None`. + service_name: Name of the service in the registry. Only used if + `services="ray"`. Defaults to `"python_executor"`. + namespace: Ray namespace for the service. Only used if `services="ray"`. + If `None`, uses the default namespace. Defaults to `None`. Examples: >>> from torchrl.envs.llm.transforms import PythonInterpreter @@ -1019,6 +1097,24 @@ class PythonInterpreter(ToolTransformBase): '\n' '<|im_end|>\n' '<|im_start|>assistant\n'] + + Using shared Ray service: + >>> from torchrl.services import get_services + >>> + >>> # Register service once (e.g., in main process) + >>> services = get_services(backend="ray") + >>> if "python_executor" not in services: + ... services.register( + ... "python_executor", + ... PythonExecutorService, + ... pool_size=32, + ... timeout=10.0, + ... num_cpus=32, + ... max_concurrency=32 + ... ) + >>> + >>> # Use in transform (all 128 envs share the 32 interpreters) + >>> env = env.append_transform(PythonInterpreter(services="ray")) """ use_step = True # Use _step() method @@ -1029,22 +1125,56 @@ def __init__( tool_name: str = "tool", persistent: bool = False, timeout: float = 10.0, + services: str | None = None, + service_name: str = "python_executor", + namespace: str | None = None, ): super().__init__() self.tokenizer = tokenizer self.tool_role = tool_name # Set the role for history entries self.persistent = persistent self.timeout = timeout - # Initialize as empty list if persistent, None otherwise - self.processes: list[PersistentPythonProcess | None] = [] if persistent else [] + self.services = services + self.service_name = service_name + self.namespace = namespace + + # Initialize attributes to avoid AttributeError in __del__ + self.python_service = None + self.processes = None + + # Initialize based on service mode + if services == "ray": + # Use shared Ray service + try: + from torchrl.services import get_services + + service_registry = get_services(backend="ray", namespace=namespace) + self.python_service = service_registry[service_name] + self.processes = None + torchrl_logger.info( + f"PythonInterpreter using Ray service '{service_name}'" + ) + except Exception as e: + raise RuntimeError( + f"Failed to get Ray service '{service_name}'. " + f"Make sure the service is registered. Error: {e}" + ) from e + elif services is None: + # Use local processes + self.python_service = None + self.processes = [] if persistent else [] + else: + raise ValueError( + f"Invalid services backend: {services}. Must be 'ray' or None." + ) def close(self): """Close the transform.""" - if self.processes: + if self.python_service is None and self.processes: for process in self.processes: if process: process.cleanup() - self.processes = [] + self.processes = [] def clone(self): """Clone the transform.""" @@ -1053,6 +1183,9 @@ def clone(self): tool_name=self.tool_role, # tool_role is the instance attribute persistent=self.persistent, timeout=self.timeout, + services=self.services, + service_name=self.service_name, + namespace=self.namespace, ) def _ensure_processes(self, batch_size: int): @@ -1078,7 +1211,22 @@ def _ensure_processes(self, batch_size: int): def _execute_python_code(self, code: str, i: int) -> dict: """Safely execute Python code and return results.""" - if self.persistent: + if self.python_service is not None: + # Use shared Ray service + try: + import ray + + result = ray.get(self.python_service.execute.remote(code)) + return result + except Exception as e: + return { + "success": False, + "stdout": "", + "stderr": f"Ray service execution failed: {str(e)}", + "returncode": -1, + } + elif self.persistent: + # Use local persistent process # Ensure we have enough processes if i >= len(self.processes): self._ensure_processes(i + 1) @@ -1182,19 +1330,28 @@ def _step( self, tensordict: TensorDictBase, next_tensordict: TensorDictBase ) -> TensorDictBase: """Override to handle batch size management for persistent processes.""" - # Ensure we have enough processes for the entire batch - if self.persistent and next_tensordict.batch_dims == 1: + # Ensure we have enough processes for the entire batch (only for local persistent mode) + if ( + self.python_service is None + and self.persistent + and next_tensordict.batch_dims == 1 + ): self._ensure_processes(len(next_tensordict)) # Delegate to base class for all the heavy lifting return super()._step(tensordict, next_tensordict) def __del__(self): - """Cleanup persistent processes on deletion.""" - if self.processes: - for process in self.processes: - if process: - process.cleanup() + """Ensure cleanup on deletion.""" + try: + if hasattr(self, "python_service") and self.python_service is None: + if hasattr(self, "processes") and self.processes: + for process in self.processes: + if process: + process.cleanup() + except Exception: + # Ignore errors during cleanup + pass def _reset( self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase @@ -1207,7 +1364,9 @@ def _reset( reset = torch.ones( tensordict.shape, device=tensordict.device, dtype=torch.bool ) - if self.persistent: + + # Only reset local persistent processes, not the shared service + if self.python_service is None and self.persistent: for i, process in enumerate(self.processes): if reset[i] and process is not None: process.cleanup() diff --git a/torchrl/services/__init__.py b/torchrl/services/__init__.py new file mode 100644 index 00000000000..0a78077e06e --- /dev/null +++ b/torchrl/services/__init__.py @@ -0,0 +1,77 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Distributed service registry for TorchRL. + +This module provides a service registry for managing distributed actors +(tokenizers, replay buffers, etc.) that can be accessed across workers. + +Example: + >>> from torchrl.services import get_services + >>> + >>> # Worker 1: Register a tokenizer service + >>> services = get_services() + >>> services.register("tokenizer", TokenizerClass, num_cpus=1, num_gpus=0.1) + >>> + >>> # Worker 2: Access the same tokenizer + >>> services = get_services() + >>> tokenizer = services["tokenizer"] + >>> result = tokenizer.encode.remote(text) +""" + +from torchrl.services.base import ServiceBase +from torchrl.services.ray_service import RayService + +__all__ = ["ServiceBase", "RayService", "get_services"] + + +def get_services(backend: str = "ray", **init_kwargs) -> ServiceBase: + """Get a distributed service registry. + + This function creates or retrieves a service registry for managing distributed + actors across workers. Services registered by one worker are immediately visible + to all other workers in the cluster. + + Args: + backend: Service backend to use. Currently only "ray" is supported. + **init_kwargs: Backend-specific initialization arguments. + For Ray: + - ray_init_config (dict, optional): Arguments to pass to ray.init() + - namespace (str, optional): Ray namespace for service isolation. + Defaults to "torchrl_services". + + Returns: + ServiceBase: A service registry instance. + + Raises: + ValueError: If an unsupported backend is specified. + ImportError: If the required backend library is not installed. + + Examples: + >>> # Basic usage - register and access services + >>> services = get_services() + >>> services.register("tokenizer", TokenizerClass, num_cpus=1) + >>> tokenizer = services["tokenizer"] + >>> + >>> # With custom Ray initialization + >>> services = get_services( + ... backend="ray", + ... ray_init_config={"address": "auto"}, + ... namespace="my_experiment" + ... ) + >>> + >>> # Check if service exists + >>> if "tokenizer" in services: + ... tokenizer = services["tokenizer"] + >>> + >>> # List all registered services + >>> service_names = services.list() + """ + if backend != "ray": + raise ValueError( + f"Unsupported backend: {backend}. Currently only 'ray' is supported." + ) + + return RayService(**init_kwargs) diff --git a/torchrl/services/base.py b/torchrl/services/base.py new file mode 100644 index 00000000000..0c7ceefb198 --- /dev/null +++ b/torchrl/services/base.py @@ -0,0 +1,108 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC, abstractmethod +from typing import Any + + +class ServiceBase(ABC): + """Base class for distributed service registries. + + A service registry manages distributed actors/services that can be accessed + across multiple workers. Common use cases include: + + - Tokenizers shared across inference workers + - Replay buffers for distributed training + - Model registries for centralized model storage + - Metrics aggregators + + The registry provides a dict-like interface for registering and accessing + services by name. + """ + + @abstractmethod + def register(self, name: str, service_factory: type, *args, **kwargs) -> Any: + """Register a service factory and create the service actor. + + This method registers a service with the given name and immediately + creates the corresponding actor. The service becomes globally visible + to all workers in the cluster. + + Args: + name: Unique identifier for the service. This name is used to + retrieve the service later. + service_factory: Class to instantiate as a remote actor. + *args: Positional arguments to pass to the service constructor. + **kwargs: Keyword arguments for both actor configuration and + service constructor. Actor configuration options are backend-specific + (e.g., num_cpus, num_gpus for Ray). + + Returns: + The remote actor handle. + + Raises: + ValueError: If a service with this name already exists. + """ + + @abstractmethod + def get(self, name: str) -> Any: + """Get a service by name. + + Retrieves a previously registered service. If the service was registered + by another worker, this method will find it in the distributed registry. + + Args: + name: Service identifier. + + Returns: + The remote actor handle for the service. + + Raises: + KeyError: If the service is not found. + """ + + @abstractmethod + def __contains__(self, name: str) -> bool: + """Check if a service is registered. + + Args: + name: Service identifier. + + Returns: + True if the service exists, False otherwise. + """ + + @abstractmethod + def list(self) -> list[str]: + """List all registered service names. + + Returns: + List of service names currently registered in the cluster. + """ + + @abstractmethod + def reset(self) -> None: + """Reset the service registry. + + This removes all registered services and cleans up associated resources. + After calling reset(), the registry will be empty and all service actors + will be terminated. + + Warning: + This is a destructive operation. All services will be terminated and + any ongoing work will be interrupted. + """ + + def __getitem__(self, name: str) -> Any: + """Dict-like access: services["tokenizer"].""" + return self.get(name) + + def __setitem__(self, name: str, service_factory: type) -> None: + """Dict-like registration: services["tokenizer"] = TokenizerClass. + + Note: This only supports service_factory without additional arguments. + For full control, use register() method instead. + """ + self.register(name, service_factory) diff --git a/torchrl/services/ray_service.py b/torchrl/services/ray_service.py new file mode 100644 index 00000000000..3fdd0f2b445 --- /dev/null +++ b/torchrl/services/ray_service.py @@ -0,0 +1,452 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any + +from torchrl._utils import logger +from torchrl.services.base import ServiceBase + +RAY_ERR = None +try: + import ray + + _has_ray = True +except ImportError as err: + _has_ray = False + RAY_ERR = err + + +class _ServiceRegistryActor: + """Internal actor that maintains the list of registered services. + + This is a lightweight actor (1 CPU) that tracks which services have been + registered in a namespace. This ensures we only list our own services, + not other named actors in Ray. + """ + + def __init__(self): + self._services: set[str] = set() + + def add(self, name: str) -> None: + """Add a service to the registry.""" + self._services.add(name) + + def remove(self, name: str) -> None: + """Remove a service from the registry.""" + self._services.discard(name) + + def list(self) -> list[str]: + """List all registered services.""" + return sorted(self._services) + + def clear(self) -> None: + """Clear all registered services.""" + self._services.clear() + + def contains(self, name: str) -> bool: + """Check if a service is registered.""" + return name in self._services + + +class RayService(ServiceBase): + """Ray-based distributed service registry. + + This class uses Ray's named actors feature to provide truly distributed + service discovery. When a service is registered by any worker, it becomes + immediately accessible to all other workers in the Ray cluster. + + Services are registered as Ray actors with globally unique names. This + ensures that: + 1. Services persist independently of the registering worker + 2. All workers see the same services instantly + 3. No custom synchronization is needed + + Args: + ray_init_config (dict, optional): Configuration for ray.init(). Only + used if Ray is not already initialized. Common options: + - address (str): Ray cluster address, or "auto" to auto-detect + - num_cpus (int): Number of CPUs to use + - num_gpus (int): Number of GPUs to use + namespace (str, optional): Ray namespace for service isolation. Services + in different namespaces are isolated from each other. Defaults to + "torchrl_services". + + Examples: + >>> # Basic usage + >>> services = RayService() + >>> services.register("tokenizer", TokenizerClass, num_cpus=1) + >>> tokenizer = services["tokenizer"] + >>> + >>> # With Ray options for dynamic configuration + >>> actor = services.register( + ... "model", + ... ModelClass, + ... num_cpus=2, + ... num_gpus=1, + ... memory=10 * 1024**3, + ... max_concurrency=4 + ... ) + >>> + >>> # Check and retrieve + >>> if "tokenizer" in services: + ... tok = services["tokenizer"] + >>> + >>> # List all services + >>> print(services.list()) + ['tokenizer', 'model'] + """ + + def __init__( + self, + ray_init_config: dict | None = None, + namespace: str = "torchrl_services", + ): + if not _has_ray: + raise ImportError( + "Ray is required for RayService. Install with: pip install ray" + ) from RAY_ERR + + self._namespace = namespace + self._ensure_ray_initialized(ray_init_config) + self._registry_actor = self._get_or_create_registry_actor() + + def _ensure_ray_initialized(self, ray_init_config: dict | None = None): + """Initialize Ray if not already initialized.""" + if not ray.is_initialized(): + config = ray_init_config or {} + # Ensure namespace is set + if "namespace" not in config: + config["namespace"] = self._namespace + + logger.info(f"Initializing Ray with namespace '{self._namespace}'") + ray.init(**config) + else: + # Ray already initialized - check if namespace matches + context = ray.get_runtime_context() + current_namespace = context.namespace + if current_namespace != self._namespace: + logger.warning( + f"Ray already initialized with namespace '{current_namespace}', " + f"but RayService is using namespace '{self._namespace}'. " + f"Services may not be visible across namespaces." + ) + + def _make_service_name(self, name: str) -> str: + """Create the full actor name with namespace prefix.""" + return f"{self._namespace}::service::{name}" + + def _get_registry_actor_name(self) -> str: + """Get the name of the registry actor for this namespace.""" + return f"{self._namespace}::_registry" + + def _get_or_create_registry_actor(self): + """Get or create the registry actor for this namespace.""" + registry_name = self._get_registry_actor_name() + + try: + # Try to get existing registry + registry = ray.get_actor(registry_name, namespace=self._namespace) + return registry + except ValueError: + # Registry doesn't exist, create it + RemoteRegistry = ray.remote(_ServiceRegistryActor) + registry = RemoteRegistry.options( + name=registry_name, + namespace=self._namespace, + lifetime="detached", + num_cpus=1, + ).remote() + logger.info( + f"Created service registry actor for namespace '{self._namespace}'" + ) + return registry + + def register(self, name: str, service_factory: type, *args, **kwargs) -> Any: + """Register a service and create a named Ray actor. + + This method creates a Ray actor with a globally unique name. The actor + becomes immediately visible to all workers in the cluster. + + Args: + name: Service identifier. Must be unique within the namespace. + service_factory: Class to instantiate as a Ray actor. + *args: Positional arguments for the service constructor. + **kwargs: Both Ray actor options (num_cpus, num_gpus, memory, etc.) + and service constructor arguments. Ray will filter out the actor + options it recognizes. + + Returns: + The Ray actor handle. + + Raises: + ValueError: If a service with this name already exists. + + Examples: + >>> services = RayService() + >>> + >>> # Basic registration + >>> tokenizer = services.register("tokenizer", TokenizerClass) + >>> + >>> # With Ray resource specification + >>> buffer = services.register( + ... "buffer", + ... ReplayBuffer, + ... num_cpus=2, + ... num_gpus=0, + ... size=1000000 + ... ) + >>> + >>> # With advanced Ray options + >>> model = services.register( + ... "model", + ... ModelClass, + ... num_cpus=4, + ... num_gpus=1, + ... memory=20 * 1024**3, + ... max_concurrency=10, + ... max_restarts=3, + ... ) + """ + full_name = self._make_service_name(name) + + # Check if service already exists in our registry + if ray.get(self._registry_actor.contains.remote(name)): + raise ValueError( + f"Service '{name}' already exists in namespace '{self._namespace}'. " + f"Use a different name or retrieve the existing service with get()." + ) + + # Create the Ray remote class + # First, make it a remote class + remote_cls = ray.remote(service_factory) + + # Then apply options including the name + options = { + "name": full_name, + "namespace": self._namespace, + "lifetime": "detached", + } + + # Extract Ray-specific options from kwargs + ray_options = [ + "num_cpus", + "num_gpus", + "memory", + "object_store_memory", + "resources", + "accelerator_type", + "max_concurrency", + "max_restarts", + "max_task_retries", + "max_pending_calls", + "scheduling_strategy", + ] + + for opt in ray_options: + if opt in kwargs: + options[opt] = kwargs.pop(opt) + + # Apply options and create the actor + remote_actor = remote_cls.options(**options).remote(*args, **kwargs) + + # Add to registry + ray.get(self._registry_actor.add.remote(name)) + + logger.info( + f"Registered service '{name}' as Ray actor '{full_name}' " + f"with options: {options}" + ) + + return remote_actor + + def get(self, name: str) -> Any: + """Get a service by name. + + Retrieves a service actor by name. The service can have been registered + by any worker in the cluster. + + Args: + name: Service identifier. + + Returns: + The Ray actor handle. + + Raises: + KeyError: If the service is not found. + + Examples: + >>> services = RayService() + >>> tokenizer = services.get("tokenizer") + >>> # Use the actor + >>> result = ray.get(tokenizer.encode.remote("Hello world")) + """ + # Check registry first + if not ray.get(self._registry_actor.contains.remote(name)): + raise KeyError( + f"Service '{name}' not found in namespace '{self._namespace}'. " + f"Available services: {self.list()}" + ) + + full_name = self._make_service_name(name) + + try: + actor = ray.get_actor(full_name, namespace=self._namespace) + return actor + except ValueError as e: + # Service in registry but actor missing - inconsistency + logger.warning( + f"Service '{name}' in registry but actor not found. " + f"Removing from registry." + ) + ray.get(self._registry_actor.remove.remote(name)) + raise KeyError( + f"Service '{name}' actor not found (removed from registry). " + f"Available services: {self.list()}" + ) from e + + def __contains__(self, name: str) -> bool: + """Check if a service is registered. + + Args: + name: Service identifier. + + Returns: + True if the service exists, False otherwise. + + Examples: + >>> services = RayService() + >>> if "tokenizer" in services: + ... tokenizer = services["tokenizer"] + ... else: + ... services.register("tokenizer", TokenizerClass) + """ + return ray.get(self._registry_actor.contains.remote(name)) + + def list(self) -> list[str]: + """List all registered service names. + + Returns a list of all services in the current namespace. This includes + services registered by any worker. + + Returns: + List of service names (without namespace prefix). + + Examples: + >>> services = RayService() + >>> services.register("tokenizer", TokenizerClass) + >>> services.register("buffer", ReplayBuffer) + >>> print(services.list()) + ['buffer', 'tokenizer'] + """ + return ray.get(self._registry_actor.list.remote()) + + def reset(self) -> None: + """Reset the service registry by terminating all actors. + + This method: + 1. Terminates all service actors in the current namespace + 2. Clears the registry actor's internal state + + After calling reset(), all services will be removed and their actors + will be killed. Any ongoing work will be interrupted. + + Warning: + This is a destructive operation that affects all workers in the + namespace. Use with caution. + + Examples: + >>> services = RayService(namespace="experiment") + >>> services.register("tokenizer", TokenizerClass) + >>> print(services.list()) + ['tokenizer'] + >>> services.reset() + >>> print(services.list()) + [] + """ + service_names = self.list() + + for name in service_names: + full_name = self._make_service_name(name) + try: + actor = ray.get_actor(full_name, namespace=self._namespace) + ray.kill(actor) + logger.info(f"Terminated service '{name}' (actor: {full_name})") + except ValueError: + # Actor already gone or doesn't exist + logger.warning(f"Service '{name}' not found during reset") + except Exception as e: + logger.warning(f"Failed to terminate service '{name}': {e}") + + # Clear the registry + ray.get(self._registry_actor.clear.remote()) + + logger.info( + f"Reset complete for namespace '{self._namespace}'. Terminated {len(service_names)} services." + ) + + def shutdown(self, raise_on_error: bool = True) -> None: + """Shutdown the RayService by shutting down the Ray cluster.""" + try: + self.reset() + # kill the registry actor + registry_actor = ray.get_actor( + self._get_registry_actor_name(), namespace=self._namespace + ) + ray.kill(registry_actor, no_restart=True) + except Exception as e: + if raise_on_error: + raise e + else: + logger.warning(f"Error shutting down RayService: {e}") + + def register_with_options( + self, + name: str, + service_factory: type, + actor_options: dict[str, Any], + **constructor_kwargs, + ) -> Any: + """Register a service with explicit separation of Ray options and constructor args. + + This is a convenience method that makes it explicit which arguments are for + Ray actor configuration vs. the service constructor. It's functionally + equivalent to `register()` but more readable for complex configurations. + + Args: + name: Service identifier. + service_factory: Class to instantiate as a Ray actor. + actor_options: Dictionary of Ray actor options (num_cpus, num_gpus, etc.). + **constructor_kwargs: Arguments to pass to the service constructor. + + Returns: + The Ray actor handle. + + Examples: + >>> services = RayService() + >>> + >>> # Explicit separation of concerns + >>> model = services.register_with_options( + ... "model", + ... ModelClass, + ... actor_options={ + ... "num_cpus": 4, + ... "num_gpus": 1, + ... "memory": 20 * 1024**3, + ... "max_concurrency": 10 + ... }, + ... model_path="/path/to/checkpoint", + ... batch_size=32 + ... ) + >>> + >>> # Equivalent to: + >>> # services.register( + >>> # "model", ModelClass, + >>> # num_cpus=4, num_gpus=1, memory=20*1024**3, max_concurrency=10, + >>> # model_path="/path/to/checkpoint", batch_size=32 + >>> # ) + """ + # Merge actor_options into kwargs for register() + merged_kwargs = {**actor_options, **constructor_kwargs} + return self.register(name, service_factory, **merged_kwargs)