Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dspy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from dspy.utils.logging_utils import configure_dspy_loggers, disable_logging, enable_logging
from dspy.utils.asyncify import asyncify
from dspy.utils.saving import load
from dspy.utils.MCPTools import MCPTools
from dspy.streaming.streamify import streamify
from dspy.utils.usage_tracker import track_usage

Expand Down
24 changes: 20 additions & 4 deletions dspy/predict/react.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import logging
from typing import Any, Callable, Literal

from litellm import ContextWindowExceededError

import asyncio
import dspy
from dspy.primitives.program import Module
from dspy.primitives.tool import Tool
Expand Down Expand Up @@ -67,15 +67,28 @@ def __init__(self, signature, tools: list[Callable], max_iters=5):
self.tools = tools
self.react = dspy.Predict(react_signature)
self.extract = dspy.ChainOfThought(fallback_signature)

def _format_trajectory(self, trajectory: dict[str, Any]):
adapter = dspy.settings.adapter or dspy.ChatAdapter()
trajectory_signature = dspy.Signature(f"{', '.join(trajectory.keys())} -> x")
return adapter.format_user_message_content(trajectory_signature, trajectory)

def forward(self, **input_args):
"""Execute the ReAct agent with the provided input arguments."""
trajectory = {}
max_iters = input_args.pop("max_iters", self.max_iters)
try:
asyncio.get_running_loop()
async def async_forward():
return await self._forward_async(trajectory, max_iters, **input_args)
return async_forward()
except RuntimeError:
async def run_async():
return await self._forward_async(trajectory, max_iters, **input_args)
return asyncio.run(run_async())

async def _forward_async(self, trajectory, max_iters, **input_args):
"""Async implementation of the ReAct forward method."""
for idx in range(max_iters):
try:
pred = self._call_with_potential_trajectory_truncation(self.react, trajectory, **input_args)
Expand All @@ -88,9 +101,12 @@ def forward(self, **input_args):
trajectory[f"tool_args_{idx}"] = pred.next_tool_args

try:
trajectory[f"observation_{idx}"] = self.tools[pred.next_tool_name](**pred.next_tool_args)
# Execute the tool using the aexecute method
trajectory[f"observation_{idx}"] = await self.tools[pred.next_tool_name].aexecute(**pred.next_tool_args)
except Exception as err:
trajectory[f"observation_{idx}"] = f"Execution error in {pred.next_tool_name}: {_fmt_exc(err)}"
error_msg = f"Execution error in {pred.next_tool_name}: {_fmt_exc(err)}"
trajectory[f"observation_{idx}"] = error_msg
logger.warning(f"Tool execution error: {_fmt_exc(err)}")

if pred.next_tool_name == "finish":
break
Expand Down
15 changes: 14 additions & 1 deletion dspy/primitives/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,5 +174,18 @@ async def acall(self, **kwargs):
parsed_kwargs = self._validate_and_parse_args(**kwargs)
result = self.func(**parsed_kwargs)
if not asyncio.iscoroutine(result):
raise ValueError("You are calling `acall` on a non-async tool, please use `__call__` instead.")
raise ValueError("You are calling `acall` on a sync tool, please use `__call__` instead.")
return await result

async def aexecute(self, **kwargs):
"""Execute the tool in any context (async or sync) handling both sync and async functions.

This method is used internally by ReAct and other agents to safely execute tools in any context.
It will work with both sync and async functions, making it a universal tool execution method.
"""
parsed_kwargs = self._validate_and_parse_args(**kwargs)
result = self.func(**parsed_kwargs)
if asyncio.iscoroutine(result):
return await result
else:
return result
201 changes: 201 additions & 0 deletions dspy/utils/MCPTools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
from typing import Any, Dict, List, Optional, Tuple, Type
import json
import logging
import anyio
from dspy.primitives.tool import Tool

logger = logging.getLogger(__name__)

def map_json_schema_to_tool_args(schema: Optional[Dict[str, Any]]) -> Tuple[Dict[str, Any], Dict[str, Type], Dict[str, str]]:
"""Maps a JSON schema to tool arguments compatible with DSPy Tool.

Args:
schema: A JSON schema describing the tool's input parameters

Returns:
A tuple of (args, arg_types, arg_desc) dictionaries for DSPy Tool initialization
"""
args, arg_types, arg_desc = {}, {}, {}
if not schema or "properties" not in schema:
return args, arg_types, arg_desc

type_mapping = {"string": str, "integer": int, "number": float, "boolean": bool, "array": list, "object": dict}
required = schema.get("required", [])

for name, prop in schema["properties"].items():
args[name] = prop
arg_types[name] = type_mapping.get(prop.get("type", "string"), Any)
arg_desc[name] = prop.get("description", "No description provided.")
if name in required:
arg_desc[name] += " (Required)"

return args, arg_types, arg_desc

class MCPTool(Tool):
"""Wrapper for an MCP tool, compatible with DSPy agents.

This class wraps a Model Context Protocol tool and makes it compatible with
DSPy's ReAct and other agent frameworks. It handles the translation between
DSPy's tool interface and MCP's JSON-RPC interface.
"""

def __init__(self, tool_info: Any, session: Any):
"""Create a DSPy Tool from an MCP tool description.

Args:
tool_info: Tool information from MCP (object, dict, or JSON string)
session: MCP client session for making tool calls
"""
self.session = session
self._raw_tool_info = tool_info

name, desc, input_schema = self._extract_tool_info(tool_info)
self.name = name
args, arg_types, arg_desc = map_json_schema_to_tool_args(input_schema)

super().__init__(
func=self.call_tool_async,
name=name,
desc=desc,
args=args,
arg_types=arg_types,
arg_desc=arg_desc
)

def _extract_tool_info(self, tool_info: Any) -> Tuple[str, str, Optional[Dict[str, Any]]]:
"""Extract name, description and input schema from tool info.

Args:
tool_info: Tool information in various formats (object, dict, JSON string)

Returns:
A tuple of (name, description, input_schema)
"""
# Try object attributes
if hasattr(tool_info, 'name') and hasattr(tool_info, 'description'):
return (
tool_info.name,
tool_info.description,
getattr(tool_info, 'inputSchema', None)
)

return str(tool_info), "No description available.", None

async def call_tool_async(self, **kwargs: Any) -> Any:
"""Execute the MCP tool asynchronously.

Args:
**kwargs: Arguments to pass to the MCP tool

Returns:
Processed result from the tool execution

Raises:
RuntimeError: If there's an error executing the tool
"""
try:
logger.debug(f"Executing MCP tool {self.name} with args: {kwargs}")
result = await self.session.call_tool(self.name, kwargs)
logger.debug(f"MCP tool {self.name} returned: {result}")
return self._process_result(result)
except anyio.ClosedResourceError as e:
# Special handling for closed resources (common during cleanup)
logger.error(f"MCP resource closed during {self.name} execution: {str(e)}")
raise RuntimeError(f"MCP connection closed while executing {self.name}")
except Exception as e:
logger.error(f"Error executing MCP tool {self.name}: {str(e)}")
raise RuntimeError(f"Error executing tool {self.name}: {str(e)}")

def _process_result(self, result: Any) -> Any:
"""Process the result from tool execution into a format suitable for agents.

Args:
result: Raw result from the MCP tool

Returns:
Processed result (typically as a string or structured data)
"""
if result is None:
return "Tool executed successfully but returned no content."

# Handle content attribute
if hasattr(result, 'content') and result.content:
content = result.content
if isinstance(content, list):
try:
return "\n".join(str(getattr(item, 'text', item)) for item in content if item)
except Exception:
pass
return str(content)

# Handle text attribute
if hasattr(result, 'text') and result.text is not None:
return result.text

# Handle dictionary
if isinstance(result, dict):
for key in ("message", "output", "result", "text"):
if key in result:
return str(result[key])
return json.dumps(result, indent=2)

return str(result)

class MCPTools:
"""Collection of tools from an MCP server, usable with DSPy agents.

This class provides a container for multiple MCP tools and makes them
available in a format compatible with DSPy's agent frameworks like ReAct.
"""

def __init__(self, session: Any, tools_list: List[Any]):
"""Initialize the MCPTools collection.

Args:
session: MCP client session for making tool calls
tools_list: List of tool descriptions from MCP
"""
self.session = session
self.tools = {MCPTool(tool, session).name: MCPTool(tool, session) for tool in tools_list}
logger.info(f"Initialized MCPTools with {len(self.tools)} tools: {', '.join(self.tools.keys())}")

def __getitem__(self, tool_name: str) -> MCPTool:
"""Get a tool by name.

Args:
tool_name: Name of the tool to retrieve

Returns:
The requested MCPTool instance

Raises:
KeyError: If the tool is not found
"""
if tool_name not in self.tools:
raise KeyError(f"Tool '{tool_name}' not found in available MCP tools")
return self.tools[tool_name]

def get_tools(self) -> List[Tool]:
"""Get all tools as a list.

Returns:
List of all available MCPTool instances
"""
return list(self.tools.values())

def get_tool_names(self) -> List[str]:
"""Get names of all available tools.

Returns:
List of tool names
"""
return list(self.tools.keys())

def __str__(self) -> str:
"""String representation showing available tools."""
return f"MCPTools with {len(self.tools)} tools: {', '.join(self.tools.keys())}"

def __repr__(self) -> str:
"""Detailed representation of the tools collection."""
return f"MCPTools({len(self.tools)} tools: {list(self.tools.keys())})"

2 changes: 2 additions & 0 deletions dspy/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dspy.utils.callback import BaseCallback, with_callbacks
from dspy.utils.dummies import DummyLM, DummyVectorizer, dummy_rm
from dspy.streaming.messages import StatusMessageProvider, StatusMessage
from dspy.utils.MCPTools import MCPTools

import os
import requests
Expand All @@ -24,6 +25,7 @@ def download(url):
"with_callbacks",
"DummyLM",
"DummyVectorizer",
"MCPTools",
"dummy_rm",
"StatusMessage",
"StatusMessageProvider",
Expand Down
56 changes: 56 additions & 0 deletions mcp_docs/basic_mcp_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import os
import asyncio
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client

import dspy


async def main():
"""Main entry point for the async application."""
try:
print("Starting MCP client initialization...")
# Configure DSPy with LLM
LLM = dspy.LM("gemini/gemini-2.0-flash", api_key=os.getenv("GOOGLE_API_KEY"))
dspy.configure(lm=LLM)

# Initialize MCP server and tools
server_params = StdioServerParameters(
command="npx",
args=[
"-y",
"@openbnb/mcp-server-airbnb",
"--ignore-robots-txt"
])
async with stdio_client(server_params) as (read, write):
async with ClientSession(read, write) as session:
await session.initialize()
tools = await session.list_tools()
tools = tools.tools
print("\nCreating MCPTools instance...")
mcp_tools = dspy.MCPTools(session=session, tools_list=tools)

# Create ReAct agent in the same async context
react_agent = dspy.ReAct("input->output", mcp_tools.get_tools())

# Run the agent (will use the existing event loop)
print("\nRunning ReAct agent...")
react_result = await react_agent(
input="Find a place to stay in New York for 2 adults from 2025-05-01 to 2025-05-05."
)

# # If the result is a coroutine (which it will be in an async context), await it
# if asyncio.iscoroutine(react_result):
# react_result = await react_result

print("\nReAct Result:")
print(react_result)

except Exception as e:
print(f"Error in main: {str(e)}")
raise

if __name__ == "__main__":
asyncio.run(main())


Loading