-
Notifications
You must be signed in to change notification settings - Fork 269
/
Copy pathgen_client.py
85 lines (72 loc) · 2.95 KB
/
gen_client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from contextlib import asynccontextmanager
from datetime import timedelta
from typing import AsyncGenerator, Callable
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from mcp import ClientSession
from mcp_agent.logging.logger import get_logger
from mcp_agent.mcp_server_registry import ServerRegistry
from mcp_agent.mcp.mcp_agent_client_session import MCPAgentClientSession
logger = get_logger(__name__)
@asynccontextmanager
async def gen_client(
server_name: str,
server_registry: ServerRegistry,
client_session_factory: Callable[
[MemoryObjectReceiveStream, MemoryObjectSendStream, timedelta | None],
ClientSession,
] = MCPAgentClientSession,
) -> AsyncGenerator[ClientSession, None]:
"""
Create a client session to the specified server.
Handles server startup, initialization, and message receive loop setup.
If required, callers can specify their own message receive loop and ClientSession class constructor to customize further.
For persistent connections, use connect() or MCPConnectionManager instead.
"""
if not server_registry:
raise ValueError(
"Server registry not found in the context. Please specify one either on this method, or in the context."
)
async with server_registry.initialize_server(
server_name=server_name,
client_session_factory=client_session_factory,
) as session:
yield session
async def connect(
server_name: str,
server_registry: ServerRegistry,
client_session_factory: Callable[
[MemoryObjectReceiveStream, MemoryObjectSendStream, timedelta | None],
ClientSession,
] = MCPAgentClientSession,
) -> ClientSession:
"""
Create a persistent client session to the specified server.
Handles server startup, initialization, and message receive loop setup.
If required, callers can specify their own message receive loop and ClientSession class constructor to customize further.
"""
if not server_registry:
raise ValueError(
"Server registry not found in the context. Please specify one either on this method, or in the context."
)
server_connection = await server_registry.connection_manager.get_server(
server_name=server_name,
client_session_factory=client_session_factory,
)
return server_connection.session
async def disconnect(
server_name: str | None,
server_registry: ServerRegistry,
) -> None:
"""
Disconnect from the specified server. If server_name is None, disconnect from all servers.
"""
if not server_registry:
raise ValueError(
"Server registry not found in the context. Please specify one either on this method, or in the context."
)
if server_name:
await server_registry.connection_manager.disconnect_server(
server_name=server_name
)
else:
await server_registry.connection_manager.disconnect_all()