|
7 | 7 | from contextlib import asynccontextmanager, contextmanager |
8 | 8 | from contextvars import ContextVar |
9 | 9 | from dataclasses import field |
10 | | -from typing import Any, Generic, Literal, Union, cast |
| 10 | +from typing import TYPE_CHECKING, Any, Generic, Literal, Union, cast |
11 | 11 |
|
12 | 12 | from opentelemetry.trace import Span, Tracer |
13 | 13 | from typing_extensions import TypeGuard, TypeVar, assert_never |
|
27 | 27 | from .models.instrumented import InstrumentedModel |
28 | 28 | from .result import ResultDataT |
29 | 29 | from .settings import ModelSettings, merge_model_settings |
30 | | -from .tools import ( |
31 | | - RunContext, |
32 | | - Tool, |
33 | | - ToolDefinition, |
34 | | -) |
| 30 | +from .tools import RunContext, Tool, ToolDefinition |
| 31 | + |
| 32 | +if TYPE_CHECKING: |
| 33 | + from .mcp import MCPServer |
35 | 34 |
|
36 | 35 | __all__ = ( |
37 | 36 | 'GraphAgentState', |
@@ -94,6 +93,7 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]): |
94 | 93 | result_validators: list[_result.ResultValidator[DepsT, ResultDataT]] |
95 | 94 |
|
96 | 95 | function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False) |
| 96 | + mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False) |
97 | 97 |
|
98 | 98 | run_span: Span |
99 | 99 | tracer: Tracer |
@@ -219,7 +219,17 @@ async def add_tool(tool: Tool[DepsT]) -> None: |
219 | 219 | if tool_def := await tool.prepare_tool_def(ctx): |
220 | 220 | function_tool_defs.append(tool_def) |
221 | 221 |
|
222 | | - await asyncio.gather(*map(add_tool, ctx.deps.function_tools.values())) |
| 222 | + async def add_mcp_server_tools(server: MCPServer) -> None: |
| 223 | + if not server.is_running: |
| 224 | + raise exceptions.UserError(f'MCP server is not running: {server}') |
| 225 | + tool_defs = await server.list_tools() |
| 226 | + # TODO(Marcelo): We should check if the tool names are unique. If not, we should raise an error. |
| 227 | + function_tool_defs.extend(tool_defs) |
| 228 | + |
| 229 | + await asyncio.gather( |
| 230 | + *map(add_tool, ctx.deps.function_tools.values()), |
| 231 | + *map(add_mcp_server_tools, ctx.deps.mcp_servers), |
| 232 | + ) |
223 | 233 |
|
224 | 234 | result_schema = ctx.deps.result_schema |
225 | 235 | return models.ModelRequestParameters( |
@@ -594,6 +604,21 @@ async def process_function_tools( |
594 | 604 | yield event |
595 | 605 | call_index_to_event_id[len(calls_to_run)] = event.call_id |
596 | 606 | calls_to_run.append((tool, call)) |
| 607 | + elif mcp_tool := await _tool_from_mcp_server(call.tool_name, ctx): |
| 608 | + if stub_function_tools: |
| 609 | + # TODO(Marcelo): We should add coverage for this part of the code. |
| 610 | + output_parts.append( # pragma: no cover |
| 611 | + _messages.ToolReturnPart( |
| 612 | + tool_name=call.tool_name, |
| 613 | + content='Tool not executed - a final result was already processed.', |
| 614 | + tool_call_id=call.tool_call_id, |
| 615 | + ) |
| 616 | + ) |
| 617 | + else: |
| 618 | + event = _messages.FunctionToolCallEvent(call) |
| 619 | + yield event |
| 620 | + call_index_to_event_id[len(calls_to_run)] = event.call_id |
| 621 | + calls_to_run.append((mcp_tool, call)) |
597 | 622 | elif result_schema is not None and call.tool_name in result_schema.tools: |
598 | 623 | # if tool_name is in _result_schema, it means we found a result tool but an error occurred in |
599 | 624 | # validation, we don't add another part here |
@@ -641,6 +666,35 @@ async def process_function_tools( |
641 | 666 | output_parts.append(results_by_index[k]) |
642 | 667 |
|
643 | 668 |
|
| 669 | +async def _tool_from_mcp_server( |
| 670 | + tool_name: str, |
| 671 | + ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], |
| 672 | +) -> Tool[DepsT] | None: |
| 673 | + """Call each MCP server to find the tool with the given name. |
| 674 | +
|
| 675 | + Args: |
| 676 | + tool_name: The name of the tool to find. |
| 677 | + ctx: The current run context. |
| 678 | +
|
| 679 | + Returns: |
| 680 | + The tool with the given name, or `None` if no tool with the given name is found. |
| 681 | + """ |
| 682 | + |
| 683 | + async def run_tool(ctx: RunContext[DepsT], **args: Any) -> Any: |
| 684 | + # There's no normal situation where the server will not be running at this point, we check just in case |
| 685 | + # some weird edge case occurs. |
| 686 | + if not server.is_running: # pragma: no cover |
| 687 | + raise exceptions.UserError(f'MCP server is not running: {server}') |
| 688 | + result = await server.call_tool(tool_name, args) |
| 689 | + return result |
| 690 | + |
| 691 | + for server in ctx.deps.mcp_servers: |
| 692 | + tools = await server.list_tools() |
| 693 | + if tool_name in {tool.name for tool in tools}: |
| 694 | + return Tool(name=tool_name, function=run_tool, takes_ctx=True) |
| 695 | + return None |
| 696 | + |
| 697 | + |
644 | 698 | def _unknown_tool( |
645 | 699 | tool_name: str, |
646 | 700 | ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], |
|
0 commit comments