From 4432b1581b9bb3ced176c36767ecfa2b97932f8b Mon Sep 17 00:00:00 2001 From: Thuraabtech <97426541+Thuraabtech@users.noreply.github.com> Date: Wed, 29 Oct 2025 11:55:15 -0500 Subject: [PATCH 1/4] updated socket protocol implementation to be compatible with 1.0v (#72) * socket protocol updated to be compatible with 1.0v utcp * cubic fixes done * pinned mcp-use to use langchain 0.3.27 * removed mcp denpendency on langchain * adding the langchain dependency for testing (temporary) * remove langchain-core pin to resolve dependency conflict --------- Co-authored-by: Razvan Radulescu <43811028+h3xxit@users.noreply.github.com> Co-authored-by: Salman Mohammed --- .github/workflows/test.yml | 3 +- .../mcp/pyproject.toml | 3 +- .../communication_protocols/socket/README.md | 45 ++- .../socket/pyproject.toml | 5 +- .../socket/src/utcp_socket/__init__.py | 18 ++ .../src/utcp_socket/tcp_call_template.py | 18 +- .../utcp_socket/tcp_communication_protocol.py | 192 +++++++------ .../src/utcp_socket/udp_call_template.py | 16 ++ .../utcp_socket/udp_communication_protocol.py | 191 +++++++------ .../tests/test_tcp_communication_protocol.py | 178 ++++++++++++ .../tests/test_udp_communication_protocol.py | 176 ++++++++++++ scripts/socket_sanity.py | 265 ++++++++++++++++++ socket_plugin_test.py | 40 +++ 13 files changed, 952 insertions(+), 198 deletions(-) create mode 100644 plugins/communication_protocols/socket/tests/test_tcp_communication_protocol.py create mode 100644 plugins/communication_protocols/socket/tests/test_udp_communication_protocol.py create mode 100644 scripts/socket_sanity.py create mode 100644 socket_plugin_test.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9394330..bf677f8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -31,10 +31,11 @@ jobs: pip install -e plugins/communication_protocols/http[dev] pip install -e plugins/communication_protocols/mcp[dev] pip install -e plugins/communication_protocols/text[dev] + pip install -e plugins/communication_protocols/socket[dev] - name: Run tests with pytest run: | - pytest core/tests/ plugins/communication_protocols/cli/tests/ plugins/communication_protocols/http/tests/ plugins/communication_protocols/mcp/tests/ plugins/communication_protocols/text/tests/ --doctest-modules --junitxml=junit/test-results.xml --cov=core/src/utcp --cov-report=xml --cov-report=html + pytest core/tests/ plugins/communication_protocols/cli/tests/ plugins/communication_protocols/http/tests/ plugins/communication_protocols/mcp/tests/ plugins/communication_protocols/text/tests/ plugins/communication_protocols/socket/tests/ --doctest-modules --junitxml=junit/test-results.xml --cov=core/src/utcp --cov-report=xml --cov-report=html - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v3 diff --git a/plugins/communication_protocols/mcp/pyproject.toml b/plugins/communication_protocols/mcp/pyproject.toml index 36bb48e..9232cd5 100644 --- a/plugins/communication_protocols/mcp/pyproject.toml +++ b/plugins/communication_protocols/mcp/pyproject.toml @@ -15,7 +15,8 @@ dependencies = [ "pydantic>=2.0", "mcp>=1.12", "utcp>=1.0", - "mcp-use>=1.3" + "mcp-use>=1.3", + "langchain==0.3.27", ] classifiers = [ "Development Status :: 4 - Beta", diff --git a/plugins/communication_protocols/socket/README.md b/plugins/communication_protocols/socket/README.md index 8febb5a..3e695c9 100644 --- a/plugins/communication_protocols/socket/README.md +++ b/plugins/communication_protocols/socket/README.md @@ -1 +1,44 @@ -Find the UTCP readme at https://github.com/universal-tool-calling-protocol/python-utcp. \ No newline at end of file +# UTCP Socket Plugin (UDP/TCP) + +This plugin adds UDP and TCP communication protocols to UTCP 1.0. + +## Running Tests + +Prerequisites: +- Python 3.10+ +- `pip` +- (Optional) a virtual environment + +1) Install core and the socket plugin in editable mode with dev extras: + +```bash +pip install -e "core[dev]" +pip install -e plugins/communication_protocols/socket[dev] +``` + +2) Run the socket plugin tests: + +```bash +python -m pytest plugins/communication_protocols/socket/tests -v +``` + +3) Run a single test or filter by keyword: + +```bash +# One file +python -m pytest plugins/communication_protocols/socket/tests/test_tcp_communication_protocol.py -v + +# Filter by keyword (e.g., delimiter framing) +python -m pytest plugins/communication_protocols/socket/tests -k delimiter -q +``` + +4) Optional end-to-end sanity check (mock UDP/TCP servers): + +```bash +python scripts/socket_sanity.py +``` + +Notes: +- On Windows, your firewall may prompt the first time tests open UDP/TCP sockets; allow access or run as admin if needed. +- Tests use `pytest-asyncio`. The dev extras installed above provide required dependencies. +- Streaming is single-chunk by design, consistent with HTTP/Text transports. Multi-chunk streaming can be added later behind provider configuration. \ No newline at end of file diff --git a/plugins/communication_protocols/socket/pyproject.toml b/plugins/communication_protocols/socket/pyproject.toml index 06f845e..2f232ad 100644 --- a/plugins/communication_protocols/socket/pyproject.toml +++ b/plugins/communication_protocols/socket/pyproject.toml @@ -36,4 +36,7 @@ dev = [ [project.urls] Homepage = "https://utcp.io" Source = "https://github.com/universal-tool-calling-protocol/python-utcp" -Issues = "https://github.com/universal-tool-calling-protocol/python-utcp/issues" \ No newline at end of file +Issues = "https://github.com/universal-tool-calling-protocol/python-utcp/issues" + +[project.entry-points."utcp.plugins"] +socket = "utcp_socket:register" \ No newline at end of file diff --git a/plugins/communication_protocols/socket/src/utcp_socket/__init__.py b/plugins/communication_protocols/socket/src/utcp_socket/__init__.py index e69de29..a0b7f3b 100644 --- a/plugins/communication_protocols/socket/src/utcp_socket/__init__.py +++ b/plugins/communication_protocols/socket/src/utcp_socket/__init__.py @@ -0,0 +1,18 @@ +from utcp.plugins.discovery import register_communication_protocol, register_call_template +from utcp_socket.tcp_communication_protocol import TCPTransport +from utcp_socket.udp_communication_protocol import UDPTransport +from utcp_socket.tcp_call_template import TCPProviderSerializer +from utcp_socket.udp_call_template import UDPProviderSerializer + + +def register() -> None: + # Register communication protocols + register_communication_protocol("tcp", TCPTransport()) + register_communication_protocol("udp", UDPTransport()) + + # Register call templates and their serializers + register_call_template("tcp", TCPProviderSerializer()) + register_call_template("udp", UDPProviderSerializer()) + + +__all__ = ["register"] \ No newline at end of file diff --git a/plugins/communication_protocols/socket/src/utcp_socket/tcp_call_template.py b/plugins/communication_protocols/socket/src/utcp_socket/tcp_call_template.py index 157e43c..10fc1d6 100644 --- a/plugins/communication_protocols/socket/src/utcp_socket/tcp_call_template.py +++ b/plugins/communication_protocols/socket/src/utcp_socket/tcp_call_template.py @@ -1,6 +1,9 @@ from utcp.data.call_template import CallTemplate from typing import Optional, Literal from pydantic import Field +from utcp.interfaces.serializer import Serializer +from utcp.exceptions import UtcpSerializerValidationError +import traceback class TCPProvider(CallTemplate): """Provider configuration for raw TCP socket tools. @@ -63,7 +66,7 @@ class TCPProvider(CallTemplate): # Delimiter-based framing options message_delimiter: str = Field( default='\x00', - description="Delimiter to detect end of TCP response (e.g., '\\n', '\\r\\n', '\\x00'). Used with 'delimiter' framing." + description="Delimiter to detect end of TCP response (e.g., '\n', '\r\n', '\x00'). Used with 'delimiter' framing." ) # Fixed-length framing options fixed_message_length: Optional[int] = Field( @@ -77,3 +80,16 @@ class TCPProvider(CallTemplate): ) timeout: int = 30000 auth: None = None + + +class TCPProviderSerializer(Serializer[TCPProvider]): + def to_dict(self, obj: TCPProvider) -> dict: + return obj.model_dump() + + def validate_dict(self, data: dict) -> TCPProvider: + try: + return TCPProvider.model_validate(data) + except Exception as e: + raise UtcpSerializerValidationError( + f"Invalid TCPProvider: {e}\n{traceback.format_exc()}" + ) diff --git a/plugins/communication_protocols/socket/src/utcp_socket/tcp_communication_protocol.py b/plugins/communication_protocols/socket/src/utcp_socket/tcp_communication_protocol.py index 1b360a8..d5d64ac 100644 --- a/plugins/communication_protocols/socket/src/utcp_socket/tcp_communication_protocol.py +++ b/plugins/communication_protocols/socket/src/utcp_socket/tcp_communication_protocol.py @@ -10,9 +10,12 @@ import sys from typing import Dict, Any, List, Optional, Callable, Union -from utcp.client.client_transport_interface import ClientTransportInterface -from utcp.shared.provider import Provider, TCPProvider -from utcp.shared.tool import Tool +from utcp.interfaces.communication_protocol import CommunicationProtocol +from utcp_socket.tcp_call_template import TCPProvider, TCPProviderSerializer +from utcp.data.tool import Tool +from utcp.data.call_template import CallTemplate, CallTemplateSerializer +from utcp.data.register_manual_response import RegisterManualResult +from utcp.data.utcp_manual import UtcpManual import logging logging.basicConfig( @@ -22,7 +25,7 @@ logger = logging.getLogger(__name__) -class TCPTransport(ClientTransportInterface): +class TCPTransport(CommunicationProtocol): """Transport implementation for TCP-based tool providers. This transport communicates with tools over TCP sockets. It supports: @@ -85,6 +88,35 @@ def _format_tool_call_message( else: # Default to JSON format return json.dumps(tool_args) + + def _ensure_tool_call_template(self, tool_data: Dict[str, Any], manual_call_template: TCPProvider) -> Dict[str, Any]: + """Normalize tool definition to include a valid 'tool_call_template'. + + - If 'tool_call_template' exists, validate it. + - Else if legacy 'tool_provider' exists, convert using TCPProviderSerializer. + - Else default to the provided manual_call_template. + """ + normalized = dict(tool_data) + try: + if "tool_call_template" in normalized and normalized["tool_call_template"] is not None: + try: + ctpl = CallTemplateSerializer().validate_dict(normalized["tool_call_template"]) # type: ignore + normalized["tool_call_template"] = ctpl + except Exception: + normalized["tool_call_template"] = manual_call_template + elif "tool_provider" in normalized and normalized["tool_provider"] is not None: + try: + ctpl = TCPProviderSerializer().validate_dict(normalized["tool_provider"]) # type: ignore + normalized.pop("tool_provider", None) + normalized["tool_call_template"] = ctpl + except Exception: + normalized.pop("tool_provider", None) + normalized["tool_call_template"] = manual_call_template + else: + normalized["tool_call_template"] = manual_call_template + except Exception: + normalized["tool_call_template"] = manual_call_template + return normalized def _encode_message_with_framing(self, message: str, provider: TCPProvider) -> bytes: """Encode message with appropriate TCP framing. @@ -115,7 +147,7 @@ def _encode_message_with_framing(self, message: str, provider: TCPProvider) -> b elif provider.framing_strategy == "delimiter": # Add delimiter after the message - delimiter = provider.message_delimiter or "\\x00" + delimiter = provider.message_delimiter or "\x00" # Handle escape sequences delimiter = delimiter.encode('utf-8').decode('unicode_escape') return message_bytes + delimiter.encode('utf-8') @@ -170,7 +202,7 @@ def _decode_response_with_framing(self, sock: socket.socket, provider: TCPProvid elif provider.framing_strategy == "delimiter": # Read until delimiter is found - delimiter = provider.message_delimiter or "\\x00" + delimiter = provider.message_delimiter or "\x00" delimiter = delimiter.encode('utf-8').decode('unicode_escape').encode('utf-8') response_data = b"" @@ -215,9 +247,6 @@ def _decode_response_with_framing(self, sock: socket.socket, provider: TCPProvid return response_data - else: - raise ValueError(f"Unknown framing strategy: {provider.framing_strategy}") - async def _send_tcp_message( self, host: str, @@ -289,122 +318,91 @@ def _send_and_receive(): self._log_error(f"Error in TCP communication: {e}") raise - async def register_tool_provider(self, manual_provider: Provider) -> List[Tool]: - """Register a TCP provider and discover its tools. - - Sends a discovery message to the TCP provider and parses the response. - - Args: - manual_provider: The TCPProvider to register - - Returns: - List of tools discovered from the TCP provider - - Raises: - ValueError: If provider is not a TCPProvider - """ - if not isinstance(manual_provider, TCPProvider): + async def register_manual(self, caller, manual_call_template: CallTemplate) -> RegisterManualResult: + """Register a TCP manual and discover its tools.""" + if not isinstance(manual_call_template, TCPProvider): raise ValueError("TCPTransport can only be used with TCPProvider") - self._log_info(f"Registering TCP provider '{manual_provider.name}'") + self._log_info(f"Registering TCP provider '{manual_call_template.name}'") try: - # Send discovery message - discovery_message = json.dumps({ - "type": "utcp" - }) - + discovery_message = json.dumps({"type": "utcp"}) response = await self._send_tcp_message( - manual_provider.host, - manual_provider.port, + manual_call_template.host, + manual_call_template.port, discovery_message, - manual_provider, - manual_provider.timeout / 1000.0, # Convert ms to seconds - manual_provider.response_byte_format + manual_call_template, + manual_call_template.timeout / 1000.0, + manual_call_template.response_byte_format ) - - # Parse response try: - # Handle bytes response by trying to decode as UTF-8 for JSON parsing - if isinstance(response, bytes): - response_str = response.decode('utf-8') - else: - response_str = response - + response_str = response.decode('utf-8') if isinstance(response, bytes) else response response_data = json.loads(response_str) - - # Check if response contains tools + tools: List[Tool] = [] if isinstance(response_data, dict) and 'tools' in response_data: tools_data = response_data['tools'] - - # Parse tools - tools = [] for tool_data in tools_data: try: - tool = Tool(**tool_data) - tools.append(tool) + normalized = self._ensure_tool_call_template(tool_data, manual_call_template) + tools.append(Tool(**normalized)) except Exception as e: - self._log_error(f"Invalid tool definition in TCP provider '{manual_provider.name}': {e}") + self._log_error(f"Invalid tool definition in TCP provider '{manual_call_template.name}': {e}") continue - - self._log_info(f"Discovered {len(tools)} tools from TCP provider '{manual_provider.name}'") - return tools + self._log_info(f"Discovered {len(tools)} tools from TCP provider '{manual_call_template.name}'") else: - self._log_info(f"No tools found in TCP provider '{manual_provider.name}' response") - return [] - + self._log_info(f"No tools found in TCP provider '{manual_call_template.name}' response") + manual = UtcpManual(utcp_version="1.0", manual_version="1.0", tools=tools) + return RegisterManualResult( + manual_call_template=manual_call_template, + manual=manual, + success=True, + errors=[] + ) except json.JSONDecodeError as e: - self._log_error(f"Invalid JSON response from TCP provider '{manual_provider.name}': {e}") - return [] - + self._log_error(f"Invalid JSON response from TCP provider '{manual_call_template.name}': {e}") + return RegisterManualResult( + manual_call_template=manual_call_template, + manual=UtcpManual(utcp_version="1.0", manual_version="1.0", tools=[]), + success=False, + errors=[str(e)] + ) except Exception as e: - self._log_error(f"Error registering TCP provider '{manual_provider.name}': {e}") - return [] + self._log_error(f"Error registering TCP provider '{manual_call_template.name}': {e}") + return RegisterManualResult( + manual_call_template=manual_call_template, + manual=UtcpManual(utcp_version="1.0", manual_version="1.0", tools=[]), + success=False, + errors=[str(e)] + ) - async def deregister_tool_provider(self, manual_provider: Provider) -> None: - """Deregister a TCP provider. - - This is a no-op for TCP providers since connections are created per request. - - Args: - manual_provider: The provider to deregister - """ - if not isinstance(manual_provider, TCPProvider): + async def deregister_manual(self, caller, manual_call_template: CallTemplate) -> None: + """Deregister a TCP provider (no-op).""" + if not isinstance(manual_call_template, TCPProvider): raise ValueError("TCPTransport can only be used with TCPProvider") - - self._log_info(f"Deregistering TCP provider '{manual_provider.name}' (no-op)") + self._log_info(f"Deregistering TCP provider '{manual_call_template.name}' (no-op)") - async def call_tool(self, tool_name: str, tool_args: Dict[str, Any], tool_provider: Provider) -> Any: - """Call a TCP tool. - - Sends a tool call message to the TCP provider and returns the response. - - Args: - tool_name: Name of the tool to call - tool_args: Arguments for the tool call - tool_provider: The TCPProvider containing the tool - - Returns: - The response from the TCP tool - - Raises: - ValueError: If provider is not a TCPProvider - """ - if not isinstance(tool_provider, TCPProvider): + async def call_tool_streaming(self, caller, tool_name: str, tool_args: Dict[str, Any], tool_call_template: CallTemplate): + async def _generator(): + yield await self.call_tool(caller, tool_name, tool_args, tool_call_template) + return _generator() + + async def call_tool(self, caller, tool_name: str, tool_args: Dict[str, Any], tool_call_template: CallTemplate) -> Any: + """Call a TCP tool.""" + if not isinstance(tool_call_template, TCPProvider): raise ValueError("TCPTransport can only be used with TCPProvider") - self._log_info(f"Calling TCP tool '{tool_name}' on provider '{tool_provider.name}'") + self._log_info(f"Calling TCP tool '{tool_name}' on provider '{tool_call_template.name}'") try: - tool_call_message = self._format_tool_call_message(tool_args, tool_provider) + tool_call_message = self._format_tool_call_message(tool_args, tool_call_template) response = await self._send_tcp_message( - tool_provider.host, - tool_provider.port, + tool_call_template.host, + tool_call_template.port, tool_call_message, - tool_provider, - tool_provider.timeout / 1000.0, # Convert ms to seconds - tool_provider.response_byte_format + tool_call_template, + tool_call_template.timeout / 1000.0, + tool_call_template.response_byte_format ) return response diff --git a/plugins/communication_protocols/socket/src/utcp_socket/udp_call_template.py b/plugins/communication_protocols/socket/src/utcp_socket/udp_call_template.py index 4c704da..8c30c86 100644 --- a/plugins/communication_protocols/socket/src/utcp_socket/udp_call_template.py +++ b/plugins/communication_protocols/socket/src/utcp_socket/udp_call_template.py @@ -1,6 +1,9 @@ from utcp.data.call_template import CallTemplate from typing import Optional, Literal from pydantic import Field +from utcp.interfaces.serializer import Serializer +from utcp.exceptions import UtcpSerializerValidationError +import traceback class UDPProvider(CallTemplate): """Provider configuration for UDP (User Datagram Protocol) socket tools. @@ -38,3 +41,16 @@ class UDPProvider(CallTemplate): response_byte_format: Optional[str] = Field(default="utf-8", description="Encoding to decode response bytes. If None, returns raw bytes.") timeout: int = 30000 auth: None = None + + +class UDPProviderSerializer(Serializer[UDPProvider]): + def to_dict(self, obj: UDPProvider) -> dict: + return obj.model_dump() + + def validate_dict(self, data: dict) -> UDPProvider: + try: + return UDPProvider.model_validate(data) + except Exception as e: + raise UtcpSerializerValidationError( + f"Invalid UDPProvider: {e}\n{traceback.format_exc()}" + ) diff --git a/plugins/communication_protocols/socket/src/utcp_socket/udp_communication_protocol.py b/plugins/communication_protocols/socket/src/utcp_socket/udp_communication_protocol.py index 8d4d404..b59ef37 100644 --- a/plugins/communication_protocols/socket/src/utcp_socket/udp_communication_protocol.py +++ b/plugins/communication_protocols/socket/src/utcp_socket/udp_communication_protocol.py @@ -9,14 +9,17 @@ import traceback from typing import Dict, Any, List, Optional, Callable, Union -from utcp.client.client_transport_interface import ClientTransportInterface -from utcp.shared.provider import Provider, UDPProvider -from utcp.shared.tool import Tool +from utcp.interfaces.communication_protocol import CommunicationProtocol +from utcp_socket.udp_call_template import UDPProvider, UDPProviderSerializer +from utcp.data.tool import Tool +from utcp.data.call_template import CallTemplate, CallTemplateSerializer +from utcp.data.register_manual_response import RegisterManualResult +from utcp.data.utcp_manual import UtcpManual import logging logger = logging.getLogger(__name__) -class UDPTransport(ClientTransportInterface): +class UDPTransport(CommunicationProtocol): """Transport implementation for UDP-based tool providers. This transport communicates with tools over UDP sockets. It supports: @@ -80,6 +83,38 @@ def _format_tool_call_message( else: # Default to JSON format return json.dumps(tool_args) + + def _ensure_tool_call_template(self, tool_data: Dict[str, Any], manual_call_template: UDPProvider) -> Dict[str, Any]: + """Normalize tool definition to include a valid 'tool_call_template'. + + - If 'tool_call_template' exists, validate it. + - Else if legacy 'tool_provider' exists, convert using UDPProviderSerializer. + - Else default to the provided manual_call_template. + """ + normalized = dict(tool_data) + try: + if "tool_call_template" in normalized and normalized["tool_call_template"] is not None: + # Validate via generic CallTemplate serializer (type-dispatched) + try: + ctpl = CallTemplateSerializer().validate_dict(normalized["tool_call_template"]) # type: ignore + normalized["tool_call_template"] = ctpl + except Exception: + # Fallback to manual template if validation fails + normalized["tool_call_template"] = manual_call_template + elif "tool_provider" in normalized and normalized["tool_provider"] is not None: + # Convert legacy provider -> call template + try: + ctpl = UDPProviderSerializer().validate_dict(normalized["tool_provider"]) # type: ignore + normalized.pop("tool_provider", None) + normalized["tool_call_template"] = ctpl + except Exception: + normalized.pop("tool_provider", None) + normalized["tool_call_template"] = manual_call_template + else: + normalized["tool_call_template"] = manual_call_template + except Exception: + normalized["tool_call_template"] = manual_call_template + return normalized async def _send_udp_message( self, @@ -202,125 +237,89 @@ def _send_only(): self._log_error(f"Error sending UDP message (no response): {traceback.format_exc()}") raise - async def register_tool_provider(self, manual_provider: Provider) -> List[Tool]: - """Register a UDP provider and discover its tools. - - Sends a discovery message to the UDP provider and parses the response. - - Args: - manual_provider: The UDPProvider to register - - Returns: - List of tools discovered from the UDP provider - - Raises: - ValueError: If provider is not a UDPProvider - """ - if not isinstance(manual_provider, UDPProvider): + async def register_manual(self, caller, manual_call_template: CallTemplate) -> RegisterManualResult: + """Register a UDP manual and discover its tools.""" + if not isinstance(manual_call_template, UDPProvider): raise ValueError("UDPTransport can only be used with UDPProvider") - self._log_info(f"Registering UDP provider '{manual_provider.name}' at {manual_provider.host}:{manual_provider.port}") + self._log_info(f"Registering UDP provider '{manual_call_template.name}' at {manual_call_template.host}:{manual_call_template.port}") try: - # Send discovery message - discovery_message = json.dumps({ - "type": "utcp" - }) - + discovery_message = json.dumps({"type": "utcp"}) response = await self._send_udp_message( - manual_provider.host, - manual_provider.port, + manual_call_template.host, + manual_call_template.port, discovery_message, - manual_provider.timeout / 1000.0, # Convert ms to seconds - manual_provider.number_of_response_datagrams, - manual_provider.response_byte_format + manual_call_template.timeout / 1000.0, + manual_call_template.number_of_response_datagrams, + manual_call_template.response_byte_format ) - - # Parse response try: - # Handle bytes response by trying to decode as UTF-8 for JSON parsing - if isinstance(response, bytes): - response_str = response.decode('utf-8') - else: - response_str = response - + response_str = response.decode('utf-8') if isinstance(response, bytes) else response response_data = json.loads(response_str) - - # Check if response contains tools + tools: List[Tool] = [] if isinstance(response_data, dict) and 'tools' in response_data: tools_data = response_data['tools'] - - # Parse tools - tools = [] for tool_data in tools_data: try: - tool = Tool(**tool_data) + normalized = self._ensure_tool_call_template(tool_data, manual_call_template) + tool = Tool(**normalized) tools.append(tool) - except Exception as e: - self._log_error(f"Invalid tool definition in UDP provider '{manual_provider.name}': {traceback.format_exc()}") + except Exception: + self._log_error(f"Invalid tool definition in UDP provider '{manual_call_template.name}': {traceback.format_exc()}") continue - - self._log_info(f"Discovered {len(tools)} tools from UDP provider '{manual_provider.name}'") - return tools + self._log_info(f"Discovered {len(tools)} tools from UDP provider '{manual_call_template.name}'") else: - self._log_info(f"No tools found in UDP provider '{manual_provider.name}' response") - return [] - + self._log_info(f"No tools found in UDP provider '{manual_call_template.name}' response") + manual = UtcpManual(utcp_version="1.0", manual_version="1.0", tools=tools) + return RegisterManualResult( + manual_call_template=manual_call_template, + manual=manual, + success=True, + errors=[] + ) except json.JSONDecodeError as e: - self._log_error(f"Invalid JSON response from UDP provider '{manual_provider.name}': {traceback.format_exc()}") - return [] - + self._log_error(f"Invalid JSON response from UDP provider '{manual_call_template.name}': {traceback.format_exc()}") + manual = UtcpManual(utcp_version="1.0", manual_version="1.0", tools=[]) + return RegisterManualResult( + manual_call_template=manual_call_template, + manual=manual, + success=False, + errors=[str(e)] + ) except Exception as e: - self._log_error(f"Error registering UDP provider '{manual_provider.name}': {traceback.format_exc()}") - return [] + self._log_error(f"Error registering UDP provider '{manual_call_template.name}': {traceback.format_exc()}") + manual = UtcpManual(utcp_version="1.0", manual_version="1.0", tools=[]) + return RegisterManualResult( + manual_call_template=manual_call_template, + manual=manual, + success=False, + errors=[str(e)] + ) - async def deregister_tool_provider(self, manual_provider: Provider) -> None: - """Deregister a UDP provider. - - This is a no-op for UDP providers since they are stateless. - - Args: - manual_provider: The provider to deregister - """ - if not isinstance(manual_provider, UDPProvider): + async def deregister_manual(self, caller, manual_call_template: CallTemplate) -> None: + if not isinstance(manual_call_template, UDPProvider): raise ValueError("UDPTransport can only be used with UDPProvider") - - self._log_info(f"Deregistering UDP provider '{manual_provider.name}' (no-op)") + self._log_info(f"Deregistering UDP provider '{manual_call_template.name}' (no-op)") - async def call_tool(self, tool_name: str, tool_args: Dict[str, Any], tool_provider: Provider) -> Any: - """Call a UDP tool. - - Sends a tool call message to the UDP provider and returns the response. - - Args: - tool_name: Name of the tool to call - arguments: Arguments for the tool call - tool_provider: The UDPProvider containing the tool - - Returns: - The response from the UDP tool - - Raises: - ValueError: If provider is not a UDPProvider - """ - if not isinstance(tool_provider, UDPProvider): + async def call_tool(self, caller, tool_name: str, tool_args: Dict[str, Any], tool_call_template: CallTemplate) -> Any: + if not isinstance(tool_call_template, UDPProvider): raise ValueError("UDPTransport can only be used with UDPProvider") - - self._log_info(f"Calling UDP tool '{tool_name}' on provider '{tool_provider.name}'") - + self._log_info(f"Calling UDP tool '{tool_name}' on provider '{tool_call_template.name}'") try: - tool_call_message = self._format_tool_call_message(tool_args, tool_provider) - + tool_call_message = self._format_tool_call_message(tool_args, tool_call_template) response = await self._send_udp_message( - tool_provider.host, - tool_provider.port, + tool_call_template.host, + tool_call_template.port, tool_call_message, - tool_provider.timeout / 1000.0, # Convert ms to seconds - tool_provider.number_of_response_datagrams, - tool_provider.response_byte_format + tool_call_template.timeout / 1000.0, + tool_call_template.number_of_response_datagrams, + tool_call_template.response_byte_format ) return response - except Exception as e: self._log_error(f"Error calling UDP tool '{tool_name}': {traceback.format_exc()}") raise + + async def call_tool_streaming(self, caller, tool_name: str, tool_args: Dict[str, Any], tool_call_template: CallTemplate): + yield await self.call_tool(caller, tool_name, tool_args, tool_call_template) diff --git a/plugins/communication_protocols/socket/tests/test_tcp_communication_protocol.py b/plugins/communication_protocols/socket/tests/test_tcp_communication_protocol.py new file mode 100644 index 0000000..1b6ffb2 --- /dev/null +++ b/plugins/communication_protocols/socket/tests/test_tcp_communication_protocol.py @@ -0,0 +1,178 @@ +import asyncio +import json +import pytest + +from utcp_socket.tcp_communication_protocol import TCPTransport +from utcp_socket.tcp_call_template import TCPProvider + + +async def start_tcp_server(): + """Start a simple TCP server that sends a mutable JSON object then closes.""" + response_container = {"bytes": b""} + + async def handle(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + try: + # Read any incoming data to simulate request handling + await reader.read(1024) + except Exception: + pass + # Send response and close connection + writer.write(response_container["bytes"]) + await writer.drain() + try: + writer.close() + await writer.wait_closed() + except Exception: + pass + + server = await asyncio.start_server(handle, host="127.0.0.1", port=0) + port = server.sockets[0].getsockname()[1] + + def set_response(obj): + response_container["bytes"] = json.dumps(obj).encode("utf-8") + + return server, port, set_response + + +@pytest.mark.asyncio +async def test_register_manual_converts_legacy_tool_provider_tcp(): + """When manual returns legacy tool_provider, it is converted to tool_call_template.""" + # Start server and configure response after obtaining port + server, port, set_response = await start_tcp_server() + set_response({ + "tools": [ + { + "name": "tcp_tool", + "description": "Echo over TCP", + "inputs": {}, + "outputs": {}, + "tool_provider": { + "call_template_type": "tcp", + "name": "tcp-executor", + "host": "127.0.0.1", + "port": port, + "request_data_format": "json", + "response_byte_format": "utf-8", + "framing_strategy": "stream", + "timeout": 2000 + } + } + ] + }) + + try: + provider = TCPProvider( + name="tcp-provider", + host="127.0.0.1", + port=port, + request_data_format="json", + response_byte_format="utf-8", + framing_strategy="stream", + timeout=2000 + ) + transport_client = TCPTransport() + result = await transport_client.register_manual(None, provider) + + assert result.success + assert result.manual is not None + assert len(result.manual.tools) == 1 + tool = result.manual.tools[0] + assert tool.tool_call_template.call_template_type == "tcp" + assert isinstance(tool.tool_call_template, TCPProvider) + assert tool.tool_call_template.host == "127.0.0.1" + assert tool.tool_call_template.port == port + finally: + server.close() + await server.wait_closed() + + +@pytest.mark.asyncio +async def test_register_manual_validates_provided_tool_call_template_tcp(): + """When manual provides tool_call_template, it is validated and preserved.""" + server, port, set_response = await start_tcp_server() + set_response({ + "tools": [ + { + "name": "tcp_tool", + "description": "Echo over TCP", + "inputs": {}, + "outputs": {}, + "tool_call_template": { + "call_template_type": "tcp", + "name": "tcp-executor", + "host": "127.0.0.1", + "port": port, + "request_data_format": "json", + "response_byte_format": "utf-8", + "framing_strategy": "stream", + "timeout": 2000 + } + } + ] + }) + + try: + provider = TCPProvider( + name="tcp-provider", + host="127.0.0.1", + port=port, + request_data_format="json", + response_byte_format="utf-8", + framing_strategy="stream", + timeout=2000 + ) + transport_client = TCPTransport() + result = await transport_client.register_manual(None, provider) + + assert result.success + assert len(result.manual.tools) == 1 + tool = result.manual.tools[0] + assert tool.tool_call_template.call_template_type == "tcp" + assert isinstance(tool.tool_call_template, TCPProvider) + assert tool.tool_call_template.host == "127.0.0.1" + assert tool.tool_call_template.port == port + finally: + server.close() + await server.wait_closed() + + +@pytest.mark.asyncio +async def test_register_manual_fallbacks_to_manual_template_tcp(): + """When neither tool_provider nor tool_call_template is provided, fall back to manual template.""" + server, port, set_response = await start_tcp_server() + set_response({ + "tools": [ + { + "name": "tcp_tool", + "description": "Echo over TCP", + "inputs": {}, + "outputs": {} + } + ] + }) + + try: + provider = TCPProvider( + name="tcp-provider", + host="127.0.0.1", + port=port, + request_data_format="json", + response_byte_format="utf-8", + framing_strategy="stream", + timeout=2000 + ) + transport_client = TCPTransport() + result = await transport_client.register_manual(None, provider) + + assert result.success + assert len(result.manual.tools) == 1 + tool = result.manual.tools[0] + assert tool.tool_call_template.call_template_type == "tcp" + assert isinstance(tool.tool_call_template, TCPProvider) + # Should match manual (discovery) provider values + assert tool.tool_call_template.host == provider.host + assert tool.tool_call_template.port == provider.port + assert tool.tool_call_template.name == provider.name + finally: + server.close() + await server.wait_closed() \ No newline at end of file diff --git a/plugins/communication_protocols/socket/tests/test_udp_communication_protocol.py b/plugins/communication_protocols/socket/tests/test_udp_communication_protocol.py new file mode 100644 index 0000000..d6a770c --- /dev/null +++ b/plugins/communication_protocols/socket/tests/test_udp_communication_protocol.py @@ -0,0 +1,176 @@ +import asyncio +import json +import pytest + +from utcp_socket.udp_communication_protocol import UDPTransport +from utcp_socket.udp_call_template import UDPProvider + + +async def start_udp_server(): + """Start a simple UDP server that replies with a mutable JSON payload.""" + loop = asyncio.get_running_loop() + response_container = {"bytes": b""} + + class _Protocol(asyncio.DatagramProtocol): + def __init__(self, container): + self.container = container + self.transport = None + + def connection_made(self, transport): + self.transport = transport + + def datagram_received(self, data, addr): + # Always respond with the prepared payload + if self.transport: + self.transport.sendto(self.container["bytes"], addr) + + transport, protocol = await loop.create_datagram_endpoint( + lambda: _Protocol(response_container), local_addr=("127.0.0.1", 0) + ) + port = transport.get_extra_info("socket").getsockname()[1] + + def set_response(obj): + response_container["bytes"] = json.dumps(obj).encode("utf-8") + + return transport, port, set_response + + +@pytest.mark.asyncio +async def test_register_manual_converts_legacy_tool_provider_udp(): + """When manual returns legacy tool_provider, it is converted to tool_call_template.""" + # Start server and configure response after obtaining port + transport, port, set_response = await start_udp_server() + set_response({ + "tools": [ + { + "name": "udp_tool", + "description": "Echo over UDP", + "inputs": {}, + "outputs": {}, + "tool_provider": { + "call_template_type": "udp", + "name": "udp-executor", + "host": "127.0.0.1", + "port": port, + "number_of_response_datagrams": 1, + "request_data_format": "json", + "response_byte_format": "utf-8", + "timeout": 2000 + } + } + ] + }) + + try: + provider = UDPProvider( + name="udp-provider", + host="127.0.0.1", + port=port, + number_of_response_datagrams=1, + request_data_format="json", + response_byte_format="utf-8", + timeout=2000 + ) + transport_client = UDPTransport() + result = await transport_client.register_manual(None, provider) + + assert result.success + assert result.manual is not None + assert len(result.manual.tools) == 1 + tool = result.manual.tools[0] + assert tool.tool_call_template.call_template_type == "udp" + assert isinstance(tool.tool_call_template, UDPProvider) + assert tool.tool_call_template.host == "127.0.0.1" + assert tool.tool_call_template.port == port + finally: + transport.close() + + +@pytest.mark.asyncio +async def test_register_manual_validates_provided_tool_call_template_udp(): + """When manual provides tool_call_template, it is validated and preserved.""" + transport, port, set_response = await start_udp_server() + set_response({ + "tools": [ + { + "name": "udp_tool", + "description": "Echo over UDP", + "inputs": {}, + "outputs": {}, + "tool_call_template": { + "call_template_type": "udp", + "name": "udp-executor", + "host": "127.0.0.1", + "port": port, + "number_of_response_datagrams": 1, + "request_data_format": "json", + "response_byte_format": "utf-8", + "timeout": 2000 + } + } + ] + }) + + try: + provider = UDPProvider( + name="udp-provider", + host="127.0.0.1", + port=port, + number_of_response_datagrams=1, + request_data_format="json", + response_byte_format="utf-8", + timeout=2000 + ) + transport_client = UDPTransport() + result = await transport_client.register_manual(None, provider) + + assert result.success + assert len(result.manual.tools) == 1 + tool = result.manual.tools[0] + assert tool.tool_call_template.call_template_type == "udp" + assert isinstance(tool.tool_call_template, UDPProvider) + assert tool.tool_call_template.host == "127.0.0.1" + assert tool.tool_call_template.port == port + finally: + transport.close() + + +@pytest.mark.asyncio +async def test_register_manual_fallbacks_to_manual_template_udp(): + """When neither tool_provider nor tool_call_template is provided, fall back to manual template.""" + transport, port, set_response = await start_udp_server() + set_response({ + "tools": [ + { + "name": "udp_tool", + "description": "Echo over UDP", + "inputs": {}, + "outputs": {} + } + ] + }) + + try: + provider = UDPProvider( + name="udp-provider", + host="127.0.0.1", + port=port, + number_of_response_datagrams=1, + request_data_format="json", + response_byte_format="utf-8", + timeout=2000 + ) + transport_client = UDPTransport() + result = await transport_client.register_manual(None, provider) + + assert result.success + assert len(result.manual.tools) == 1 + tool = result.manual.tools[0] + assert tool.tool_call_template.call_template_type == "udp" + assert isinstance(tool.tool_call_template, UDPProvider) + # Should match manual (discovery) provider values + assert tool.tool_call_template.host == provider.host + assert tool.tool_call_template.port == provider.port + assert tool.tool_call_template.name == provider.name + finally: + transport.close() \ No newline at end of file diff --git a/scripts/socket_sanity.py b/scripts/socket_sanity.py new file mode 100644 index 0000000..40b2c16 --- /dev/null +++ b/scripts/socket_sanity.py @@ -0,0 +1,265 @@ +import sys +import os +import json +import time +import socket +import threading +import asyncio +from pathlib import Path + +# Ensure core and socket plugin sources are on sys.path +ROOT = Path(__file__).resolve().parent.parent +CORE_SRC = ROOT / "core" / "src" +SOCKET_SRC = ROOT / "plugins" / "communication_protocols" / "socket" / "src" +for p in [str(CORE_SRC), str(SOCKET_SRC)]: + if p not in sys.path: + sys.path.insert(0, p) + +from utcp_socket.udp_communication_protocol import UDPTransport +from utcp_socket.tcp_communication_protocol import TCPTransport +from utcp_socket.udp_call_template import UDPProvider +from utcp_socket.tcp_call_template import TCPProvider + +# ------------------------------- +# Mock UDP Server +# ------------------------------- + +def start_udp_server(host: str, port: int): + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.bind((host, port)) + + def run(): + while True: + data, addr = sock.recvfrom(65535) + try: + msg = data.decode("utf-8") + except Exception: + msg = "" + # Handle discovery + try: + parsed = json.loads(msg) + except Exception: + parsed = None + if isinstance(parsed, dict) and parsed.get("type") == "utcp": + manual = { + "utcp_version": "1.0", + "manual_version": "1.0", + "tools": [ + { + "name": "udp.echo", + "description": "Echo UDP args as JSON", + "inputs": { + "type": "object", + "properties": { + "text": {"type": "string"}, + "extra": {"type": "number"} + }, + "required": ["text"] + }, + "outputs": { + "type": "object", + "properties": { + "ok": {"type": "boolean"}, + "echo": {"type": "string"}, + "args": {"type": "object"} + } + }, + "tags": ["socket", "udp"], + "average_response_size": 64, + # Return legacy provider to exercise conversion path + "tool_provider": { + "call_template_type": "udp", + "name": "udp", + "host": host, + "port": port, + "request_data_format": "json", + "response_byte_format": "utf-8", + "number_of_response_datagrams": 1, + "timeout": 3000 + } + } + ] + } + payload = json.dumps(manual).encode("utf-8") + sock.sendto(payload, addr) + else: + # Tool call: echo JSON payload + try: + args = json.loads(msg) + except Exception: + args = {"raw": msg} + resp = { + "ok": True, + "echo": args.get("text", ""), + "args": args + } + sock.sendto(json.dumps(resp).encode("utf-8"), addr) + t = threading.Thread(target=run, daemon=True) + t.start() + return t + +# ------------------------------- +# Mock TCP Server (delimiter-based) +# ------------------------------- + +def start_tcp_server(host: str, port: int, delimiter: str = "\n"): + delim_bytes = delimiter.encode("utf-8") + + def handle_client(conn: socket.socket, addr): + try: + # Read until delimiter + buf = b"" + while True: + chunk = conn.recv(1) + if not chunk: + break + buf += chunk + if buf.endswith(delim_bytes): + break + msg = buf[:-len(delim_bytes)].decode("utf-8") if buf.endswith(delim_bytes) else buf.decode("utf-8") + # Discovery + parsed = None + try: + parsed = json.loads(msg) + except Exception: + pass + if isinstance(parsed, dict) and parsed.get("type") == "utcp": + manual = { + "utcp_version": "1.0", + "manual_version": "1.0", + "tools": [ + { + "name": "tcp.echo", + "description": "Echo TCP args as JSON", + "inputs": { + "type": "object", + "properties": { + "text": {"type": "string"}, + "extra": {"type": "number"} + }, + "required": ["text"] + }, + "outputs": { + "type": "object", + "properties": { + "ok": {"type": "boolean"}, + "echo": {"type": "string"}, + "args": {"type": "object"} + } + }, + "tags": ["socket", "tcp"], + "average_response_size": 64, + # Legacy provider to exercise conversion + "tool_provider": { + "call_template_type": "tcp", + "name": "tcp", + "host": host, + "port": port, + "request_data_format": "json", + "response_byte_format": "utf-8", + "framing_strategy": "delimiter", + "message_delimiter": "\\n", + "timeout": 3000 + } + } + ] + } + payload = json.dumps(manual).encode("utf-8") + delim_bytes + conn.sendall(payload) + else: + # Tool call: echo JSON payload + try: + args = json.loads(msg) + except Exception: + args = {"raw": msg} + resp = { + "ok": True, + "echo": args.get("text", ""), + "args": args + } + conn.sendall(json.dumps(resp).encode("utf-8") + delim_bytes) + finally: + try: + conn.shutdown(socket.SHUT_RDWR) + except Exception: + pass + conn.close() + + def run(): + srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + srv.bind((host, port)) + srv.listen(5) + while True: + conn, addr = srv.accept() + threading.Thread(target=handle_client, args=(conn, addr), daemon=True).start() + + t = threading.Thread(target=run, daemon=True) + t.start() + return t + +# ------------------------------- +# Sanity test runner +# ------------------------------- + +async def run_sanity(): + udp_host, udp_port = "127.0.0.1", 23456 + tcp_host, tcp_port = "127.0.0.1", 23457 + + # Start servers + start_udp_server(udp_host, udp_port) + start_tcp_server(tcp_host, tcp_port, delimiter="\n") + await asyncio.sleep(0.2) # small delay to ensure servers are listening + + # Transports + udp_transport = UDPTransport() + tcp_transport = TCPTransport() + + # Register manuals + udp_manual_template = UDPProvider(name="udp", host=udp_host, port=udp_port, request_data_format="json", response_byte_format="utf-8", number_of_response_datagrams=1, timeout=3000) + tcp_manual_template = TCPProvider(name="tcp", host=tcp_host, port=tcp_port, request_data_format="json", response_byte_format="utf-8", framing_strategy="delimiter", message_delimiter="\n", timeout=3000) + + udp_reg = await udp_transport.register_manual(None, udp_manual_template) + tcp_reg = await tcp_transport.register_manual(None, tcp_manual_template) + + print("UDP register success:", udp_reg.success, "tools:", len(udp_reg.manual.tools)) + print("TCP register success:", tcp_reg.success, "tools:", len(tcp_reg.manual.tools)) + + assert udp_reg.success and len(udp_reg.manual.tools) == 1 + assert tcp_reg.success and len(tcp_reg.manual.tools) == 1 + + # Verify tool_call_template present + assert udp_reg.manual.tools[0].tool_call_template.call_template_type == "udp" + assert tcp_reg.manual.tools[0].tool_call_template.call_template_type == "tcp" + + # Call tools + udp_result = await udp_transport.call_tool(None, "udp.echo", {"text": "hello", "extra": 42}, udp_reg.manual.tools[0].tool_call_template) + tcp_result = await tcp_transport.call_tool(None, "tcp.echo", {"text": "world", "extra": 99}, tcp_reg.manual.tools[0].tool_call_template) + + print("UDP call result:", udp_result) + print("TCP call result:", tcp_result) + + # Basic assertions on response shape + def ensure_dict(s): + if isinstance(s, (bytes, bytearray)): + try: + s = s.decode("utf-8") + except Exception: + return {} + if isinstance(s, str): + try: + return json.loads(s) + except Exception: + return {"raw": s} + return s if isinstance(s, dict) else {} + + udp_resp = ensure_dict(udp_result) + tcp_resp = ensure_dict(tcp_result) + + assert udp_resp.get("ok") is True and udp_resp.get("echo") == "hello" + assert tcp_resp.get("ok") is True and tcp_resp.get("echo") == "world" + + print("Sanity passed: UDP/TCP discovery and calls work with tool_call_template normalization.") + +if __name__ == "__main__": + asyncio.run(run_sanity()) \ No newline at end of file diff --git a/socket_plugin_test.py b/socket_plugin_test.py new file mode 100644 index 0000000..c03d2a9 --- /dev/null +++ b/socket_plugin_test.py @@ -0,0 +1,40 @@ +import asyncio +import sys +from pathlib import Path + +# Add core and plugin src paths so imports work without installing packages +core_src = Path(__file__).parent / "core" / "src" +socket_src = Path(__file__).parent / "plugins" / "communication_protocols" / "socket" / "src" +sys.path.insert(0, str(core_src.resolve())) +sys.path.insert(0, str(socket_src.resolve())) + +from utcp.plugins.plugin_loader import ensure_plugins_initialized +from utcp.interfaces.communication_protocol import CommunicationProtocol +from utcp.data.call_template import CallTemplateSerializer +from utcp_socket import register as register_socket + +async def main(): + # Manually register the socket plugin + register_socket() + + # Load core plugins (auth, repo, search, post-processors) + ensure_plugins_initialized() + + # 1. Check if communication protocols are registered + registered_protocols = CommunicationProtocol.communication_protocols + print(f"Registered communication protocols: {list(registered_protocols.keys())}") + assert "tcp" in registered_protocols, "TCP communication protocol not registered" + assert "udp" in registered_protocols, "UDP communication protocol not registered" + print("āœ… TCP and UDP communication protocols are registered.") + + # 2. Check if call templates are registered + registered_serializers = CallTemplateSerializer.call_template_serializers + print(f"Registered call template serializers: {list(registered_serializers.keys())}") + assert "tcp" in registered_serializers, "TCP call template serializer not registered" + assert "udp" in registered_serializers, "UDP call template serializer not registered" + print("āœ… TCP and UDP call template serializers are registered.") + + print("\nšŸŽ‰ Socket plugin sanity check passed! šŸŽ‰") + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file From dd2eb3c1085f4f1ad885f94583305e5988d0bd15 Mon Sep 17 00:00:00 2001 From: Thuraabtech <97426541+Thuraabtech@users.noreply.github.com> Date: Sat, 29 Nov 2025 08:25:43 -0600 Subject: [PATCH 2/4] GraphQL Plugin: UTCP 1.0 Migration (#75) * socket protocol updated to be compatible with 1.0v utcp * cubic fixes done * pinned mcp-use to use langchain 0.3.27 * removed mcp denpendency on langchain * adding the langchain dependency for testing (temporary) * remove langchain-core pin to resolve dependency conflict * feat: Updated Graphql implementation to be compatible with UTCP 1.0v * Added gql 'how to use' guide in the README.md * updated cubic comments for GraphQl * Update comment on delimeter handling Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Razvan Radulescu <43811028+h3xxit@users.noreply.github.com> Co-authored-by: Salman Mohammed Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- plugins/communication_protocols/gql/README.md | 48 +++- .../gql/src/utcp_gql/__init__.py | 9 + .../gql/src/utcp_gql/gql_call_template.py | 35 ++- .../utcp_gql/gql_communication_protocol.py | 257 ++++++++++++------ .../gql/tests/test_graphql_protocol.py | 110 ++++++++ .../mcp/pyproject.toml | 2 +- .../src/utcp_socket/tcp_call_template.py | 4 + .../utcp_socket/tcp_communication_protocol.py | 35 ++- .../utcp_socket/udp_communication_protocol.py | 18 +- .../tests/test_tcp_communication_protocol.py | 2 + scripts/socket_sanity.py | 6 +- 11 files changed, 421 insertions(+), 105 deletions(-) create mode 100644 plugins/communication_protocols/gql/tests/test_graphql_protocol.py diff --git a/plugins/communication_protocols/gql/README.md b/plugins/communication_protocols/gql/README.md index 8febb5a..34a2518 100644 --- a/plugins/communication_protocols/gql/README.md +++ b/plugins/communication_protocols/gql/README.md @@ -1 +1,47 @@ -Find the UTCP readme at https://github.com/universal-tool-calling-protocol/python-utcp. \ No newline at end of file + +# UTCP GraphQL Communication Protocol Plugin + +This plugin integrates GraphQL as a UTCP 1.0 communication protocol and call template. It supports discovery via schema introspection, authenticated calls, and header handling. + +## Getting Started + +### Installation + +```bash +pip install gql +``` + +### Registration + +```python +import utcp_gql +utcp_gql.register() +``` + +## How To Use + +- Ensure the plugin is imported and registered: `import utcp_gql; utcp_gql.register()`. +- Add a manual in your client config: + ```json + { + "name": "my_graph", + "call_template_type": "graphql", + "url": "https://your.graphql/endpoint", + "operation_type": "query", + "headers": { "x-client": "utcp" }, + "header_fields": ["x-session-id"] + } + ``` +- Call a tool: + ```python + await client.call_tool("my_graph.someQuery", {"id": "123", "x-session-id": "abc"}) + ``` + +## Notes + +- Tool names are prefixed by the manual name (e.g., `my_graph.someQuery`). +- Headers merge static `headers` plus whitelisted dynamic fields from `header_fields`. +- Supported auth: API key, Basic auth, OAuth2 (client-credentials). +- Security: only `https://` or `http://localhost`/`http://127.0.0.1` endpoints. + +For UTCP core docs, see https://github.com/universal-tool-calling-protocol/python-utcp. \ No newline at end of file diff --git a/plugins/communication_protocols/gql/src/utcp_gql/__init__.py b/plugins/communication_protocols/gql/src/utcp_gql/__init__.py index e69de29..7362502 100644 --- a/plugins/communication_protocols/gql/src/utcp_gql/__init__.py +++ b/plugins/communication_protocols/gql/src/utcp_gql/__init__.py @@ -0,0 +1,9 @@ +from utcp.plugins.discovery import register_communication_protocol, register_call_template + +from .gql_communication_protocol import GraphQLCommunicationProtocol +from .gql_call_template import GraphQLProvider, GraphQLProviderSerializer + + +def register(): + register_communication_protocol("graphql", GraphQLCommunicationProtocol()) + register_call_template("graphql", GraphQLProviderSerializer()) \ No newline at end of file diff --git a/plugins/communication_protocols/gql/src/utcp_gql/gql_call_template.py b/plugins/communication_protocols/gql/src/utcp_gql/gql_call_template.py index dfe5b07..3848d29 100644 --- a/plugins/communication_protocols/gql/src/utcp_gql/gql_call_template.py +++ b/plugins/communication_protocols/gql/src/utcp_gql/gql_call_template.py @@ -1,7 +1,10 @@ from utcp.data.call_template import CallTemplate -from utcp.data.auth import Auth +from utcp.data.auth import Auth, AuthSerializer +from utcp.interfaces.serializer import Serializer +from utcp.exceptions import UtcpSerializerValidationError +import traceback from typing import Dict, List, Optional, Literal -from pydantic import Field +from pydantic import Field, field_serializer, field_validator class GraphQLProvider(CallTemplate): """Provider configuration for GraphQL-based tools. @@ -27,3 +30,31 @@ class GraphQLProvider(CallTemplate): auth: Optional[Auth] = None headers: Optional[Dict[str, str]] = None header_fields: Optional[List[str]] = Field(default=None, description="List of input fields to be sent as request headers for the initial connection.") + + @field_serializer("auth") + def serialize_auth(self, auth: Optional[Auth]): + if auth is None: + return None + return AuthSerializer().to_dict(auth) + + @field_validator("auth", mode="before") + @classmethod + def validate_auth(cls, v: Optional[Auth | dict]): + if v is None: + return None + if isinstance(v, Auth): + return v + return AuthSerializer().validate_dict(v) + + +class GraphQLProviderSerializer(Serializer[GraphQLProvider]): + def to_dict(self, obj: GraphQLProvider) -> dict: + return obj.model_dump() + + def validate_dict(self, data: dict) -> GraphQLProvider: + try: + return GraphQLProvider.model_validate(data) + except Exception as e: + raise UtcpSerializerValidationError( + f"Invalid GraphQLProvider: {e}\n{traceback.format_exc()}" + ) diff --git a/plugins/communication_protocols/gql/src/utcp_gql/gql_communication_protocol.py b/plugins/communication_protocols/gql/src/utcp_gql/gql_communication_protocol.py index f27f803..9d26cab 100644 --- a/plugins/communication_protocols/gql/src/utcp_gql/gql_communication_protocol.py +++ b/plugins/communication_protocols/gql/src/utcp_gql/gql_communication_protocol.py @@ -1,36 +1,55 @@ -import sys -from typing import Dict, Any, List, Optional, Callable +import logging +from typing import Dict, Any, List, Optional, AsyncGenerator, TYPE_CHECKING + import aiohttp -import asyncio -import ssl from gql import Client as GqlClient, gql as gql_query from gql.transport.aiohttp import AIOHTTPTransport -from utcp.client.client_transport_interface import ClientTransportInterface -from utcp.shared.provider import Provider, GraphQLProvider -from utcp.shared.tool import Tool, ToolInputOutputSchema -from utcp.shared.auth import ApiKeyAuth, BasicAuth, OAuth2Auth -import logging + +from utcp.interfaces.communication_protocol import CommunicationProtocol +from utcp.data.call_template import CallTemplate +from utcp.data.tool import Tool, JsonSchema +from utcp.data.utcp_manual import UtcpManual +from utcp.data.register_manual_response import RegisterManualResult +from utcp.data.auth_implementations.api_key_auth import ApiKeyAuth +from utcp.data.auth_implementations.basic_auth import BasicAuth +from utcp.data.auth_implementations.oauth2_auth import OAuth2Auth + +from utcp_gql.gql_call_template import GraphQLProvider + +if TYPE_CHECKING: + from utcp.utcp_client import UtcpClient + logging.basicConfig( level=logging.INFO, - format="%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d - %(message)s" + format="%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d - %(message)s", ) logger = logging.getLogger(__name__) -class GraphQLClientTransport(ClientTransportInterface): - """ - Simple, robust, production-ready GraphQL transport using gql. - Stateless, per-operation. Supports all GraphQL features. + +class GraphQLCommunicationProtocol(CommunicationProtocol): + """GraphQL protocol implementation for UTCP 1.0. + + - Discovers tools via GraphQL schema introspection. + - Executes per-call sessions using `gql` over HTTP(S). + - Supports `ApiKeyAuth`, `BasicAuth`, and `OAuth2Auth`. + - Enforces HTTPS or localhost for security. """ - def __init__(self): + + def __init__(self) -> None: self._oauth_tokens: Dict[str, Dict[str, Any]] = {} - def _enforce_https_or_localhost(self, url: str): - if not (url.startswith("https://") or url.startswith("http://localhost") or url.startswith("http://127.0.0.1")): + def _enforce_https_or_localhost(self, url: str) -> None: + if not ( + url.startswith("https://") + or url.startswith("http://localhost") + or url.startswith("http://127.0.0.1") + ): raise ValueError( - f"Security error: URL must use HTTPS or start with 'http://localhost' or 'http://127.0.0.1'. Got: {url}. " - "Non-secure URLs are vulnerable to man-in-the-middle attacks." + "Security error: URL must use HTTPS or start with 'http://localhost' or 'http://127.0.0.1'. " + "Non-secure URLs are vulnerable to man-in-the-middle attacks. " + f"Got: {url}." ) async def _handle_oauth2(self, auth: OAuth2Auth) -> str: @@ -39,10 +58,10 @@ async def _handle_oauth2(self, auth: OAuth2Auth) -> str: return self._oauth_tokens[client_id]["access_token"] async with aiohttp.ClientSession() as session: data = { - 'grant_type': 'client_credentials', - 'client_id': client_id, - 'client_secret': auth.client_secret, - 'scope': auth.scope + "grant_type": "client_credentials", + "client_id": client_id, + "client_secret": auth.client_secret, + "scope": auth.scope, } async with session.post(auth.token_url, data=data) as resp: resp.raise_for_status() @@ -50,87 +69,147 @@ async def _handle_oauth2(self, auth: OAuth2Auth) -> str: self._oauth_tokens[client_id] = token_response return token_response["access_token"] - async def _prepare_headers(self, provider: GraphQLProvider) -> Dict[str, str]: - headers = provider.headers.copy() if provider.headers else {} + async def _prepare_headers( + self, provider: GraphQLProvider, tool_args: Optional[Dict[str, Any]] = None + ) -> Dict[str, str]: + headers: Dict[str, str] = provider.headers.copy() if provider.headers else {} if provider.auth: if isinstance(provider.auth, ApiKeyAuth): - if provider.auth.api_key: - if provider.auth.location == "header": - headers[provider.auth.var_name] = provider.auth.api_key - # (query/cookie not supported for GraphQL by default) + if provider.auth.api_key and provider.auth.location == "header": + headers[provider.auth.var_name] = provider.auth.api_key elif isinstance(provider.auth, BasicAuth): import base64 + userpass = f"{provider.auth.username}:{provider.auth.password}" headers["Authorization"] = "Basic " + base64.b64encode(userpass.encode()).decode() elif isinstance(provider.auth, OAuth2Auth): token = await self._handle_oauth2(provider.auth) headers["Authorization"] = f"Bearer {token}" + + # Map selected tool_args into headers if requested + if tool_args and provider.header_fields: + for field in provider.header_fields: + if field in tool_args and isinstance(tool_args[field], str): + headers[field] = tool_args[field] + return headers - async def register_tool_provider(self, manual_provider: Provider) -> List[Tool]: - if not isinstance(manual_provider, GraphQLProvider): - raise ValueError("GraphQLClientTransport can only be used with GraphQLProvider") - self._enforce_https_or_localhost(manual_provider.url) - headers = await self._prepare_headers(manual_provider) - transport = AIOHTTPTransport(url=manual_provider.url, headers=headers) - async with GqlClient(transport=transport, fetch_schema_from_transport=True) as session: - schema = session.client.schema - tools = [] - # Queries - if hasattr(schema, 'query_type') and schema.query_type: - for name, field in schema.query_type.fields.items(): - tools.append(Tool( - name=name, - description=getattr(field, 'description', '') or '', - inputs=ToolInputOutputSchema(required=None), - tool_provider=manual_provider - )) - # Mutations - if hasattr(schema, 'mutation_type') and schema.mutation_type: - for name, field in schema.mutation_type.fields.items(): - tools.append(Tool( - name=name, - description=getattr(field, 'description', '') or '', - inputs=ToolInputOutputSchema(required=None), - tool_provider=manual_provider - )) - # Subscriptions (listed, but not called here) - if hasattr(schema, 'subscription_type') and schema.subscription_type: - for name, field in schema.subscription_type.fields.items(): - tools.append(Tool( - name=name, - description=getattr(field, 'description', '') or '', - inputs=ToolInputOutputSchema(required=None), - tool_provider=manual_provider - )) - return tools - - async def deregister_tool_provider(self, manual_provider: Provider) -> None: - # Stateless: nothing to do - pass - - async def call_tool(self, tool_name: str, tool_args: Dict[str, Any], tool_provider: Provider, query: Optional[str] = None) -> Any: - if not isinstance(tool_provider, GraphQLProvider): - raise ValueError("GraphQLClientTransport can only be used with GraphQLProvider") - self._enforce_https_or_localhost(tool_provider.url) - headers = await self._prepare_headers(tool_provider) - transport = AIOHTTPTransport(url=tool_provider.url, headers=headers) + async def register_manual( + self, caller: "UtcpClient", manual_call_template: CallTemplate + ) -> RegisterManualResult: + if not isinstance(manual_call_template, GraphQLProvider): + raise ValueError("GraphQLCommunicationProtocol requires a GraphQLProvider call template") + self._enforce_https_or_localhost(manual_call_template.url) + + try: + headers = await self._prepare_headers(manual_call_template) + transport = AIOHTTPTransport(url=manual_call_template.url, headers=headers) + async with GqlClient(transport=transport, fetch_schema_from_transport=True) as session: + schema = session.client.schema + tools: List[Tool] = [] + + # Queries + if hasattr(schema, "query_type") and schema.query_type: + for name, field in schema.query_type.fields.items(): + tools.append( + Tool( + name=name, + description=getattr(field, "description", "") or "", + inputs=JsonSchema(type="object"), + outputs=JsonSchema(type="object"), + tool_call_template=manual_call_template, + ) + ) + + # Mutations + if hasattr(schema, "mutation_type") and schema.mutation_type: + for name, field in schema.mutation_type.fields.items(): + tools.append( + Tool( + name=name, + description=getattr(field, "description", "") or "", + inputs=JsonSchema(type="object"), + outputs=JsonSchema(type="object"), + tool_call_template=manual_call_template, + ) + ) + + # Subscriptions (listed for completeness) + if hasattr(schema, "subscription_type") and schema.subscription_type: + for name, field in schema.subscription_type.fields.items(): + tools.append( + Tool( + name=name, + description=getattr(field, "description", "") or "", + inputs=JsonSchema(type="object"), + outputs=JsonSchema(type="object"), + tool_call_template=manual_call_template, + ) + ) + + manual = UtcpManual(tools=tools) + return RegisterManualResult( + manual_call_template=manual_call_template, + manual=manual, + success=True, + errors=[], + ) + except Exception as e: + logger.error(f"GraphQL manual registration failed for '{manual_call_template.name}': {e}") + return RegisterManualResult( + manual_call_template=manual_call_template, + manual=UtcpManual(manual_version="0.0.0", tools=[]), + success=False, + errors=[str(e)], + ) + + async def deregister_manual( + self, caller: "UtcpClient", manual_call_template: CallTemplate + ) -> None: + # Stateless: nothing to clean up + return None + + async def call_tool( + self, + caller: "UtcpClient", + tool_name: str, + tool_args: Dict[str, Any], + tool_call_template: CallTemplate, + ) -> Any: + if not isinstance(tool_call_template, GraphQLProvider): + raise ValueError("GraphQLCommunicationProtocol requires a GraphQLProvider call template") + self._enforce_https_or_localhost(tool_call_template.url) + + headers = await self._prepare_headers(tool_call_template, tool_args) + transport = AIOHTTPTransport(url=tool_call_template.url, headers=headers) async with GqlClient(transport=transport, fetch_schema_from_transport=True) as session: - if query is not None: - document = gql_query(query) - result = await session.execute(document, variable_values=tool_args) - return result - # If no query provided, build a simple query - # Default to query operation - op_type = getattr(tool_provider, 'operation_type', 'query') - arg_str = ', '.join(f"${k}: String" for k in tool_args.keys()) + op_type = getattr(tool_call_template, "operation_type", "query") + # Strip manual prefix if present (client prefixes at save time) + base_tool_name = tool_name.split(".", 1)[-1] if "." in tool_name else tool_name + # Filter out header fields from GraphQL variables; these are sent via HTTP headers + header_fields = tool_call_template.header_fields or [] + filtered_args = {k: v for k, v in tool_args.items() if k not in header_fields} + + arg_str = ", ".join(f"${k}: String" for k in filtered_args.keys()) var_defs = f"({arg_str})" if arg_str else "" - arg_pass = ', '.join(f"{k}: ${k}" for k in tool_args.keys()) + arg_pass = ", ".join(f"{k}: ${k}" for k in filtered_args.keys()) arg_pass = f"({arg_pass})" if arg_pass else "" - gql_str = f"{op_type} {var_defs} {{ {tool_name}{arg_pass} }}" + + gql_str = f"{op_type} {var_defs} {{ {base_tool_name}{arg_pass} }}" document = gql_query(gql_str) - result = await session.execute(document, variable_values=tool_args) + result = await session.execute(document, variable_values=filtered_args) return result + async def call_tool_streaming( + self, + caller: "UtcpClient", + tool_name: str, + tool_args: Dict[str, Any], + tool_call_template: CallTemplate, + ) -> AsyncGenerator[Any, None]: + # Basic implementation: execute non-streaming and yield once + result = await self.call_tool(caller, tool_name, tool_args, tool_call_template) + yield result + async def close(self) -> None: - self._oauth_tokens.clear() + self._oauth_tokens.clear() \ No newline at end of file diff --git a/plugins/communication_protocols/gql/tests/test_graphql_protocol.py b/plugins/communication_protocols/gql/tests/test_graphql_protocol.py new file mode 100644 index 0000000..1b1bb74 --- /dev/null +++ b/plugins/communication_protocols/gql/tests/test_graphql_protocol.py @@ -0,0 +1,110 @@ +import os +import sys +import types +import pytest + + +# Ensure plugin src is importable +PLUGIN_SRC = os.path.join(os.path.dirname(__file__), "..", "src") +PLUGIN_SRC = os.path.abspath(PLUGIN_SRC) +if PLUGIN_SRC not in sys.path: + sys.path.append(PLUGIN_SRC) + +import utcp_gql +# Simplify imports: use the main module and assign local aliases +GraphQLProvider = utcp_gql.gql_call_template.GraphQLProvider +gql_module = utcp_gql.gql_communication_protocol + +from utcp.data.utcp_manual import UtcpManual +from utcp.utcp_client import UtcpClient +from utcp.implementations.utcp_client_implementation import UtcpClientImplementation + + +class FakeSchema: + def __init__(self): + # Minimal field objects with descriptions + self.query_type = types.SimpleNamespace( + fields={ + "hello": types.SimpleNamespace(description="Returns greeting"), + } + ) + self.mutation_type = types.SimpleNamespace( + fields={ + "add": types.SimpleNamespace(description="Adds numbers"), + } + ) + self.subscription_type = None + + +class FakeClientObj: + def __init__(self): + self.client = types.SimpleNamespace(schema=FakeSchema()) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def execute(self, document, variable_values=None): + # document is a gql query; we can base behavior on variable_values + variable_values = variable_values or {} + # Determine operation by presence of variables used + if "hello" in str(document): + name = variable_values.get("name", "") + return {"hello": f"Hello {name}"} + if "add" in str(document): + a = int(variable_values.get("a", 0)) + b = int(variable_values.get("b", 0)) + return {"add": a + b} + return {"ok": True} + + +class FakeTransport: + def __init__(self, url: str, headers: dict | None = None): + self.url = url + self.headers = headers or {} + + +@pytest.mark.asyncio +async def test_graphql_register_and_call(monkeypatch): + # Patch gql client/transport used by protocol to avoid needing a real server + monkeypatch.setattr(gql_module, "GqlClient", lambda *args, **kwargs: FakeClientObj()) + monkeypatch.setattr(gql_module, "AIOHTTPTransport", FakeTransport) + # Avoid real GraphQL parsing; pass-through document string to fake execute + monkeypatch.setattr(gql_module, "gql_query", lambda s: s) + + # Register plugin (call_template serializer + protocol) + utcp_gql.register() + + # Create protocol and manual call template + protocol = gql_module.GraphQLCommunicationProtocol() + provider = GraphQLProvider( + name="mock_graph", + call_template_type="graphql", + url="http://localhost/graphql", + operation_type="query", + headers={"x-client": "utcp"}, + header_fields=["x-session-id"], + ) + + # Minimal UTCP client implementation for caller context + client: UtcpClient = await UtcpClientImplementation.create() + client.config.variables = {} + + # Register and discover tools + reg = await protocol.register_manual(client, provider) + assert reg.success is True + assert isinstance(reg.manual, UtcpManual) + tool_names = sorted(t.name for t in reg.manual.tools) + assert "hello" in tool_names + assert "add" in tool_names + + # Call hello + res = await protocol.call_tool(client, "mock_graph.hello", {"name": "UTCP", "x-session-id": "abc"}, provider) + assert res == {"hello": "Hello UTCP"} + + # Call add (mutation) + provider.operation_type = "mutation" + res2 = await protocol.call_tool(client, "mock_graph.add", {"a": 2, "b": 3}, provider) + assert res2 == {"add": 5} \ No newline at end of file diff --git a/plugins/communication_protocols/mcp/pyproject.toml b/plugins/communication_protocols/mcp/pyproject.toml index 9232cd5..2efd4c3 100644 --- a/plugins/communication_protocols/mcp/pyproject.toml +++ b/plugins/communication_protocols/mcp/pyproject.toml @@ -16,7 +16,7 @@ dependencies = [ "mcp>=1.12", "utcp>=1.0", "mcp-use>=1.3", - "langchain==0.3.27", + "langchain>=0.3.27,<0.4.0", ] classifiers = [ "Development Status :: 4 - Beta", diff --git a/plugins/communication_protocols/socket/src/utcp_socket/tcp_call_template.py b/plugins/communication_protocols/socket/src/utcp_socket/tcp_call_template.py index 10fc1d6..8b27d1c 100644 --- a/plugins/communication_protocols/socket/src/utcp_socket/tcp_call_template.py +++ b/plugins/communication_protocols/socket/src/utcp_socket/tcp_call_template.py @@ -68,6 +68,10 @@ class TCPProvider(CallTemplate): default='\x00', description="Delimiter to detect end of TCP response (e.g., '\n', '\r\n', '\x00'). Used with 'delimiter' framing." ) + interpret_escape_sequences: bool = Field( + default=True, + description="If True, interpret Python-style escape sequences in message_delimiter (e.g., '\\n', '\\r\\n', '\\x00'). If False, use the delimiter literally as provided." + ) # Fixed-length framing options fixed_message_length: Optional[int] = Field( default=None, diff --git a/plugins/communication_protocols/socket/src/utcp_socket/tcp_communication_protocol.py b/plugins/communication_protocols/socket/src/utcp_socket/tcp_communication_protocol.py index d5d64ac..b2f08c3 100644 --- a/plugins/communication_protocols/socket/src/utcp_socket/tcp_communication_protocol.py +++ b/plugins/communication_protocols/socket/src/utcp_socket/tcp_communication_protocol.py @@ -148,9 +148,14 @@ def _encode_message_with_framing(self, message: str, provider: TCPProvider) -> b elif provider.framing_strategy == "delimiter": # Add delimiter after the message delimiter = provider.message_delimiter or "\x00" - # Handle escape sequences - delimiter = delimiter.encode('utf-8').decode('unicode_escape') - return message_bytes + delimiter.encode('utf-8') + if provider.interpret_escape_sequences: + # Handle escape sequences (e.g., "\n", "\r\n", "\x00") + delimiter = delimiter.encode('utf-8').decode('unicode_escape') + delimiter_bytes = delimiter.encode('utf-8') + else: + # Use delimiter literally as provided + delimiter_bytes = delimiter.encode('utf-8') + return message_bytes + delimiter_bytes elif provider.framing_strategy in ("fixed_length", "stream"): # No additional framing needed @@ -202,8 +207,19 @@ def _decode_response_with_framing(self, sock: socket.socket, provider: TCPProvid elif provider.framing_strategy == "delimiter": # Read until delimiter is found + # Delimiter handling: + # The code supports both literal delimiters (e.g., "\\x00") and escape-sequence interpreted delimiters (e.g., "\x00") + # via the `interpret_escape_sequences` flag in TCPProvider. This ensures compatibility with both legacy and updated + # wire protocols. The delimiter is interpreted according to the flag, so no breaking change occurs unless the flag + # is set differently than expected by the server/client. + # Example: + # If interpret_escape_sequences is True, "\\x00" becomes a null byte; if False, it remains four literal bytes. + # delimiter = delimiter.encode('utf-8') delimiter = provider.message_delimiter or "\x00" - delimiter = delimiter.encode('utf-8').decode('unicode_escape').encode('utf-8') + if provider.interpret_escape_sequences: + delimiter_bytes = delimiter.encode('utf-8').decode('unicode_escape').encode('utf-8') + else: + delimiter_bytes = delimiter.encode('utf-8') response_data = b"" while True: @@ -213,9 +229,9 @@ def _decode_response_with_framing(self, sock: socket.socket, provider: TCPProvid response_data += chunk # Check if we've received the delimiter - if response_data.endswith(delimiter): + if response_data.endswith(delimiter_bytes): # Remove delimiter from response - return response_data[:-len(delimiter)] + return response_data[:-len(delimiter_bytes)] elif provider.framing_strategy == "fixed_length": # Read exactly fixed_message_length bytes @@ -246,6 +262,13 @@ def _decode_response_with_framing(self, sock: socket.socket, provider: TCPProvid break return response_data + + else: + # Copilot AI (5 days ago): + # The else branch for unknown framing strategies was previously removed, + # which could cause silent fallthrough and confusing behavior. Add explicit + # validation to raise a descriptive error when an unsupported strategy is provided. + raise ValueError(f"Unknown framing strategy: {provider.framing_strategy!r}") async def _send_tcp_message( self, diff --git a/plugins/communication_protocols/socket/src/utcp_socket/udp_communication_protocol.py b/plugins/communication_protocols/socket/src/utcp_socket/udp_communication_protocol.py index b59ef37..89ae3e3 100644 --- a/plugins/communication_protocols/socket/src/utcp_socket/udp_communication_protocol.py +++ b/plugins/communication_protocols/socket/src/utcp_socket/udp_communication_protocol.py @@ -15,6 +15,7 @@ from utcp.data.call_template import CallTemplate, CallTemplateSerializer from utcp.data.register_manual_response import RegisterManualResult from utcp.data.utcp_manual import UtcpManual +from utcp.exceptions import UtcpSerializerValidationError import logging logger = logging.getLogger(__name__) @@ -98,8 +99,9 @@ def _ensure_tool_call_template(self, tool_data: Dict[str, Any], manual_call_temp try: ctpl = CallTemplateSerializer().validate_dict(normalized["tool_call_template"]) # type: ignore normalized["tool_call_template"] = ctpl - except Exception: - # Fallback to manual template if validation fails + except (UtcpSerializerValidationError, ValueError) as e: + # Fallback to manual template if validation fails, but log details + logger.exception("Failed to validate existing tool_call_template; falling back to manual template") normalized["tool_call_template"] = manual_call_template elif "tool_provider" in normalized and normalized["tool_provider"] is not None: # Convert legacy provider -> call template @@ -107,12 +109,15 @@ def _ensure_tool_call_template(self, tool_data: Dict[str, Any], manual_call_temp ctpl = UDPProviderSerializer().validate_dict(normalized["tool_provider"]) # type: ignore normalized.pop("tool_provider", None) normalized["tool_call_template"] = ctpl - except Exception: + except UtcpSerializerValidationError as e: + logger.exception("Failed to convert legacy tool_provider to call template; falling back to manual template") normalized.pop("tool_provider", None) normalized["tool_call_template"] = manual_call_template else: normalized["tool_call_template"] = manual_call_template except Exception: + # Any unexpected error during normalization should be logged + logger.exception("Unexpected error normalizing tool definition; falling back to manual template") normalized["tool_call_template"] = manual_call_template return normalized @@ -321,5 +326,12 @@ async def call_tool(self, caller, tool_name: str, tool_args: Dict[str, Any], too self._log_error(f"Error calling UDP tool '{tool_name}': {traceback.format_exc()}") raise + # Copilot AI (5 days ago): + # The call_tool_streaming method wraps a generator function but doesn't use the async def syntax for the method itself. + # While this works, it's inconsistent with the other implementation in tcp_communication_protocol.py (lines 384-387) which properly uses async def with an inner generator. + # For consistency and clarity, this should also use async def directly: + # + # async def call_tool_streaming(self, caller, tool_name: str, tool_args: Dict[str, Any], tool_call_template: CallTemplate): + # yield await self.call_tool(caller, tool_name, tool_args, tool_call_template) async def call_tool_streaming(self, caller, tool_name: str, tool_args: Dict[str, Any], tool_call_template: CallTemplate): yield await self.call_tool(caller, tool_name, tool_args, tool_call_template) diff --git a/plugins/communication_protocols/socket/tests/test_tcp_communication_protocol.py b/plugins/communication_protocols/socket/tests/test_tcp_communication_protocol.py index 1b6ffb2..d359fd9 100644 --- a/plugins/communication_protocols/socket/tests/test_tcp_communication_protocol.py +++ b/plugins/communication_protocols/socket/tests/test_tcp_communication_protocol.py @@ -15,6 +15,7 @@ async def handle(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): # Read any incoming data to simulate request handling await reader.read(1024) except Exception: + # Ignore exceptions during read (e.g., client disconnects), as this is a test server. pass # Send response and close connection writer.write(response_container["bytes"]) @@ -23,6 +24,7 @@ async def handle(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): writer.close() await writer.wait_closed() except Exception: + # Ignore exceptions during writer close; connection may already be closed or in error state. pass server = await asyncio.start_server(handle, host="127.0.0.1", port=0) diff --git a/scripts/socket_sanity.py b/scripts/socket_sanity.py index 40b2c16..5ac6028 100644 --- a/scripts/socket_sanity.py +++ b/scripts/socket_sanity.py @@ -1,7 +1,5 @@ import sys -import os import json -import time import socket import threading import asyncio @@ -39,6 +37,7 @@ def run(): try: parsed = json.loads(msg) except Exception: + # Ignore JSON parsing errors; non-JSON input will be handled below parsed = None if isinstance(parsed, dict) and parsed.get("type") == "utcp": manual = { @@ -182,6 +181,7 @@ def handle_client(conn: socket.socket, addr): try: conn.shutdown(socket.SHUT_RDWR) except Exception: + # Ignore errors if socket is already closed or shutdown fails pass conn.close() @@ -259,7 +259,7 @@ def ensure_dict(s): assert udp_resp.get("ok") is True and udp_resp.get("echo") == "hello" assert tcp_resp.get("ok") is True and tcp_resp.get("echo") == "world" - print("Sanity passed: UDP/TCP discovery and calls work with tool_call_template normalization.") + print("Sanity check passed: UDP/TCP discovery and calls work with tool_call_template normalization.") if __name__ == "__main__": asyncio.run(run_sanity()) \ No newline at end of file From 7e1d47da1fa42b8b424ab66ceeec5d162b1d0602 Mon Sep 17 00:00:00 2001 From: Razvan Radulescu <43811028+h3xxit@users.noreply.github.com> Date: Sat, 29 Nov 2025 16:09:22 +0100 Subject: [PATCH 3/4] update --- .../gql/old_tests/test_graphql_transport.py | 129 -------- .../gql/src/utcp_gql/__init__.py | 4 +- .../gql/src/utcp_gql/gql_call_template.py | 47 ++- .../utcp_gql/gql_communication_protocol.py | 64 ++-- .../gql/tests/test_graphql_integration.py | 275 ++++++++++++++++++ .../gql/tests/test_graphql_protocol.py | 110 ------- .../communication_protocols/socket/README.md | 4 +- .../websocket/tests/__init__.py | 1 - socket_plugin_test.py | 40 --- test_websocket_manual.py | 201 ------------- 10 files changed, 358 insertions(+), 517 deletions(-) delete mode 100644 plugins/communication_protocols/gql/old_tests/test_graphql_transport.py create mode 100644 plugins/communication_protocols/gql/tests/test_graphql_integration.py delete mode 100644 plugins/communication_protocols/gql/tests/test_graphql_protocol.py delete mode 100644 plugins/communication_protocols/websocket/tests/__init__.py delete mode 100644 socket_plugin_test.py delete mode 100644 test_websocket_manual.py diff --git a/plugins/communication_protocols/gql/old_tests/test_graphql_transport.py b/plugins/communication_protocols/gql/old_tests/test_graphql_transport.py deleted file mode 100644 index d33c323..0000000 --- a/plugins/communication_protocols/gql/old_tests/test_graphql_transport.py +++ /dev/null @@ -1,129 +0,0 @@ -# import pytest -# import pytest_asyncio -# import json -# from aiohttp import web -# from utcp.client.transport_interfaces.graphql_transport import GraphQLClientTransport -# from utcp.shared.provider import GraphQLProvider -# from utcp.shared.auth import ApiKeyAuth, BasicAuth, OAuth2Auth - - -# @pytest_asyncio.fixture -# async def graphql_app(): -# async def graphql_handler(request): -# body = await request.json() -# query = body.get("query", "") -# variables = body.get("variables", {}) -# # Introspection query (minimal response) -# if "__schema" in query: -# return web.json_response({ -# "data": { -# "__schema": { -# "queryType": {"name": "Query"}, -# "mutationType": {"name": "Mutation"}, -# "subscriptionType": None, -# "types": [ -# {"kind": "OBJECT", "name": "Query", "fields": [ -# {"name": "hello", "args": [{"name": "name", "type": {"kind": "SCALAR", "name": "String"}, "defaultValue": None}], "type": {"kind": "SCALAR", "name": "String"}, "isDeprecated": False, "deprecationReason": None} -# ], "interfaces": []}, -# {"kind": "OBJECT", "name": "Mutation", "fields": [ -# {"name": "add", "args": [ -# {"name": "a", "type": {"kind": "SCALAR", "name": "Int"}, "defaultValue": None}, -# {"name": "b", "type": {"kind": "SCALAR", "name": "Int"}, "defaultValue": None} -# ], "type": {"kind": "SCALAR", "name": "Int"}, "isDeprecated": False, "deprecationReason": None} -# ], "interfaces": []}, -# {"kind": "SCALAR", "name": "String"}, -# {"kind": "SCALAR", "name": "Int"}, -# {"kind": "SCALAR", "name": "Boolean"} -# ], -# "directives": [] -# } -# } -# }) -# # hello query -# if "hello" in query: -# name = variables.get("name", "world") -# return web.json_response({"data": {"hello": f"Hello, {name}!"}}) -# # add mutation -# if "add" in query: -# a = variables.get("a", 0) -# b = variables.get("b", 0) -# return web.json_response({"data": {"add": a + b}}) -# # fallback -# return web.json_response({"data": {}}, status=200) - -# app = web.Application() -# app.router.add_post("/graphql", graphql_handler) -# return app - -# @pytest_asyncio.fixture -# async def aiohttp_graphql_client(aiohttp_client, graphql_app): -# return await aiohttp_client(graphql_app) - -# @pytest_asyncio.fixture -# def transport(): -# return GraphQLClientTransport() - -# @pytest_asyncio.fixture -# def provider(aiohttp_graphql_client): -# return GraphQLProvider( -# name="test-graphql-provider", -# url=str(aiohttp_graphql_client.make_url("/graphql")), -# headers={}, -# ) - -# @pytest.mark.asyncio -# async def test_register_tool_provider_discovers_tools(transport, provider): -# tools = await transport.register_tool_provider(provider) -# tool_names = [tool.name for tool in tools] -# assert "hello" in tool_names -# assert "add" in tool_names - -# @pytest.mark.asyncio -# async def test_call_tool_query(transport, provider): -# result = await transport.call_tool("hello", {"name": "Alice"}, provider) -# assert result["hello"] == "Hello, Alice!" - -# @pytest.mark.asyncio -# async def test_call_tool_mutation(transport, provider): -# provider.operation_type = "mutation" -# mutation = ''' -# mutation ($a: Int, $b: Int) { -# add(a: $a, b: $b) -# } -# ''' -# result = await transport.call_tool("add", {"a": 2, "b": 3}, provider, query=mutation) -# assert result["add"] == 5 - -# @pytest.mark.asyncio -# async def test_call_tool_api_key(transport, provider): -# provider.headers = {} -# provider.auth = ApiKeyAuth(var_name="X-API-Key", api_key="test-key") -# result = await transport.call_tool("hello", {"name": "Bob"}, provider) -# assert result["hello"] == "Hello, Bob!" - -# @pytest.mark.asyncio -# async def test_call_tool_basic_auth(transport, provider): -# provider.headers = {} -# provider.auth = BasicAuth(username="user", password="pass") -# result = await transport.call_tool("hello", {"name": "Eve"}, provider) -# assert result["hello"] == "Hello, Eve!" - -# @pytest.mark.asyncio -# async def test_call_tool_oauth2(monkeypatch, transport, provider): -# async def fake_oauth2(auth): -# return "fake-token" -# transport._handle_oauth2 = fake_oauth2 -# provider.headers = {} -# provider.auth = OAuth2Auth(token_url="http://fake/token", client_id="id", client_secret="secret", scope="scope") -# result = await transport.call_tool("hello", {"name": "Zoe"}, provider) -# assert result["hello"] == "Hello, Zoe!" - -# @pytest.mark.asyncio -# async def test_enforce_https_or_localhost_raises(transport, provider): -# provider.url = "http://evil.com/graphql" -# with pytest.raises(ValueError): -# await transport.call_tool("hello", {"name": "Mallory"}, provider) - -# @pytest.mark.asyncio -# async def test_deregister_tool_provider_noop(transport, provider): -# await transport.deregister_tool_provider(provider) \ No newline at end of file diff --git a/plugins/communication_protocols/gql/src/utcp_gql/__init__.py b/plugins/communication_protocols/gql/src/utcp_gql/__init__.py index 7362502..6dd0fda 100644 --- a/plugins/communication_protocols/gql/src/utcp_gql/__init__.py +++ b/plugins/communication_protocols/gql/src/utcp_gql/__init__.py @@ -1,9 +1,9 @@ from utcp.plugins.discovery import register_communication_protocol, register_call_template from .gql_communication_protocol import GraphQLCommunicationProtocol -from .gql_call_template import GraphQLProvider, GraphQLProviderSerializer +from .gql_call_template import GraphQLCallTemplate, GraphQLCallTemplateSerializer def register(): register_communication_protocol("graphql", GraphQLCommunicationProtocol()) - register_call_template("graphql", GraphQLProviderSerializer()) \ No newline at end of file + register_call_template("graphql", GraphQLCallTemplateSerializer()) \ No newline at end of file diff --git a/plugins/communication_protocols/gql/src/utcp_gql/gql_call_template.py b/plugins/communication_protocols/gql/src/utcp_gql/gql_call_template.py index 3848d29..579d691 100644 --- a/plugins/communication_protocols/gql/src/utcp_gql/gql_call_template.py +++ b/plugins/communication_protocols/gql/src/utcp_gql/gql_call_template.py @@ -6,13 +6,17 @@ from typing import Dict, List, Optional, Literal from pydantic import Field, field_serializer, field_validator -class GraphQLProvider(CallTemplate): +class GraphQLCallTemplate(CallTemplate): """Provider configuration for GraphQL-based tools. Enables communication with GraphQL endpoints supporting queries, mutations, and subscriptions. Provides flexible query execution with custom headers and authentication. + For maximum flexibility, use the `query` field to provide a complete GraphQL + query string with proper selection sets and variable types. This allows agents + to call any existing GraphQL endpoint without limitations. + Attributes: call_template_type: Always "graphql" for GraphQL providers. url: The GraphQL endpoint URL. @@ -21,6 +25,23 @@ class GraphQLProvider(CallTemplate): auth: Optional authentication configuration. headers: Optional static headers to include in requests. header_fields: List of tool argument names to map to HTTP request headers. + query: Custom GraphQL query string with full control over selection sets + and variable types. Example: 'query GetUser($id: ID!) { user(id: $id) { id name } }' + variable_types: Map of variable names to GraphQL types for auto-generated queries. + Example: {'id': 'ID!', 'limit': 'Int'}. Defaults to 'String' if not specified. + + Example: + # Full flexibility with custom query + template = GraphQLCallTemplate( + url="https://api.example.com/graphql", + query="query GetUser($id: ID!) { user(id: $id) { id name email } }", + ) + + # Auto-generation with proper types + template = GraphQLCallTemplate( + url="https://api.example.com/graphql", + variable_types={"limit": "Int", "active": "Boolean"}, + ) """ call_template_type: Literal["graphql"] = "graphql" @@ -30,6 +51,18 @@ class GraphQLProvider(CallTemplate): auth: Optional[Auth] = None headers: Optional[Dict[str, str]] = None header_fields: Optional[List[str]] = Field(default=None, description="List of input fields to be sent as request headers for the initial connection.") + query: Optional[str] = Field( + default=None, + description="Custom GraphQL query/mutation string. Use $varName syntax for variables. " + "If provided, this takes precedence over auto-generation. " + "Example: 'query GetUser($id: ID!) { user(id: $id) { id name email } }'" + ) + variable_types: Optional[Dict[str, str]] = Field( + default=None, + description="Map of variable names to GraphQL types for auto-generated queries. " + "Example: {'id': 'ID!', 'limit': 'Int', 'active': 'Boolean'}. " + "Defaults to 'String' if not specified." + ) @field_serializer("auth") def serialize_auth(self, auth: Optional[Auth]): @@ -47,14 +80,14 @@ def validate_auth(cls, v: Optional[Auth | dict]): return AuthSerializer().validate_dict(v) -class GraphQLProviderSerializer(Serializer[GraphQLProvider]): - def to_dict(self, obj: GraphQLProvider) -> dict: +class GraphQLCallTemplateSerializer(Serializer[GraphQLCallTemplate]): + def to_dict(self, obj: GraphQLCallTemplate) -> dict: return obj.model_dump() - def validate_dict(self, data: dict) -> GraphQLProvider: + def validate_dict(self, data: dict) -> GraphQLCallTemplate: try: - return GraphQLProvider.model_validate(data) + return GraphQLCallTemplate.model_validate(data) except Exception as e: raise UtcpSerializerValidationError( - f"Invalid GraphQLProvider: {e}\n{traceback.format_exc()}" - ) + f"Invalid GraphQLCallTemplate: {e}\n{traceback.format_exc()}" + ) \ No newline at end of file diff --git a/plugins/communication_protocols/gql/src/utcp_gql/gql_communication_protocol.py b/plugins/communication_protocols/gql/src/utcp_gql/gql_communication_protocol.py index 9d26cab..16b945c 100644 --- a/plugins/communication_protocols/gql/src/utcp_gql/gql_communication_protocol.py +++ b/plugins/communication_protocols/gql/src/utcp_gql/gql_communication_protocol.py @@ -14,7 +14,7 @@ from utcp.data.auth_implementations.basic_auth import BasicAuth from utcp.data.auth_implementations.oauth2_auth import OAuth2Auth -from utcp_gql.gql_call_template import GraphQLProvider +from utcp_gql.gql_call_template import GraphQLCallTemplate if TYPE_CHECKING: from utcp.utcp_client import UtcpClient @@ -70,25 +70,25 @@ async def _handle_oauth2(self, auth: OAuth2Auth) -> str: return token_response["access_token"] async def _prepare_headers( - self, provider: GraphQLProvider, tool_args: Optional[Dict[str, Any]] = None + self, call_template: GraphQLCallTemplate, tool_args: Optional[Dict[str, Any]] = None ) -> Dict[str, str]: - headers: Dict[str, str] = provider.headers.copy() if provider.headers else {} - if provider.auth: - if isinstance(provider.auth, ApiKeyAuth): - if provider.auth.api_key and provider.auth.location == "header": - headers[provider.auth.var_name] = provider.auth.api_key - elif isinstance(provider.auth, BasicAuth): + headers: Dict[str, str] = call_template.headers.copy() if call_template.headers else {} + if call_template.auth: + if isinstance(call_template.auth, ApiKeyAuth): + if call_template.auth.api_key and call_template.auth.location == "header": + headers[call_template.auth.var_name] = call_template.auth.api_key + elif isinstance(call_template.auth, BasicAuth): import base64 - userpass = f"{provider.auth.username}:{provider.auth.password}" + userpass = f"{call_template.auth.username}:{call_template.auth.password}" headers["Authorization"] = "Basic " + base64.b64encode(userpass.encode()).decode() - elif isinstance(provider.auth, OAuth2Auth): - token = await self._handle_oauth2(provider.auth) + elif isinstance(call_template.auth, OAuth2Auth): + token = await self._handle_oauth2(call_template.auth) headers["Authorization"] = f"Bearer {token}" # Map selected tool_args into headers if requested - if tool_args and provider.header_fields: - for field in provider.header_fields: + if tool_args and call_template.header_fields: + for field in call_template.header_fields: if field in tool_args and isinstance(tool_args[field], str): headers[field] = tool_args[field] @@ -97,8 +97,8 @@ async def _prepare_headers( async def register_manual( self, caller: "UtcpClient", manual_call_template: CallTemplate ) -> RegisterManualResult: - if not isinstance(manual_call_template, GraphQLProvider): - raise ValueError("GraphQLCommunicationProtocol requires a GraphQLProvider call template") + if not isinstance(manual_call_template, GraphQLCallTemplate): + raise ValueError("GraphQLCommunicationProtocol requires a GraphQLCallTemplate call template") self._enforce_https_or_localhost(manual_call_template.url) try: @@ -176,26 +176,40 @@ async def call_tool( tool_args: Dict[str, Any], tool_call_template: CallTemplate, ) -> Any: - if not isinstance(tool_call_template, GraphQLProvider): - raise ValueError("GraphQLCommunicationProtocol requires a GraphQLProvider call template") + if not isinstance(tool_call_template, GraphQLCallTemplate): + raise ValueError("GraphQLCommunicationProtocol requires a GraphQLCallTemplate call template") self._enforce_https_or_localhost(tool_call_template.url) headers = await self._prepare_headers(tool_call_template, tool_args) transport = AIOHTTPTransport(url=tool_call_template.url, headers=headers) async with GqlClient(transport=transport, fetch_schema_from_transport=True) as session: - op_type = getattr(tool_call_template, "operation_type", "query") - # Strip manual prefix if present (client prefixes at save time) - base_tool_name = tool_name.split(".", 1)[-1] if "." in tool_name else tool_name # Filter out header fields from GraphQL variables; these are sent via HTTP headers header_fields = tool_call_template.header_fields or [] filtered_args = {k: v for k, v in tool_args.items() if k not in header_fields} - arg_str = ", ".join(f"${k}: String" for k in filtered_args.keys()) - var_defs = f"({arg_str})" if arg_str else "" - arg_pass = ", ".join(f"{k}: ${k}" for k in filtered_args.keys()) - arg_pass = f"({arg_pass})" if arg_pass else "" + # Use custom query if provided (highest flexibility for agents) + if tool_call_template.query: + gql_str = tool_call_template.query + else: + # Auto-generate query - use variable_types for proper typing + op_type = getattr(tool_call_template, "operation_type", "query") + base_tool_name = tool_name.split(".", 1)[-1] if "." in tool_name else tool_name + variable_types = tool_call_template.variable_types or {} + + # Build variable definitions with proper types (default to String) + arg_str = ", ".join( + f"${k}: {variable_types.get(k, 'String')}" + for k in filtered_args.keys() + ) + var_defs = f"({arg_str})" if arg_str else "" + arg_pass = ", ".join(f"{k}: ${k}" for k in filtered_args.keys()) + arg_pass = f"({arg_pass})" if arg_pass else "" + + # Note: Auto-generated queries for object-returning fields will still fail + # without a selection set. Use the `query` field for full control. + gql_str = f"{op_type} {var_defs} {{ {base_tool_name}{arg_pass} }}" + logger.debug(f"Auto-generated GraphQL: {gql_str}") - gql_str = f"{op_type} {var_defs} {{ {base_tool_name}{arg_pass} }}" document = gql_query(gql_str) result = await session.execute(document, variable_values=filtered_args) return result diff --git a/plugins/communication_protocols/gql/tests/test_graphql_integration.py b/plugins/communication_protocols/gql/tests/test_graphql_integration.py new file mode 100644 index 0000000..fdc4fcb --- /dev/null +++ b/plugins/communication_protocols/gql/tests/test_graphql_integration.py @@ -0,0 +1,275 @@ +"""Integration tests for GraphQL communication protocol using real GraphQL servers. + +Uses the public Countries API (https://countries.trevorblades.com/graphql) which +requires no authentication and has a stable schema. +""" +import os +import sys +import warnings +import pytest +import pytest_asyncio + +# Ensure plugin src is importable +PLUGIN_SRC = os.path.join(os.path.dirname(__file__), "..", "src") +PLUGIN_SRC = os.path.abspath(PLUGIN_SRC) +if PLUGIN_SRC not in sys.path: + sys.path.append(PLUGIN_SRC) + +import utcp_gql +from utcp_gql.gql_call_template import GraphQLCallTemplate +from utcp_gql.gql_communication_protocol import GraphQLCommunicationProtocol + +from utcp.implementations.utcp_client_implementation import UtcpClientImplementation + +# Public GraphQL API for testing (no auth required) +COUNTRIES_API_URL = "https://countries.trevorblades.com/graphql" + +# Suppress gql SSL warning (we're using HTTPS which is secure) +warnings.filterwarnings("ignore", message=".*AIOHTTPTransport does not verify ssl.*") + + +@pytest.fixture +def protocol(): + """Create a fresh GraphQL protocol instance.""" + utcp_gql.register() + return GraphQLCommunicationProtocol() + + +@pytest_asyncio.fixture +async def client(): + """Create a minimal UTCP client.""" + return await UtcpClientImplementation.create() + + +@pytest.mark.asyncio +async def test_register_manual_discovers_tools(protocol, client): + """Test that register_manual discovers tools from a real GraphQL schema.""" + template = GraphQLCallTemplate( + name="countries_api", + url=COUNTRIES_API_URL, + ) + + result = await protocol.register_manual(client, template) + + assert result.success is True + assert len(result.manual.tools) > 0 + + # The Countries API should have these common queries + tool_names = [t.name for t in result.manual.tools] + assert "countries" in tool_names or "country" in tool_names + + +@pytest.mark.asyncio +async def test_call_tool_with_custom_query(protocol, client): + """Test calling a tool with a custom query string (fixes selection set issue).""" + # Custom query with proper selection set - this is the UTCP-flexible approach + custom_query = """ + query GetCountry($code: ID!) { + country(code: $code) { + name + capital + currency + } + } + """ + + template = GraphQLCallTemplate( + name="countries_api", + url=COUNTRIES_API_URL, + query=custom_query, + ) + + result = await protocol.call_tool( + client, + "country", + {"code": "US"}, + template, + ) + + assert result is not None + assert "country" in result + assert result["country"]["name"] == "United States" + assert result["country"]["capital"] == "Washington D.C." + + +@pytest.mark.asyncio +async def test_call_tool_with_variable_types(protocol, client): + """Test that variable_types properly maps GraphQL types (fixes String-only issue).""" + # The country query expects code: ID!, not String + # Using variable_types to specify the correct type + custom_query = """ + query GetCountry($code: ID!) { + country(code: $code) { + name + emoji + } + } + """ + + template = GraphQLCallTemplate( + name="countries_api", + url=COUNTRIES_API_URL, + query=custom_query, + variable_types={"code": "ID!"}, + ) + + result = await protocol.call_tool( + client, + "country", + {"code": "FR"}, + template, + ) + + assert result is not None + assert result["country"]["name"] == "France" + assert result["country"]["emoji"] == "šŸ‡«šŸ‡·" + + +@pytest.mark.asyncio +async def test_call_tool_list_query(protocol, client): + """Test querying a list of items with proper selection set.""" + custom_query = """ + query GetContinents { + continents { + code + name + } + } + """ + + template = GraphQLCallTemplate( + name="countries_api", + url=COUNTRIES_API_URL, + query=custom_query, + ) + + result = await protocol.call_tool( + client, + "continents", + {}, + template, + ) + + assert result is not None + assert "continents" in result + assert len(result["continents"]) == 7 # 7 continents + + continent_names = [c["name"] for c in result["continents"]] + assert "Europe" in continent_names + assert "Asia" in continent_names + + +@pytest.mark.asyncio +async def test_call_tool_nested_query(protocol, client): + """Test querying nested objects with proper selection sets.""" + custom_query = """ + query GetCountryWithLanguages($code: ID!) { + country(code: $code) { + name + languages { + code + name + } + } + } + """ + + template = GraphQLCallTemplate( + name="countries_api", + url=COUNTRIES_API_URL, + query=custom_query, + ) + + result = await protocol.call_tool( + client, + "country", + {"code": "CH"}, # Switzerland - has multiple languages + template, + ) + + assert result is not None + assert result["country"]["name"] == "Switzerland" + assert len(result["country"]["languages"]) >= 3 # German, French, Italian, Romansh + + +@pytest.mark.asyncio +async def test_call_tool_with_filter_arguments(protocol, client): + """Test queries with filter arguments using proper types.""" + custom_query = """ + query GetCountriesByContinent($filter: CountryFilterInput) { + countries(filter: $filter) { + code + name + } + } + """ + + template = GraphQLCallTemplate( + name="countries_api", + url=COUNTRIES_API_URL, + query=custom_query, + variable_types={"filter": "CountryFilterInput"}, + ) + + result = await protocol.call_tool( + client, + "countries", + {"filter": {"continent": {"eq": "EU"}}}, + template, + ) + + assert result is not None + assert "countries" in result + # Should return European countries + country_codes = [c["code"] for c in result["countries"]] + assert "DE" in country_codes # Germany + assert "FR" in country_codes # France + + +@pytest.mark.asyncio +async def test_error_handling_invalid_query(protocol, client): + """Test that invalid queries return proper errors.""" + # Invalid query syntax + invalid_query = "this is not valid graphql" + + template = GraphQLCallTemplate( + name="countries_api", + url=COUNTRIES_API_URL, + query=invalid_query, + ) + + with pytest.raises(Exception): + await protocol.call_tool( + client, + "invalid", + {}, + template, + ) + + +@pytest.mark.asyncio +async def test_error_handling_missing_selection_set_auto_generated(protocol, client): + """ + Demonstrate that auto-generated queries fail for object-returning fields. + + This test documents the limitation: without a custom query, object fields fail. + The fix is to always use the `query` field for object-returning operations. + """ + # No custom query - will auto-generate without selection set + template = GraphQLCallTemplate( + name="countries_api", + url=COUNTRIES_API_URL, + operation_type="query", + variable_types={"code": "ID!"}, + ) + + # This should fail because auto-generated query lacks selection set + # The query becomes: query ($code: ID!) { country(code: $code) } + # But country returns an object that needs: { name capital ... } + with pytest.raises(Exception): + await protocol.call_tool( + client, + "country", + {"code": "US"}, + template, + ) diff --git a/plugins/communication_protocols/gql/tests/test_graphql_protocol.py b/plugins/communication_protocols/gql/tests/test_graphql_protocol.py deleted file mode 100644 index 1b1bb74..0000000 --- a/plugins/communication_protocols/gql/tests/test_graphql_protocol.py +++ /dev/null @@ -1,110 +0,0 @@ -import os -import sys -import types -import pytest - - -# Ensure plugin src is importable -PLUGIN_SRC = os.path.join(os.path.dirname(__file__), "..", "src") -PLUGIN_SRC = os.path.abspath(PLUGIN_SRC) -if PLUGIN_SRC not in sys.path: - sys.path.append(PLUGIN_SRC) - -import utcp_gql -# Simplify imports: use the main module and assign local aliases -GraphQLProvider = utcp_gql.gql_call_template.GraphQLProvider -gql_module = utcp_gql.gql_communication_protocol - -from utcp.data.utcp_manual import UtcpManual -from utcp.utcp_client import UtcpClient -from utcp.implementations.utcp_client_implementation import UtcpClientImplementation - - -class FakeSchema: - def __init__(self): - # Minimal field objects with descriptions - self.query_type = types.SimpleNamespace( - fields={ - "hello": types.SimpleNamespace(description="Returns greeting"), - } - ) - self.mutation_type = types.SimpleNamespace( - fields={ - "add": types.SimpleNamespace(description="Adds numbers"), - } - ) - self.subscription_type = None - - -class FakeClientObj: - def __init__(self): - self.client = types.SimpleNamespace(schema=FakeSchema()) - - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc, tb): - return False - - async def execute(self, document, variable_values=None): - # document is a gql query; we can base behavior on variable_values - variable_values = variable_values or {} - # Determine operation by presence of variables used - if "hello" in str(document): - name = variable_values.get("name", "") - return {"hello": f"Hello {name}"} - if "add" in str(document): - a = int(variable_values.get("a", 0)) - b = int(variable_values.get("b", 0)) - return {"add": a + b} - return {"ok": True} - - -class FakeTransport: - def __init__(self, url: str, headers: dict | None = None): - self.url = url - self.headers = headers or {} - - -@pytest.mark.asyncio -async def test_graphql_register_and_call(monkeypatch): - # Patch gql client/transport used by protocol to avoid needing a real server - monkeypatch.setattr(gql_module, "GqlClient", lambda *args, **kwargs: FakeClientObj()) - monkeypatch.setattr(gql_module, "AIOHTTPTransport", FakeTransport) - # Avoid real GraphQL parsing; pass-through document string to fake execute - monkeypatch.setattr(gql_module, "gql_query", lambda s: s) - - # Register plugin (call_template serializer + protocol) - utcp_gql.register() - - # Create protocol and manual call template - protocol = gql_module.GraphQLCommunicationProtocol() - provider = GraphQLProvider( - name="mock_graph", - call_template_type="graphql", - url="http://localhost/graphql", - operation_type="query", - headers={"x-client": "utcp"}, - header_fields=["x-session-id"], - ) - - # Minimal UTCP client implementation for caller context - client: UtcpClient = await UtcpClientImplementation.create() - client.config.variables = {} - - # Register and discover tools - reg = await protocol.register_manual(client, provider) - assert reg.success is True - assert isinstance(reg.manual, UtcpManual) - tool_names = sorted(t.name for t in reg.manual.tools) - assert "hello" in tool_names - assert "add" in tool_names - - # Call hello - res = await protocol.call_tool(client, "mock_graph.hello", {"name": "UTCP", "x-session-id": "abc"}, provider) - assert res == {"hello": "Hello UTCP"} - - # Call add (mutation) - provider.operation_type = "mutation" - res2 = await protocol.call_tool(client, "mock_graph.add", {"a": 2, "b": 3}, provider) - assert res2 == {"add": 5} \ No newline at end of file diff --git a/plugins/communication_protocols/socket/README.md b/plugins/communication_protocols/socket/README.md index 3e695c9..04c1737 100644 --- a/plugins/communication_protocols/socket/README.md +++ b/plugins/communication_protocols/socket/README.md @@ -12,8 +12,8 @@ Prerequisites: 1) Install core and the socket plugin in editable mode with dev extras: ```bash -pip install -e "core[dev]" -pip install -e plugins/communication_protocols/socket[dev] +pip install -e "./core[dev]" +pip install -e ./plugins/communication_protocols/socket[dev] ``` 2) Run the socket plugin tests: diff --git a/plugins/communication_protocols/websocket/tests/__init__.py b/plugins/communication_protocols/websocket/tests/__init__.py deleted file mode 100644 index 614ce9a..0000000 --- a/plugins/communication_protocols/websocket/tests/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Tests for the WebSocket communication protocol plugin.""" diff --git a/socket_plugin_test.py b/socket_plugin_test.py deleted file mode 100644 index c03d2a9..0000000 --- a/socket_plugin_test.py +++ /dev/null @@ -1,40 +0,0 @@ -import asyncio -import sys -from pathlib import Path - -# Add core and plugin src paths so imports work without installing packages -core_src = Path(__file__).parent / "core" / "src" -socket_src = Path(__file__).parent / "plugins" / "communication_protocols" / "socket" / "src" -sys.path.insert(0, str(core_src.resolve())) -sys.path.insert(0, str(socket_src.resolve())) - -from utcp.plugins.plugin_loader import ensure_plugins_initialized -from utcp.interfaces.communication_protocol import CommunicationProtocol -from utcp.data.call_template import CallTemplateSerializer -from utcp_socket import register as register_socket - -async def main(): - # Manually register the socket plugin - register_socket() - - # Load core plugins (auth, repo, search, post-processors) - ensure_plugins_initialized() - - # 1. Check if communication protocols are registered - registered_protocols = CommunicationProtocol.communication_protocols - print(f"Registered communication protocols: {list(registered_protocols.keys())}") - assert "tcp" in registered_protocols, "TCP communication protocol not registered" - assert "udp" in registered_protocols, "UDP communication protocol not registered" - print("āœ… TCP and UDP communication protocols are registered.") - - # 2. Check if call templates are registered - registered_serializers = CallTemplateSerializer.call_template_serializers - print(f"Registered call template serializers: {list(registered_serializers.keys())}") - assert "tcp" in registered_serializers, "TCP call template serializer not registered" - assert "udp" in registered_serializers, "UDP call template serializer not registered" - print("āœ… TCP and UDP call template serializers are registered.") - - print("\nšŸŽ‰ Socket plugin sanity check passed! šŸŽ‰") - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file diff --git a/test_websocket_manual.py b/test_websocket_manual.py deleted file mode 100644 index a1457c4..0000000 --- a/test_websocket_manual.py +++ /dev/null @@ -1,201 +0,0 @@ -#!/usr/bin/env python3 -""" -Manual test script for WebSocket transport implementation. -This tests the core functionality without requiring pytest setup. -""" - -import asyncio -import sys -import os - -# Add src to path for imports -sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) - -from utcp.client.transport_interfaces.websocket_transport import WebSocketClientTransport -from utcp.shared.provider import WebSocketProvider -from utcp.shared.auth import ApiKeyAuth, BasicAuth - - -async def test_basic_functionality(): - """Test basic WebSocket transport functionality""" - print("Testing WebSocket Transport Implementation...") - - transport = WebSocketClientTransport() - - # Test 1: Security enforcement - print("\n1. Testing security enforcement...") - try: - insecure_provider = WebSocketProvider( - name="insecure", - url="ws://example.com/ws" # Should be rejected - ) - await transport.register_tool_provider(insecure_provider) - print("āŒ FAILED: Insecure URL was accepted") - except ValueError as e: - if "Security error" in str(e): - print("āœ… PASSED: Insecure URL properly rejected") - else: - print(f"āŒ FAILED: Wrong error: {e}") - except Exception as e: - print(f"āŒ FAILED: Unexpected error: {e}") - - # Test 2: Provider type validation - print("\n2. Testing provider type validation...") - try: - from utcp.shared.provider import HttpProvider - wrong_provider = HttpProvider(name="wrong", url="https://example.com") - await transport.register_tool_provider(wrong_provider) - print("āŒ FAILED: Wrong provider type was accepted") - except ValueError as e: - if "WebSocketClientTransport can only be used with WebSocketProvider" in str(e): - print("āœ… PASSED: Provider type validation works") - else: - print(f"āŒ FAILED: Wrong error: {e}") - except Exception as e: - print(f"āŒ FAILED: Unexpected error: {e}") - - # Test 3: Authentication header preparation - print("\n3. Testing authentication...") - try: - # Test API Key auth - api_provider = WebSocketProvider( - name="api_test", - url="wss://example.com/ws", - auth=ApiKeyAuth( - var_name="X-API-Key", - api_key="test-key-123", - location="header" - ) - ) - headers = await transport._prepare_headers(api_provider) - if headers.get("X-API-Key") == "test-key-123": - print("āœ… PASSED: API Key authentication headers prepared correctly") - else: - print(f"āŒ FAILED: API Key headers incorrect: {headers}") - - # Test Basic auth - basic_provider = WebSocketProvider( - name="basic_test", - url="wss://example.com/ws", - auth=BasicAuth(username="user", password="pass") - ) - headers = await transport._prepare_headers(basic_provider) - if "Authorization" in headers and headers["Authorization"].startswith("Basic "): - print("āœ… PASSED: Basic authentication headers prepared correctly") - else: - print(f"āŒ FAILED: Basic auth headers incorrect: {headers}") - - except Exception as e: - print(f"āŒ FAILED: Authentication test error: {e}") - - # Test 4: Connection management - print("\n4. Testing connection management...") - try: - localhost_provider = WebSocketProvider( - name="test_provider", - url="ws://localhost:8765/ws" - ) - - # This should fail to connect but not due to security - try: - await transport.register_tool_provider(localhost_provider) - print("āŒ FAILED: Connection should have failed (no server)") - except ValueError as e: - if "Security error" in str(e): - print("āŒ FAILED: Security error on localhost") - else: - print("ā“ UNEXPECTED: Different error occurred") - except Exception as e: - # Expected - connection refused or similar - print("āœ… PASSED: Connection management works (failed to connect as expected)") - - except Exception as e: - print(f"āŒ FAILED: Connection test error: {e}") - - # Test 5: Cleanup - print("\n5. Testing cleanup...") - try: - await transport.close() - if len(transport._connections) == 0 and len(transport._oauth_tokens) == 0: - print("āœ… PASSED: Cleanup successful") - else: - print("āŒ FAILED: Cleanup incomplete") - except Exception as e: - print(f"āŒ FAILED: Cleanup error: {e}") - - print("\nāœ… WebSocket transport basic functionality tests completed!") - - -async def test_with_mock_server(): - """Test with a real WebSocket connection to our mock server""" - print("\n" + "="*50) - print("Testing with Mock WebSocket Server") - print("="*50) - - # Import and start mock server - sys.path.append('tests/client/transport_interfaces') - try: - from mock_websocket_server import create_app - from aiohttp import web - - print("Starting mock WebSocket server...") - app = await create_app() - runner = web.AppRunner(app) - await runner.setup() - site = web.TCPSite(runner, 'localhost', 8765) - await site.start() - - print("Mock server started on ws://localhost:8765/ws") - - # Test with our transport - transport = WebSocketClientTransport() - provider = WebSocketProvider( - name="test_provider", - url="ws://localhost:8765/ws" - ) - - try: - # Test tool discovery - print("\nTesting tool discovery...") - tools = await transport.register_tool_provider(provider) - print(f"āœ… Discovered {len(tools)} tools:") - for tool in tools: - print(f" - {tool.name}: {tool.description}") - - # Test tool execution - print("\nTesting tool execution...") - result = await transport.call_tool("echo", {"message": "Hello WebSocket!"}, provider) - print(f"āœ… Echo result: {result}") - - result = await transport.call_tool("add_numbers", {"a": 5, "b": 3}, provider) - print(f"āœ… Add result: {result}") - - # Test error handling - print("\nTesting error handling...") - try: - await transport.call_tool("simulate_error", {"error_message": "Test error"}, provider) - print("āŒ FAILED: Error tool should have failed") - except RuntimeError as e: - print(f"āœ… Error properly handled: {e}") - - except Exception as e: - print(f"āŒ Transport test failed: {e}") - finally: - await transport.close() - await runner.cleanup() - print("Mock server stopped") - - except ImportError as e: - print(f"āš ļø Mock server test skipped (missing dependencies): {e}") - except Exception as e: - print(f"āŒ Mock server test failed: {e}") - - -async def main(): - """Run all manual tests""" - await test_basic_functionality() - # await test_with_mock_server() # Uncomment if you want to test with real server - - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file From d4d8ede41a5bded516313a70b6e72137e69feede Mon Sep 17 00:00:00 2001 From: Razvan Radulescu <43811028+h3xxit@users.noreply.github.com> Date: Sat, 29 Nov 2025 16:13:11 +0100 Subject: [PATCH 4/4] remove WIP for finished plugins --- plugins/communication_protocols/gql/pyproject.toml | 2 +- plugins/communication_protocols/socket/pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/communication_protocols/gql/pyproject.toml b/plugins/communication_protocols/gql/pyproject.toml index 7c752c3..d5b558d 100644 --- a/plugins/communication_protocols/gql/pyproject.toml +++ b/plugins/communication_protocols/gql/pyproject.toml @@ -8,7 +8,7 @@ version = "1.0.2" authors = [ { name = "UTCP Contributors" }, ] -description = "UTCP communication protocol plugin for GraphQL. (Work in progress)" +description = "UTCP communication protocol plugin for GraphQL." readme = "README.md" requires-python = ">=3.10" dependencies = [ diff --git a/plugins/communication_protocols/socket/pyproject.toml b/plugins/communication_protocols/socket/pyproject.toml index 2f232ad..a544648 100644 --- a/plugins/communication_protocols/socket/pyproject.toml +++ b/plugins/communication_protocols/socket/pyproject.toml @@ -8,7 +8,7 @@ version = "1.0.2" authors = [ { name = "UTCP Contributors" }, ] -description = "UTCP communication protocol plugin for TCP and UDP protocols. (Work in progress)" +description = "UTCP communication protocol plugin for TCP and UDP protocols." readme = "README.md" requires-python = ">=3.10" dependencies = [