Skip to content

WebSockets MCP Client Support #3016

@sergey-png

Description

@sergey-png

Description

Hello! Thank you for your product, very powerful!
In our company, we use websockets to connect to MCP servers in 90% of cases. We do not use HTTP Streamable and SSE transports, as there are leaks on pods. The session id is nailed to the pod of the service that this session id has written out and you start having problems clearing sessions on the pod side, balancing, and so on.
That's why we really want WebSockets client support. Below I will give an example of how theoretically it is possible to implement this, but I think there is more smarter way to do this.

from contextlib import asynccontextmanager
from typing import AsyncIterator

from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from mcp.client.websocket import websocket_client
from mcp.shared.message import SessionMessage
from pydantic_ai.mcp import MCPServer

# imports may be diffrent

class MCPServerWS(MCPServer):
    """WebSocket-based MCP Server implementation."""

    def __init__(
        self,
        url: str,
        *,
        headers: dict[str, str] | None = None,
        tool_prefix: str | None = None,
        log_level=None,
        log_handler=None,
        timeout: float = 5,
        read_timeout: float = 5 * 60,
        process_tool_call=None,
        allow_sampling: bool = True,
        sampling_model=None,
        max_retries: int = 1,
        elicitation_callback=None,
        id: str | None = None,
    ):
        self.url = url
        self.headers = headers

        super().__init__(
            tool_prefix=tool_prefix,
            log_level=log_level,
            log_handler=log_handler,
            timeout=timeout,
            read_timeout=read_timeout,
            process_tool_call=process_tool_call,
            allow_sampling=allow_sampling,
            sampling_model=sampling_model,
            max_retries=max_retries,
            elicitation_callback=elicitation_callback,
            id=id,
        )

    @asynccontextmanager
    async def client_streams(
        self,
    ) -> AsyncIterator[
        tuple[
            MemoryObjectReceiveStream[SessionMessage | Exception],
            MemoryObjectSendStream[SessionMessage],
        ]
    ]:
        async with websocket_client(
            url=self.url,
            headers=self.headers,
        ) as (read_stream, write_stream):
            yield read_stream, write_stream

References

No response

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions