diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9394330..bf677f8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -31,10 +31,11 @@ jobs: pip install -e plugins/communication_protocols/http[dev] pip install -e plugins/communication_protocols/mcp[dev] pip install -e plugins/communication_protocols/text[dev] + pip install -e plugins/communication_protocols/socket[dev] - name: Run tests with pytest run: | - pytest core/tests/ plugins/communication_protocols/cli/tests/ plugins/communication_protocols/http/tests/ plugins/communication_protocols/mcp/tests/ plugins/communication_protocols/text/tests/ --doctest-modules --junitxml=junit/test-results.xml --cov=core/src/utcp --cov-report=xml --cov-report=html + pytest core/tests/ plugins/communication_protocols/cli/tests/ plugins/communication_protocols/http/tests/ plugins/communication_protocols/mcp/tests/ plugins/communication_protocols/text/tests/ plugins/communication_protocols/socket/tests/ --doctest-modules --junitxml=junit/test-results.xml --cov=core/src/utcp --cov-report=xml --cov-report=html - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v3 diff --git a/plugins/communication_protocols/gql/README.md b/plugins/communication_protocols/gql/README.md index 8febb5a..34a2518 100644 --- a/plugins/communication_protocols/gql/README.md +++ b/plugins/communication_protocols/gql/README.md @@ -1 +1,47 @@ -Find the UTCP readme at https://github.com/universal-tool-calling-protocol/python-utcp. \ No newline at end of file + +# UTCP GraphQL Communication Protocol Plugin + +This plugin integrates GraphQL as a UTCP 1.0 communication protocol and call template. It supports discovery via schema introspection, authenticated calls, and header handling. + +## Getting Started + +### Installation + +```bash +pip install gql +``` + +### Registration + +```python +import utcp_gql +utcp_gql.register() +``` + +## How To Use + +- Ensure the plugin is imported and registered: `import utcp_gql; utcp_gql.register()`. +- Add a manual in your client config: + ```json + { + "name": "my_graph", + "call_template_type": "graphql", + "url": "https://your.graphql/endpoint", + "operation_type": "query", + "headers": { "x-client": "utcp" }, + "header_fields": ["x-session-id"] + } + ``` +- Call a tool: + ```python + await client.call_tool("my_graph.someQuery", {"id": "123", "x-session-id": "abc"}) + ``` + +## Notes + +- Tool names are prefixed by the manual name (e.g., `my_graph.someQuery`). +- Headers merge static `headers` plus whitelisted dynamic fields from `header_fields`. +- Supported auth: API key, Basic auth, OAuth2 (client-credentials). +- Security: only `https://` or `http://localhost`/`http://127.0.0.1` endpoints. + +For UTCP core docs, see https://github.com/universal-tool-calling-protocol/python-utcp. \ No newline at end of file diff --git a/plugins/communication_protocols/gql/old_tests/test_graphql_transport.py b/plugins/communication_protocols/gql/old_tests/test_graphql_transport.py deleted file mode 100644 index d33c323..0000000 --- a/plugins/communication_protocols/gql/old_tests/test_graphql_transport.py +++ /dev/null @@ -1,129 +0,0 @@ -# import pytest -# import pytest_asyncio -# import json -# from aiohttp import web -# from utcp.client.transport_interfaces.graphql_transport import GraphQLClientTransport -# from utcp.shared.provider import GraphQLProvider -# from utcp.shared.auth import ApiKeyAuth, BasicAuth, OAuth2Auth - - -# @pytest_asyncio.fixture -# async def graphql_app(): -# async def graphql_handler(request): -# body = await request.json() -# query = body.get("query", "") -# variables = body.get("variables", {}) -# # Introspection query (minimal response) -# if "__schema" in query: -# return web.json_response({ -# "data": { -# "__schema": { -# "queryType": {"name": "Query"}, -# "mutationType": {"name": "Mutation"}, -# "subscriptionType": None, -# "types": [ -# {"kind": "OBJECT", "name": "Query", "fields": [ -# {"name": "hello", "args": [{"name": "name", "type": {"kind": "SCALAR", "name": "String"}, "defaultValue": None}], "type": {"kind": "SCALAR", "name": "String"}, "isDeprecated": False, "deprecationReason": None} -# ], "interfaces": []}, -# {"kind": "OBJECT", "name": "Mutation", "fields": [ -# {"name": "add", "args": [ -# {"name": "a", "type": {"kind": "SCALAR", "name": "Int"}, "defaultValue": None}, -# {"name": "b", "type": {"kind": "SCALAR", "name": "Int"}, "defaultValue": None} -# ], "type": {"kind": "SCALAR", "name": "Int"}, "isDeprecated": False, "deprecationReason": None} -# ], "interfaces": []}, -# {"kind": "SCALAR", "name": "String"}, -# {"kind": "SCALAR", "name": "Int"}, -# {"kind": "SCALAR", "name": "Boolean"} -# ], -# "directives": [] -# } -# } -# }) -# # hello query -# if "hello" in query: -# name = variables.get("name", "world") -# return web.json_response({"data": {"hello": f"Hello, {name}!"}}) -# # add mutation -# if "add" in query: -# a = variables.get("a", 0) -# b = variables.get("b", 0) -# return web.json_response({"data": {"add": a + b}}) -# # fallback -# return web.json_response({"data": {}}, status=200) - -# app = web.Application() -# app.router.add_post("/graphql", graphql_handler) -# return app - -# @pytest_asyncio.fixture -# async def aiohttp_graphql_client(aiohttp_client, graphql_app): -# return await aiohttp_client(graphql_app) - -# @pytest_asyncio.fixture -# def transport(): -# return GraphQLClientTransport() - -# @pytest_asyncio.fixture -# def provider(aiohttp_graphql_client): -# return GraphQLProvider( -# name="test-graphql-provider", -# url=str(aiohttp_graphql_client.make_url("/graphql")), -# headers={}, -# ) - -# @pytest.mark.asyncio -# async def test_register_tool_provider_discovers_tools(transport, provider): -# tools = await transport.register_tool_provider(provider) -# tool_names = [tool.name for tool in tools] -# assert "hello" in tool_names -# assert "add" in tool_names - -# @pytest.mark.asyncio -# async def test_call_tool_query(transport, provider): -# result = await transport.call_tool("hello", {"name": "Alice"}, provider) -# assert result["hello"] == "Hello, Alice!" - -# @pytest.mark.asyncio -# async def test_call_tool_mutation(transport, provider): -# provider.operation_type = "mutation" -# mutation = ''' -# mutation ($a: Int, $b: Int) { -# add(a: $a, b: $b) -# } -# ''' -# result = await transport.call_tool("add", {"a": 2, "b": 3}, provider, query=mutation) -# assert result["add"] == 5 - -# @pytest.mark.asyncio -# async def test_call_tool_api_key(transport, provider): -# provider.headers = {} -# provider.auth = ApiKeyAuth(var_name="X-API-Key", api_key="test-key") -# result = await transport.call_tool("hello", {"name": "Bob"}, provider) -# assert result["hello"] == "Hello, Bob!" - -# @pytest.mark.asyncio -# async def test_call_tool_basic_auth(transport, provider): -# provider.headers = {} -# provider.auth = BasicAuth(username="user", password="pass") -# result = await transport.call_tool("hello", {"name": "Eve"}, provider) -# assert result["hello"] == "Hello, Eve!" - -# @pytest.mark.asyncio -# async def test_call_tool_oauth2(monkeypatch, transport, provider): -# async def fake_oauth2(auth): -# return "fake-token" -# transport._handle_oauth2 = fake_oauth2 -# provider.headers = {} -# provider.auth = OAuth2Auth(token_url="http://fake/token", client_id="id", client_secret="secret", scope="scope") -# result = await transport.call_tool("hello", {"name": "Zoe"}, provider) -# assert result["hello"] == "Hello, Zoe!" - -# @pytest.mark.asyncio -# async def test_enforce_https_or_localhost_raises(transport, provider): -# provider.url = "http://evil.com/graphql" -# with pytest.raises(ValueError): -# await transport.call_tool("hello", {"name": "Mallory"}, provider) - -# @pytest.mark.asyncio -# async def test_deregister_tool_provider_noop(transport, provider): -# await transport.deregister_tool_provider(provider) \ No newline at end of file diff --git a/plugins/communication_protocols/gql/pyproject.toml b/plugins/communication_protocols/gql/pyproject.toml index 7c752c3..d5b558d 100644 --- a/plugins/communication_protocols/gql/pyproject.toml +++ b/plugins/communication_protocols/gql/pyproject.toml @@ -8,7 +8,7 @@ version = "1.0.2" authors = [ { name = "UTCP Contributors" }, ] -description = "UTCP communication protocol plugin for GraphQL. (Work in progress)" +description = "UTCP communication protocol plugin for GraphQL." readme = "README.md" requires-python = ">=3.10" dependencies = [ diff --git a/plugins/communication_protocols/gql/src/utcp_gql/__init__.py b/plugins/communication_protocols/gql/src/utcp_gql/__init__.py index e69de29..6dd0fda 100644 --- a/plugins/communication_protocols/gql/src/utcp_gql/__init__.py +++ b/plugins/communication_protocols/gql/src/utcp_gql/__init__.py @@ -0,0 +1,9 @@ +from utcp.plugins.discovery import register_communication_protocol, register_call_template + +from .gql_communication_protocol import GraphQLCommunicationProtocol +from .gql_call_template import GraphQLCallTemplate, GraphQLCallTemplateSerializer + + +def register(): + register_communication_protocol("graphql", GraphQLCommunicationProtocol()) + register_call_template("graphql", GraphQLCallTemplateSerializer()) \ No newline at end of file diff --git a/plugins/communication_protocols/gql/src/utcp_gql/gql_call_template.py b/plugins/communication_protocols/gql/src/utcp_gql/gql_call_template.py index dfe5b07..579d691 100644 --- a/plugins/communication_protocols/gql/src/utcp_gql/gql_call_template.py +++ b/plugins/communication_protocols/gql/src/utcp_gql/gql_call_template.py @@ -1,15 +1,22 @@ from utcp.data.call_template import CallTemplate -from utcp.data.auth import Auth +from utcp.data.auth import Auth, AuthSerializer +from utcp.interfaces.serializer import Serializer +from utcp.exceptions import UtcpSerializerValidationError +import traceback from typing import Dict, List, Optional, Literal -from pydantic import Field +from pydantic import Field, field_serializer, field_validator -class GraphQLProvider(CallTemplate): +class GraphQLCallTemplate(CallTemplate): """Provider configuration for GraphQL-based tools. Enables communication with GraphQL endpoints supporting queries, mutations, and subscriptions. Provides flexible query execution with custom headers and authentication. + For maximum flexibility, use the `query` field to provide a complete GraphQL + query string with proper selection sets and variable types. This allows agents + to call any existing GraphQL endpoint without limitations. + Attributes: call_template_type: Always "graphql" for GraphQL providers. url: The GraphQL endpoint URL. @@ -18,6 +25,23 @@ class GraphQLProvider(CallTemplate): auth: Optional authentication configuration. headers: Optional static headers to include in requests. header_fields: List of tool argument names to map to HTTP request headers. + query: Custom GraphQL query string with full control over selection sets + and variable types. Example: 'query GetUser($id: ID!) { user(id: $id) { id name } }' + variable_types: Map of variable names to GraphQL types for auto-generated queries. + Example: {'id': 'ID!', 'limit': 'Int'}. Defaults to 'String' if not specified. + + Example: + # Full flexibility with custom query + template = GraphQLCallTemplate( + url="https://api.example.com/graphql", + query="query GetUser($id: ID!) { user(id: $id) { id name email } }", + ) + + # Auto-generation with proper types + template = GraphQLCallTemplate( + url="https://api.example.com/graphql", + variable_types={"limit": "Int", "active": "Boolean"}, + ) """ call_template_type: Literal["graphql"] = "graphql" @@ -27,3 +51,43 @@ class GraphQLProvider(CallTemplate): auth: Optional[Auth] = None headers: Optional[Dict[str, str]] = None header_fields: Optional[List[str]] = Field(default=None, description="List of input fields to be sent as request headers for the initial connection.") + query: Optional[str] = Field( + default=None, + description="Custom GraphQL query/mutation string. Use $varName syntax for variables. " + "If provided, this takes precedence over auto-generation. " + "Example: 'query GetUser($id: ID!) { user(id: $id) { id name email } }'" + ) + variable_types: Optional[Dict[str, str]] = Field( + default=None, + description="Map of variable names to GraphQL types for auto-generated queries. " + "Example: {'id': 'ID!', 'limit': 'Int', 'active': 'Boolean'}. " + "Defaults to 'String' if not specified." + ) + + @field_serializer("auth") + def serialize_auth(self, auth: Optional[Auth]): + if auth is None: + return None + return AuthSerializer().to_dict(auth) + + @field_validator("auth", mode="before") + @classmethod + def validate_auth(cls, v: Optional[Auth | dict]): + if v is None: + return None + if isinstance(v, Auth): + return v + return AuthSerializer().validate_dict(v) + + +class GraphQLCallTemplateSerializer(Serializer[GraphQLCallTemplate]): + def to_dict(self, obj: GraphQLCallTemplate) -> dict: + return obj.model_dump() + + def validate_dict(self, data: dict) -> GraphQLCallTemplate: + try: + return GraphQLCallTemplate.model_validate(data) + except Exception as e: + raise UtcpSerializerValidationError( + f"Invalid GraphQLCallTemplate: {e}\n{traceback.format_exc()}" + ) \ No newline at end of file diff --git a/plugins/communication_protocols/gql/src/utcp_gql/gql_communication_protocol.py b/plugins/communication_protocols/gql/src/utcp_gql/gql_communication_protocol.py index f27f803..16b945c 100644 --- a/plugins/communication_protocols/gql/src/utcp_gql/gql_communication_protocol.py +++ b/plugins/communication_protocols/gql/src/utcp_gql/gql_communication_protocol.py @@ -1,36 +1,55 @@ -import sys -from typing import Dict, Any, List, Optional, Callable +import logging +from typing import Dict, Any, List, Optional, AsyncGenerator, TYPE_CHECKING + import aiohttp -import asyncio -import ssl from gql import Client as GqlClient, gql as gql_query from gql.transport.aiohttp import AIOHTTPTransport -from utcp.client.client_transport_interface import ClientTransportInterface -from utcp.shared.provider import Provider, GraphQLProvider -from utcp.shared.tool import Tool, ToolInputOutputSchema -from utcp.shared.auth import ApiKeyAuth, BasicAuth, OAuth2Auth -import logging + +from utcp.interfaces.communication_protocol import CommunicationProtocol +from utcp.data.call_template import CallTemplate +from utcp.data.tool import Tool, JsonSchema +from utcp.data.utcp_manual import UtcpManual +from utcp.data.register_manual_response import RegisterManualResult +from utcp.data.auth_implementations.api_key_auth import ApiKeyAuth +from utcp.data.auth_implementations.basic_auth import BasicAuth +from utcp.data.auth_implementations.oauth2_auth import OAuth2Auth + +from utcp_gql.gql_call_template import GraphQLCallTemplate + +if TYPE_CHECKING: + from utcp.utcp_client import UtcpClient + logging.basicConfig( level=logging.INFO, - format="%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d - %(message)s" + format="%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d - %(message)s", ) logger = logging.getLogger(__name__) -class GraphQLClientTransport(ClientTransportInterface): - """ - Simple, robust, production-ready GraphQL transport using gql. - Stateless, per-operation. Supports all GraphQL features. + +class GraphQLCommunicationProtocol(CommunicationProtocol): + """GraphQL protocol implementation for UTCP 1.0. + + - Discovers tools via GraphQL schema introspection. + - Executes per-call sessions using `gql` over HTTP(S). + - Supports `ApiKeyAuth`, `BasicAuth`, and `OAuth2Auth`. + - Enforces HTTPS or localhost for security. """ - def __init__(self): + + def __init__(self) -> None: self._oauth_tokens: Dict[str, Dict[str, Any]] = {} - def _enforce_https_or_localhost(self, url: str): - if not (url.startswith("https://") or url.startswith("http://localhost") or url.startswith("http://127.0.0.1")): + def _enforce_https_or_localhost(self, url: str) -> None: + if not ( + url.startswith("https://") + or url.startswith("http://localhost") + or url.startswith("http://127.0.0.1") + ): raise ValueError( - f"Security error: URL must use HTTPS or start with 'http://localhost' or 'http://127.0.0.1'. Got: {url}. " - "Non-secure URLs are vulnerable to man-in-the-middle attacks." + "Security error: URL must use HTTPS or start with 'http://localhost' or 'http://127.0.0.1'. " + "Non-secure URLs are vulnerable to man-in-the-middle attacks. " + f"Got: {url}." ) async def _handle_oauth2(self, auth: OAuth2Auth) -> str: @@ -39,10 +58,10 @@ async def _handle_oauth2(self, auth: OAuth2Auth) -> str: return self._oauth_tokens[client_id]["access_token"] async with aiohttp.ClientSession() as session: data = { - 'grant_type': 'client_credentials', - 'client_id': client_id, - 'client_secret': auth.client_secret, - 'scope': auth.scope + "grant_type": "client_credentials", + "client_id": client_id, + "client_secret": auth.client_secret, + "scope": auth.scope, } async with session.post(auth.token_url, data=data) as resp: resp.raise_for_status() @@ -50,87 +69,161 @@ async def _handle_oauth2(self, auth: OAuth2Auth) -> str: self._oauth_tokens[client_id] = token_response return token_response["access_token"] - async def _prepare_headers(self, provider: GraphQLProvider) -> Dict[str, str]: - headers = provider.headers.copy() if provider.headers else {} - if provider.auth: - if isinstance(provider.auth, ApiKeyAuth): - if provider.auth.api_key: - if provider.auth.location == "header": - headers[provider.auth.var_name] = provider.auth.api_key - # (query/cookie not supported for GraphQL by default) - elif isinstance(provider.auth, BasicAuth): + async def _prepare_headers( + self, call_template: GraphQLCallTemplate, tool_args: Optional[Dict[str, Any]] = None + ) -> Dict[str, str]: + headers: Dict[str, str] = call_template.headers.copy() if call_template.headers else {} + if call_template.auth: + if isinstance(call_template.auth, ApiKeyAuth): + if call_template.auth.api_key and call_template.auth.location == "header": + headers[call_template.auth.var_name] = call_template.auth.api_key + elif isinstance(call_template.auth, BasicAuth): import base64 - userpass = f"{provider.auth.username}:{provider.auth.password}" + + userpass = f"{call_template.auth.username}:{call_template.auth.password}" headers["Authorization"] = "Basic " + base64.b64encode(userpass.encode()).decode() - elif isinstance(provider.auth, OAuth2Auth): - token = await self._handle_oauth2(provider.auth) + elif isinstance(call_template.auth, OAuth2Auth): + token = await self._handle_oauth2(call_template.auth) headers["Authorization"] = f"Bearer {token}" + + # Map selected tool_args into headers if requested + if tool_args and call_template.header_fields: + for field in call_template.header_fields: + if field in tool_args and isinstance(tool_args[field], str): + headers[field] = tool_args[field] + return headers - async def register_tool_provider(self, manual_provider: Provider) -> List[Tool]: - if not isinstance(manual_provider, GraphQLProvider): - raise ValueError("GraphQLClientTransport can only be used with GraphQLProvider") - self._enforce_https_or_localhost(manual_provider.url) - headers = await self._prepare_headers(manual_provider) - transport = AIOHTTPTransport(url=manual_provider.url, headers=headers) - async with GqlClient(transport=transport, fetch_schema_from_transport=True) as session: - schema = session.client.schema - tools = [] - # Queries - if hasattr(schema, 'query_type') and schema.query_type: - for name, field in schema.query_type.fields.items(): - tools.append(Tool( - name=name, - description=getattr(field, 'description', '') or '', - inputs=ToolInputOutputSchema(required=None), - tool_provider=manual_provider - )) - # Mutations - if hasattr(schema, 'mutation_type') and schema.mutation_type: - for name, field in schema.mutation_type.fields.items(): - tools.append(Tool( - name=name, - description=getattr(field, 'description', '') or '', - inputs=ToolInputOutputSchema(required=None), - tool_provider=manual_provider - )) - # Subscriptions (listed, but not called here) - if hasattr(schema, 'subscription_type') and schema.subscription_type: - for name, field in schema.subscription_type.fields.items(): - tools.append(Tool( - name=name, - description=getattr(field, 'description', '') or '', - inputs=ToolInputOutputSchema(required=None), - tool_provider=manual_provider - )) - return tools - - async def deregister_tool_provider(self, manual_provider: Provider) -> None: - # Stateless: nothing to do - pass - - async def call_tool(self, tool_name: str, tool_args: Dict[str, Any], tool_provider: Provider, query: Optional[str] = None) -> Any: - if not isinstance(tool_provider, GraphQLProvider): - raise ValueError("GraphQLClientTransport can only be used with GraphQLProvider") - self._enforce_https_or_localhost(tool_provider.url) - headers = await self._prepare_headers(tool_provider) - transport = AIOHTTPTransport(url=tool_provider.url, headers=headers) + async def register_manual( + self, caller: "UtcpClient", manual_call_template: CallTemplate + ) -> RegisterManualResult: + if not isinstance(manual_call_template, GraphQLCallTemplate): + raise ValueError("GraphQLCommunicationProtocol requires a GraphQLCallTemplate call template") + self._enforce_https_or_localhost(manual_call_template.url) + + try: + headers = await self._prepare_headers(manual_call_template) + transport = AIOHTTPTransport(url=manual_call_template.url, headers=headers) + async with GqlClient(transport=transport, fetch_schema_from_transport=True) as session: + schema = session.client.schema + tools: List[Tool] = [] + + # Queries + if hasattr(schema, "query_type") and schema.query_type: + for name, field in schema.query_type.fields.items(): + tools.append( + Tool( + name=name, + description=getattr(field, "description", "") or "", + inputs=JsonSchema(type="object"), + outputs=JsonSchema(type="object"), + tool_call_template=manual_call_template, + ) + ) + + # Mutations + if hasattr(schema, "mutation_type") and schema.mutation_type: + for name, field in schema.mutation_type.fields.items(): + tools.append( + Tool( + name=name, + description=getattr(field, "description", "") or "", + inputs=JsonSchema(type="object"), + outputs=JsonSchema(type="object"), + tool_call_template=manual_call_template, + ) + ) + + # Subscriptions (listed for completeness) + if hasattr(schema, "subscription_type") and schema.subscription_type: + for name, field in schema.subscription_type.fields.items(): + tools.append( + Tool( + name=name, + description=getattr(field, "description", "") or "", + inputs=JsonSchema(type="object"), + outputs=JsonSchema(type="object"), + tool_call_template=manual_call_template, + ) + ) + + manual = UtcpManual(tools=tools) + return RegisterManualResult( + manual_call_template=manual_call_template, + manual=manual, + success=True, + errors=[], + ) + except Exception as e: + logger.error(f"GraphQL manual registration failed for '{manual_call_template.name}': {e}") + return RegisterManualResult( + manual_call_template=manual_call_template, + manual=UtcpManual(manual_version="0.0.0", tools=[]), + success=False, + errors=[str(e)], + ) + + async def deregister_manual( + self, caller: "UtcpClient", manual_call_template: CallTemplate + ) -> None: + # Stateless: nothing to clean up + return None + + async def call_tool( + self, + caller: "UtcpClient", + tool_name: str, + tool_args: Dict[str, Any], + tool_call_template: CallTemplate, + ) -> Any: + if not isinstance(tool_call_template, GraphQLCallTemplate): + raise ValueError("GraphQLCommunicationProtocol requires a GraphQLCallTemplate call template") + self._enforce_https_or_localhost(tool_call_template.url) + + headers = await self._prepare_headers(tool_call_template, tool_args) + transport = AIOHTTPTransport(url=tool_call_template.url, headers=headers) async with GqlClient(transport=transport, fetch_schema_from_transport=True) as session: - if query is not None: - document = gql_query(query) - result = await session.execute(document, variable_values=tool_args) - return result - # If no query provided, build a simple query - # Default to query operation - op_type = getattr(tool_provider, 'operation_type', 'query') - arg_str = ', '.join(f"${k}: String" for k in tool_args.keys()) - var_defs = f"({arg_str})" if arg_str else "" - arg_pass = ', '.join(f"{k}: ${k}" for k in tool_args.keys()) - arg_pass = f"({arg_pass})" if arg_pass else "" - gql_str = f"{op_type} {var_defs} {{ {tool_name}{arg_pass} }}" + # Filter out header fields from GraphQL variables; these are sent via HTTP headers + header_fields = tool_call_template.header_fields or [] + filtered_args = {k: v for k, v in tool_args.items() if k not in header_fields} + + # Use custom query if provided (highest flexibility for agents) + if tool_call_template.query: + gql_str = tool_call_template.query + else: + # Auto-generate query - use variable_types for proper typing + op_type = getattr(tool_call_template, "operation_type", "query") + base_tool_name = tool_name.split(".", 1)[-1] if "." in tool_name else tool_name + variable_types = tool_call_template.variable_types or {} + + # Build variable definitions with proper types (default to String) + arg_str = ", ".join( + f"${k}: {variable_types.get(k, 'String')}" + for k in filtered_args.keys() + ) + var_defs = f"({arg_str})" if arg_str else "" + arg_pass = ", ".join(f"{k}: ${k}" for k in filtered_args.keys()) + arg_pass = f"({arg_pass})" if arg_pass else "" + + # Note: Auto-generated queries for object-returning fields will still fail + # without a selection set. Use the `query` field for full control. + gql_str = f"{op_type} {var_defs} {{ {base_tool_name}{arg_pass} }}" + logger.debug(f"Auto-generated GraphQL: {gql_str}") + document = gql_query(gql_str) - result = await session.execute(document, variable_values=tool_args) + result = await session.execute(document, variable_values=filtered_args) return result + async def call_tool_streaming( + self, + caller: "UtcpClient", + tool_name: str, + tool_args: Dict[str, Any], + tool_call_template: CallTemplate, + ) -> AsyncGenerator[Any, None]: + # Basic implementation: execute non-streaming and yield once + result = await self.call_tool(caller, tool_name, tool_args, tool_call_template) + yield result + async def close(self) -> None: - self._oauth_tokens.clear() + self._oauth_tokens.clear() \ No newline at end of file diff --git a/plugins/communication_protocols/gql/tests/test_graphql_integration.py b/plugins/communication_protocols/gql/tests/test_graphql_integration.py new file mode 100644 index 0000000..fdc4fcb --- /dev/null +++ b/plugins/communication_protocols/gql/tests/test_graphql_integration.py @@ -0,0 +1,275 @@ +"""Integration tests for GraphQL communication protocol using real GraphQL servers. + +Uses the public Countries API (https://countries.trevorblades.com/graphql) which +requires no authentication and has a stable schema. +""" +import os +import sys +import warnings +import pytest +import pytest_asyncio + +# Ensure plugin src is importable +PLUGIN_SRC = os.path.join(os.path.dirname(__file__), "..", "src") +PLUGIN_SRC = os.path.abspath(PLUGIN_SRC) +if PLUGIN_SRC not in sys.path: + sys.path.append(PLUGIN_SRC) + +import utcp_gql +from utcp_gql.gql_call_template import GraphQLCallTemplate +from utcp_gql.gql_communication_protocol import GraphQLCommunicationProtocol + +from utcp.implementations.utcp_client_implementation import UtcpClientImplementation + +# Public GraphQL API for testing (no auth required) +COUNTRIES_API_URL = "https://countries.trevorblades.com/graphql" + +# Suppress gql SSL warning (we're using HTTPS which is secure) +warnings.filterwarnings("ignore", message=".*AIOHTTPTransport does not verify ssl.*") + + +@pytest.fixture +def protocol(): + """Create a fresh GraphQL protocol instance.""" + utcp_gql.register() + return GraphQLCommunicationProtocol() + + +@pytest_asyncio.fixture +async def client(): + """Create a minimal UTCP client.""" + return await UtcpClientImplementation.create() + + +@pytest.mark.asyncio +async def test_register_manual_discovers_tools(protocol, client): + """Test that register_manual discovers tools from a real GraphQL schema.""" + template = GraphQLCallTemplate( + name="countries_api", + url=COUNTRIES_API_URL, + ) + + result = await protocol.register_manual(client, template) + + assert result.success is True + assert len(result.manual.tools) > 0 + + # The Countries API should have these common queries + tool_names = [t.name for t in result.manual.tools] + assert "countries" in tool_names or "country" in tool_names + + +@pytest.mark.asyncio +async def test_call_tool_with_custom_query(protocol, client): + """Test calling a tool with a custom query string (fixes selection set issue).""" + # Custom query with proper selection set - this is the UTCP-flexible approach + custom_query = """ + query GetCountry($code: ID!) { + country(code: $code) { + name + capital + currency + } + } + """ + + template = GraphQLCallTemplate( + name="countries_api", + url=COUNTRIES_API_URL, + query=custom_query, + ) + + result = await protocol.call_tool( + client, + "country", + {"code": "US"}, + template, + ) + + assert result is not None + assert "country" in result + assert result["country"]["name"] == "United States" + assert result["country"]["capital"] == "Washington D.C." + + +@pytest.mark.asyncio +async def test_call_tool_with_variable_types(protocol, client): + """Test that variable_types properly maps GraphQL types (fixes String-only issue).""" + # The country query expects code: ID!, not String + # Using variable_types to specify the correct type + custom_query = """ + query GetCountry($code: ID!) { + country(code: $code) { + name + emoji + } + } + """ + + template = GraphQLCallTemplate( + name="countries_api", + url=COUNTRIES_API_URL, + query=custom_query, + variable_types={"code": "ID!"}, + ) + + result = await protocol.call_tool( + client, + "country", + {"code": "FR"}, + template, + ) + + assert result is not None + assert result["country"]["name"] == "France" + assert result["country"]["emoji"] == "šŸ‡«šŸ‡·" + + +@pytest.mark.asyncio +async def test_call_tool_list_query(protocol, client): + """Test querying a list of items with proper selection set.""" + custom_query = """ + query GetContinents { + continents { + code + name + } + } + """ + + template = GraphQLCallTemplate( + name="countries_api", + url=COUNTRIES_API_URL, + query=custom_query, + ) + + result = await protocol.call_tool( + client, + "continents", + {}, + template, + ) + + assert result is not None + assert "continents" in result + assert len(result["continents"]) == 7 # 7 continents + + continent_names = [c["name"] for c in result["continents"]] + assert "Europe" in continent_names + assert "Asia" in continent_names + + +@pytest.mark.asyncio +async def test_call_tool_nested_query(protocol, client): + """Test querying nested objects with proper selection sets.""" + custom_query = """ + query GetCountryWithLanguages($code: ID!) { + country(code: $code) { + name + languages { + code + name + } + } + } + """ + + template = GraphQLCallTemplate( + name="countries_api", + url=COUNTRIES_API_URL, + query=custom_query, + ) + + result = await protocol.call_tool( + client, + "country", + {"code": "CH"}, # Switzerland - has multiple languages + template, + ) + + assert result is not None + assert result["country"]["name"] == "Switzerland" + assert len(result["country"]["languages"]) >= 3 # German, French, Italian, Romansh + + +@pytest.mark.asyncio +async def test_call_tool_with_filter_arguments(protocol, client): + """Test queries with filter arguments using proper types.""" + custom_query = """ + query GetCountriesByContinent($filter: CountryFilterInput) { + countries(filter: $filter) { + code + name + } + } + """ + + template = GraphQLCallTemplate( + name="countries_api", + url=COUNTRIES_API_URL, + query=custom_query, + variable_types={"filter": "CountryFilterInput"}, + ) + + result = await protocol.call_tool( + client, + "countries", + {"filter": {"continent": {"eq": "EU"}}}, + template, + ) + + assert result is not None + assert "countries" in result + # Should return European countries + country_codes = [c["code"] for c in result["countries"]] + assert "DE" in country_codes # Germany + assert "FR" in country_codes # France + + +@pytest.mark.asyncio +async def test_error_handling_invalid_query(protocol, client): + """Test that invalid queries return proper errors.""" + # Invalid query syntax + invalid_query = "this is not valid graphql" + + template = GraphQLCallTemplate( + name="countries_api", + url=COUNTRIES_API_URL, + query=invalid_query, + ) + + with pytest.raises(Exception): + await protocol.call_tool( + client, + "invalid", + {}, + template, + ) + + +@pytest.mark.asyncio +async def test_error_handling_missing_selection_set_auto_generated(protocol, client): + """ + Demonstrate that auto-generated queries fail for object-returning fields. + + This test documents the limitation: without a custom query, object fields fail. + The fix is to always use the `query` field for object-returning operations. + """ + # No custom query - will auto-generate without selection set + template = GraphQLCallTemplate( + name="countries_api", + url=COUNTRIES_API_URL, + operation_type="query", + variable_types={"code": "ID!"}, + ) + + # This should fail because auto-generated query lacks selection set + # The query becomes: query ($code: ID!) { country(code: $code) } + # But country returns an object that needs: { name capital ... } + with pytest.raises(Exception): + await protocol.call_tool( + client, + "country", + {"code": "US"}, + template, + ) diff --git a/plugins/communication_protocols/mcp/pyproject.toml b/plugins/communication_protocols/mcp/pyproject.toml index 36bb48e..2efd4c3 100644 --- a/plugins/communication_protocols/mcp/pyproject.toml +++ b/plugins/communication_protocols/mcp/pyproject.toml @@ -15,7 +15,8 @@ dependencies = [ "pydantic>=2.0", "mcp>=1.12", "utcp>=1.0", - "mcp-use>=1.3" + "mcp-use>=1.3", + "langchain>=0.3.27,<0.4.0", ] classifiers = [ "Development Status :: 4 - Beta", diff --git a/plugins/communication_protocols/socket/README.md b/plugins/communication_protocols/socket/README.md index 8febb5a..04c1737 100644 --- a/plugins/communication_protocols/socket/README.md +++ b/plugins/communication_protocols/socket/README.md @@ -1 +1,44 @@ -Find the UTCP readme at https://github.com/universal-tool-calling-protocol/python-utcp. \ No newline at end of file +# UTCP Socket Plugin (UDP/TCP) + +This plugin adds UDP and TCP communication protocols to UTCP 1.0. + +## Running Tests + +Prerequisites: +- Python 3.10+ +- `pip` +- (Optional) a virtual environment + +1) Install core and the socket plugin in editable mode with dev extras: + +```bash +pip install -e "./core[dev]" +pip install -e ./plugins/communication_protocols/socket[dev] +``` + +2) Run the socket plugin tests: + +```bash +python -m pytest plugins/communication_protocols/socket/tests -v +``` + +3) Run a single test or filter by keyword: + +```bash +# One file +python -m pytest plugins/communication_protocols/socket/tests/test_tcp_communication_protocol.py -v + +# Filter by keyword (e.g., delimiter framing) +python -m pytest plugins/communication_protocols/socket/tests -k delimiter -q +``` + +4) Optional end-to-end sanity check (mock UDP/TCP servers): + +```bash +python scripts/socket_sanity.py +``` + +Notes: +- On Windows, your firewall may prompt the first time tests open UDP/TCP sockets; allow access or run as admin if needed. +- Tests use `pytest-asyncio`. The dev extras installed above provide required dependencies. +- Streaming is single-chunk by design, consistent with HTTP/Text transports. Multi-chunk streaming can be added later behind provider configuration. \ No newline at end of file diff --git a/plugins/communication_protocols/socket/pyproject.toml b/plugins/communication_protocols/socket/pyproject.toml index 06f845e..a544648 100644 --- a/plugins/communication_protocols/socket/pyproject.toml +++ b/plugins/communication_protocols/socket/pyproject.toml @@ -8,7 +8,7 @@ version = "1.0.2" authors = [ { name = "UTCP Contributors" }, ] -description = "UTCP communication protocol plugin for TCP and UDP protocols. (Work in progress)" +description = "UTCP communication protocol plugin for TCP and UDP protocols." readme = "README.md" requires-python = ">=3.10" dependencies = [ @@ -36,4 +36,7 @@ dev = [ [project.urls] Homepage = "https://utcp.io" Source = "https://github.com/universal-tool-calling-protocol/python-utcp" -Issues = "https://github.com/universal-tool-calling-protocol/python-utcp/issues" \ No newline at end of file +Issues = "https://github.com/universal-tool-calling-protocol/python-utcp/issues" + +[project.entry-points."utcp.plugins"] +socket = "utcp_socket:register" \ No newline at end of file diff --git a/plugins/communication_protocols/socket/src/utcp_socket/__init__.py b/plugins/communication_protocols/socket/src/utcp_socket/__init__.py index e69de29..a0b7f3b 100644 --- a/plugins/communication_protocols/socket/src/utcp_socket/__init__.py +++ b/plugins/communication_protocols/socket/src/utcp_socket/__init__.py @@ -0,0 +1,18 @@ +from utcp.plugins.discovery import register_communication_protocol, register_call_template +from utcp_socket.tcp_communication_protocol import TCPTransport +from utcp_socket.udp_communication_protocol import UDPTransport +from utcp_socket.tcp_call_template import TCPProviderSerializer +from utcp_socket.udp_call_template import UDPProviderSerializer + + +def register() -> None: + # Register communication protocols + register_communication_protocol("tcp", TCPTransport()) + register_communication_protocol("udp", UDPTransport()) + + # Register call templates and their serializers + register_call_template("tcp", TCPProviderSerializer()) + register_call_template("udp", UDPProviderSerializer()) + + +__all__ = ["register"] \ No newline at end of file diff --git a/plugins/communication_protocols/socket/src/utcp_socket/tcp_call_template.py b/plugins/communication_protocols/socket/src/utcp_socket/tcp_call_template.py index 157e43c..8b27d1c 100644 --- a/plugins/communication_protocols/socket/src/utcp_socket/tcp_call_template.py +++ b/plugins/communication_protocols/socket/src/utcp_socket/tcp_call_template.py @@ -1,6 +1,9 @@ from utcp.data.call_template import CallTemplate from typing import Optional, Literal from pydantic import Field +from utcp.interfaces.serializer import Serializer +from utcp.exceptions import UtcpSerializerValidationError +import traceback class TCPProvider(CallTemplate): """Provider configuration for raw TCP socket tools. @@ -63,7 +66,11 @@ class TCPProvider(CallTemplate): # Delimiter-based framing options message_delimiter: str = Field( default='\x00', - description="Delimiter to detect end of TCP response (e.g., '\\n', '\\r\\n', '\\x00'). Used with 'delimiter' framing." + description="Delimiter to detect end of TCP response (e.g., '\n', '\r\n', '\x00'). Used with 'delimiter' framing." + ) + interpret_escape_sequences: bool = Field( + default=True, + description="If True, interpret Python-style escape sequences in message_delimiter (e.g., '\\n', '\\r\\n', '\\x00'). If False, use the delimiter literally as provided." ) # Fixed-length framing options fixed_message_length: Optional[int] = Field( @@ -77,3 +84,16 @@ class TCPProvider(CallTemplate): ) timeout: int = 30000 auth: None = None + + +class TCPProviderSerializer(Serializer[TCPProvider]): + def to_dict(self, obj: TCPProvider) -> dict: + return obj.model_dump() + + def validate_dict(self, data: dict) -> TCPProvider: + try: + return TCPProvider.model_validate(data) + except Exception as e: + raise UtcpSerializerValidationError( + f"Invalid TCPProvider: {e}\n{traceback.format_exc()}" + ) diff --git a/plugins/communication_protocols/socket/src/utcp_socket/tcp_communication_protocol.py b/plugins/communication_protocols/socket/src/utcp_socket/tcp_communication_protocol.py index 1b360a8..b2f08c3 100644 --- a/plugins/communication_protocols/socket/src/utcp_socket/tcp_communication_protocol.py +++ b/plugins/communication_protocols/socket/src/utcp_socket/tcp_communication_protocol.py @@ -10,9 +10,12 @@ import sys from typing import Dict, Any, List, Optional, Callable, Union -from utcp.client.client_transport_interface import ClientTransportInterface -from utcp.shared.provider import Provider, TCPProvider -from utcp.shared.tool import Tool +from utcp.interfaces.communication_protocol import CommunicationProtocol +from utcp_socket.tcp_call_template import TCPProvider, TCPProviderSerializer +from utcp.data.tool import Tool +from utcp.data.call_template import CallTemplate, CallTemplateSerializer +from utcp.data.register_manual_response import RegisterManualResult +from utcp.data.utcp_manual import UtcpManual import logging logging.basicConfig( @@ -22,7 +25,7 @@ logger = logging.getLogger(__name__) -class TCPTransport(ClientTransportInterface): +class TCPTransport(CommunicationProtocol): """Transport implementation for TCP-based tool providers. This transport communicates with tools over TCP sockets. It supports: @@ -85,6 +88,35 @@ def _format_tool_call_message( else: # Default to JSON format return json.dumps(tool_args) + + def _ensure_tool_call_template(self, tool_data: Dict[str, Any], manual_call_template: TCPProvider) -> Dict[str, Any]: + """Normalize tool definition to include a valid 'tool_call_template'. + + - If 'tool_call_template' exists, validate it. + - Else if legacy 'tool_provider' exists, convert using TCPProviderSerializer. + - Else default to the provided manual_call_template. + """ + normalized = dict(tool_data) + try: + if "tool_call_template" in normalized and normalized["tool_call_template"] is not None: + try: + ctpl = CallTemplateSerializer().validate_dict(normalized["tool_call_template"]) # type: ignore + normalized["tool_call_template"] = ctpl + except Exception: + normalized["tool_call_template"] = manual_call_template + elif "tool_provider" in normalized and normalized["tool_provider"] is not None: + try: + ctpl = TCPProviderSerializer().validate_dict(normalized["tool_provider"]) # type: ignore + normalized.pop("tool_provider", None) + normalized["tool_call_template"] = ctpl + except Exception: + normalized.pop("tool_provider", None) + normalized["tool_call_template"] = manual_call_template + else: + normalized["tool_call_template"] = manual_call_template + except Exception: + normalized["tool_call_template"] = manual_call_template + return normalized def _encode_message_with_framing(self, message: str, provider: TCPProvider) -> bytes: """Encode message with appropriate TCP framing. @@ -115,10 +147,15 @@ def _encode_message_with_framing(self, message: str, provider: TCPProvider) -> b elif provider.framing_strategy == "delimiter": # Add delimiter after the message - delimiter = provider.message_delimiter or "\\x00" - # Handle escape sequences - delimiter = delimiter.encode('utf-8').decode('unicode_escape') - return message_bytes + delimiter.encode('utf-8') + delimiter = provider.message_delimiter or "\x00" + if provider.interpret_escape_sequences: + # Handle escape sequences (e.g., "\n", "\r\n", "\x00") + delimiter = delimiter.encode('utf-8').decode('unicode_escape') + delimiter_bytes = delimiter.encode('utf-8') + else: + # Use delimiter literally as provided + delimiter_bytes = delimiter.encode('utf-8') + return message_bytes + delimiter_bytes elif provider.framing_strategy in ("fixed_length", "stream"): # No additional framing needed @@ -170,8 +207,19 @@ def _decode_response_with_framing(self, sock: socket.socket, provider: TCPProvid elif provider.framing_strategy == "delimiter": # Read until delimiter is found - delimiter = provider.message_delimiter or "\\x00" - delimiter = delimiter.encode('utf-8').decode('unicode_escape').encode('utf-8') + # Delimiter handling: + # The code supports both literal delimiters (e.g., "\\x00") and escape-sequence interpreted delimiters (e.g., "\x00") + # via the `interpret_escape_sequences` flag in TCPProvider. This ensures compatibility with both legacy and updated + # wire protocols. The delimiter is interpreted according to the flag, so no breaking change occurs unless the flag + # is set differently than expected by the server/client. + # Example: + # If interpret_escape_sequences is True, "\\x00" becomes a null byte; if False, it remains four literal bytes. + # delimiter = delimiter.encode('utf-8') + delimiter = provider.message_delimiter or "\x00" + if provider.interpret_escape_sequences: + delimiter_bytes = delimiter.encode('utf-8').decode('unicode_escape').encode('utf-8') + else: + delimiter_bytes = delimiter.encode('utf-8') response_data = b"" while True: @@ -181,9 +229,9 @@ def _decode_response_with_framing(self, sock: socket.socket, provider: TCPProvid response_data += chunk # Check if we've received the delimiter - if response_data.endswith(delimiter): + if response_data.endswith(delimiter_bytes): # Remove delimiter from response - return response_data[:-len(delimiter)] + return response_data[:-len(delimiter_bytes)] elif provider.framing_strategy == "fixed_length": # Read exactly fixed_message_length bytes @@ -214,10 +262,14 @@ def _decode_response_with_framing(self, sock: socket.socket, provider: TCPProvid break return response_data - - else: - raise ValueError(f"Unknown framing strategy: {provider.framing_strategy}") + else: + # Copilot AI (5 days ago): + # The else branch for unknown framing strategies was previously removed, + # which could cause silent fallthrough and confusing behavior. Add explicit + # validation to raise a descriptive error when an unsupported strategy is provided. + raise ValueError(f"Unknown framing strategy: {provider.framing_strategy!r}") + async def _send_tcp_message( self, host: str, @@ -289,122 +341,91 @@ def _send_and_receive(): self._log_error(f"Error in TCP communication: {e}") raise - async def register_tool_provider(self, manual_provider: Provider) -> List[Tool]: - """Register a TCP provider and discover its tools. - - Sends a discovery message to the TCP provider and parses the response. - - Args: - manual_provider: The TCPProvider to register - - Returns: - List of tools discovered from the TCP provider - - Raises: - ValueError: If provider is not a TCPProvider - """ - if not isinstance(manual_provider, TCPProvider): + async def register_manual(self, caller, manual_call_template: CallTemplate) -> RegisterManualResult: + """Register a TCP manual and discover its tools.""" + if not isinstance(manual_call_template, TCPProvider): raise ValueError("TCPTransport can only be used with TCPProvider") - self._log_info(f"Registering TCP provider '{manual_provider.name}'") + self._log_info(f"Registering TCP provider '{manual_call_template.name}'") try: - # Send discovery message - discovery_message = json.dumps({ - "type": "utcp" - }) - + discovery_message = json.dumps({"type": "utcp"}) response = await self._send_tcp_message( - manual_provider.host, - manual_provider.port, + manual_call_template.host, + manual_call_template.port, discovery_message, - manual_provider, - manual_provider.timeout / 1000.0, # Convert ms to seconds - manual_provider.response_byte_format + manual_call_template, + manual_call_template.timeout / 1000.0, + manual_call_template.response_byte_format ) - - # Parse response try: - # Handle bytes response by trying to decode as UTF-8 for JSON parsing - if isinstance(response, bytes): - response_str = response.decode('utf-8') - else: - response_str = response - + response_str = response.decode('utf-8') if isinstance(response, bytes) else response response_data = json.loads(response_str) - - # Check if response contains tools + tools: List[Tool] = [] if isinstance(response_data, dict) and 'tools' in response_data: tools_data = response_data['tools'] - - # Parse tools - tools = [] for tool_data in tools_data: try: - tool = Tool(**tool_data) - tools.append(tool) + normalized = self._ensure_tool_call_template(tool_data, manual_call_template) + tools.append(Tool(**normalized)) except Exception as e: - self._log_error(f"Invalid tool definition in TCP provider '{manual_provider.name}': {e}") + self._log_error(f"Invalid tool definition in TCP provider '{manual_call_template.name}': {e}") continue - - self._log_info(f"Discovered {len(tools)} tools from TCP provider '{manual_provider.name}'") - return tools + self._log_info(f"Discovered {len(tools)} tools from TCP provider '{manual_call_template.name}'") else: - self._log_info(f"No tools found in TCP provider '{manual_provider.name}' response") - return [] - + self._log_info(f"No tools found in TCP provider '{manual_call_template.name}' response") + manual = UtcpManual(utcp_version="1.0", manual_version="1.0", tools=tools) + return RegisterManualResult( + manual_call_template=manual_call_template, + manual=manual, + success=True, + errors=[] + ) except json.JSONDecodeError as e: - self._log_error(f"Invalid JSON response from TCP provider '{manual_provider.name}': {e}") - return [] - + self._log_error(f"Invalid JSON response from TCP provider '{manual_call_template.name}': {e}") + return RegisterManualResult( + manual_call_template=manual_call_template, + manual=UtcpManual(utcp_version="1.0", manual_version="1.0", tools=[]), + success=False, + errors=[str(e)] + ) except Exception as e: - self._log_error(f"Error registering TCP provider '{manual_provider.name}': {e}") - return [] + self._log_error(f"Error registering TCP provider '{manual_call_template.name}': {e}") + return RegisterManualResult( + manual_call_template=manual_call_template, + manual=UtcpManual(utcp_version="1.0", manual_version="1.0", tools=[]), + success=False, + errors=[str(e)] + ) - async def deregister_tool_provider(self, manual_provider: Provider) -> None: - """Deregister a TCP provider. - - This is a no-op for TCP providers since connections are created per request. - - Args: - manual_provider: The provider to deregister - """ - if not isinstance(manual_provider, TCPProvider): + async def deregister_manual(self, caller, manual_call_template: CallTemplate) -> None: + """Deregister a TCP provider (no-op).""" + if not isinstance(manual_call_template, TCPProvider): raise ValueError("TCPTransport can only be used with TCPProvider") - - self._log_info(f"Deregistering TCP provider '{manual_provider.name}' (no-op)") + self._log_info(f"Deregistering TCP provider '{manual_call_template.name}' (no-op)") - async def call_tool(self, tool_name: str, tool_args: Dict[str, Any], tool_provider: Provider) -> Any: - """Call a TCP tool. - - Sends a tool call message to the TCP provider and returns the response. - - Args: - tool_name: Name of the tool to call - tool_args: Arguments for the tool call - tool_provider: The TCPProvider containing the tool - - Returns: - The response from the TCP tool - - Raises: - ValueError: If provider is not a TCPProvider - """ - if not isinstance(tool_provider, TCPProvider): + async def call_tool_streaming(self, caller, tool_name: str, tool_args: Dict[str, Any], tool_call_template: CallTemplate): + async def _generator(): + yield await self.call_tool(caller, tool_name, tool_args, tool_call_template) + return _generator() + + async def call_tool(self, caller, tool_name: str, tool_args: Dict[str, Any], tool_call_template: CallTemplate) -> Any: + """Call a TCP tool.""" + if not isinstance(tool_call_template, TCPProvider): raise ValueError("TCPTransport can only be used with TCPProvider") - self._log_info(f"Calling TCP tool '{tool_name}' on provider '{tool_provider.name}'") + self._log_info(f"Calling TCP tool '{tool_name}' on provider '{tool_call_template.name}'") try: - tool_call_message = self._format_tool_call_message(tool_args, tool_provider) + tool_call_message = self._format_tool_call_message(tool_args, tool_call_template) response = await self._send_tcp_message( - tool_provider.host, - tool_provider.port, + tool_call_template.host, + tool_call_template.port, tool_call_message, - tool_provider, - tool_provider.timeout / 1000.0, # Convert ms to seconds - tool_provider.response_byte_format + tool_call_template, + tool_call_template.timeout / 1000.0, + tool_call_template.response_byte_format ) return response diff --git a/plugins/communication_protocols/socket/src/utcp_socket/udp_call_template.py b/plugins/communication_protocols/socket/src/utcp_socket/udp_call_template.py index 4c704da..8c30c86 100644 --- a/plugins/communication_protocols/socket/src/utcp_socket/udp_call_template.py +++ b/plugins/communication_protocols/socket/src/utcp_socket/udp_call_template.py @@ -1,6 +1,9 @@ from utcp.data.call_template import CallTemplate from typing import Optional, Literal from pydantic import Field +from utcp.interfaces.serializer import Serializer +from utcp.exceptions import UtcpSerializerValidationError +import traceback class UDPProvider(CallTemplate): """Provider configuration for UDP (User Datagram Protocol) socket tools. @@ -38,3 +41,16 @@ class UDPProvider(CallTemplate): response_byte_format: Optional[str] = Field(default="utf-8", description="Encoding to decode response bytes. If None, returns raw bytes.") timeout: int = 30000 auth: None = None + + +class UDPProviderSerializer(Serializer[UDPProvider]): + def to_dict(self, obj: UDPProvider) -> dict: + return obj.model_dump() + + def validate_dict(self, data: dict) -> UDPProvider: + try: + return UDPProvider.model_validate(data) + except Exception as e: + raise UtcpSerializerValidationError( + f"Invalid UDPProvider: {e}\n{traceback.format_exc()}" + ) diff --git a/plugins/communication_protocols/socket/src/utcp_socket/udp_communication_protocol.py b/plugins/communication_protocols/socket/src/utcp_socket/udp_communication_protocol.py index 8d4d404..89ae3e3 100644 --- a/plugins/communication_protocols/socket/src/utcp_socket/udp_communication_protocol.py +++ b/plugins/communication_protocols/socket/src/utcp_socket/udp_communication_protocol.py @@ -9,14 +9,18 @@ import traceback from typing import Dict, Any, List, Optional, Callable, Union -from utcp.client.client_transport_interface import ClientTransportInterface -from utcp.shared.provider import Provider, UDPProvider -from utcp.shared.tool import Tool +from utcp.interfaces.communication_protocol import CommunicationProtocol +from utcp_socket.udp_call_template import UDPProvider, UDPProviderSerializer +from utcp.data.tool import Tool +from utcp.data.call_template import CallTemplate, CallTemplateSerializer +from utcp.data.register_manual_response import RegisterManualResult +from utcp.data.utcp_manual import UtcpManual +from utcp.exceptions import UtcpSerializerValidationError import logging logger = logging.getLogger(__name__) -class UDPTransport(ClientTransportInterface): +class UDPTransport(CommunicationProtocol): """Transport implementation for UDP-based tool providers. This transport communicates with tools over UDP sockets. It supports: @@ -80,6 +84,42 @@ def _format_tool_call_message( else: # Default to JSON format return json.dumps(tool_args) + + def _ensure_tool_call_template(self, tool_data: Dict[str, Any], manual_call_template: UDPProvider) -> Dict[str, Any]: + """Normalize tool definition to include a valid 'tool_call_template'. + + - If 'tool_call_template' exists, validate it. + - Else if legacy 'tool_provider' exists, convert using UDPProviderSerializer. + - Else default to the provided manual_call_template. + """ + normalized = dict(tool_data) + try: + if "tool_call_template" in normalized and normalized["tool_call_template"] is not None: + # Validate via generic CallTemplate serializer (type-dispatched) + try: + ctpl = CallTemplateSerializer().validate_dict(normalized["tool_call_template"]) # type: ignore + normalized["tool_call_template"] = ctpl + except (UtcpSerializerValidationError, ValueError) as e: + # Fallback to manual template if validation fails, but log details + logger.exception("Failed to validate existing tool_call_template; falling back to manual template") + normalized["tool_call_template"] = manual_call_template + elif "tool_provider" in normalized and normalized["tool_provider"] is not None: + # Convert legacy provider -> call template + try: + ctpl = UDPProviderSerializer().validate_dict(normalized["tool_provider"]) # type: ignore + normalized.pop("tool_provider", None) + normalized["tool_call_template"] = ctpl + except UtcpSerializerValidationError as e: + logger.exception("Failed to convert legacy tool_provider to call template; falling back to manual template") + normalized.pop("tool_provider", None) + normalized["tool_call_template"] = manual_call_template + else: + normalized["tool_call_template"] = manual_call_template + except Exception: + # Any unexpected error during normalization should be logged + logger.exception("Unexpected error normalizing tool definition; falling back to manual template") + normalized["tool_call_template"] = manual_call_template + return normalized async def _send_udp_message( self, @@ -202,125 +242,96 @@ def _send_only(): self._log_error(f"Error sending UDP message (no response): {traceback.format_exc()}") raise - async def register_tool_provider(self, manual_provider: Provider) -> List[Tool]: - """Register a UDP provider and discover its tools. - - Sends a discovery message to the UDP provider and parses the response. - - Args: - manual_provider: The UDPProvider to register - - Returns: - List of tools discovered from the UDP provider - - Raises: - ValueError: If provider is not a UDPProvider - """ - if not isinstance(manual_provider, UDPProvider): + async def register_manual(self, caller, manual_call_template: CallTemplate) -> RegisterManualResult: + """Register a UDP manual and discover its tools.""" + if not isinstance(manual_call_template, UDPProvider): raise ValueError("UDPTransport can only be used with UDPProvider") - self._log_info(f"Registering UDP provider '{manual_provider.name}' at {manual_provider.host}:{manual_provider.port}") + self._log_info(f"Registering UDP provider '{manual_call_template.name}' at {manual_call_template.host}:{manual_call_template.port}") try: - # Send discovery message - discovery_message = json.dumps({ - "type": "utcp" - }) - + discovery_message = json.dumps({"type": "utcp"}) response = await self._send_udp_message( - manual_provider.host, - manual_provider.port, + manual_call_template.host, + manual_call_template.port, discovery_message, - manual_provider.timeout / 1000.0, # Convert ms to seconds - manual_provider.number_of_response_datagrams, - manual_provider.response_byte_format + manual_call_template.timeout / 1000.0, + manual_call_template.number_of_response_datagrams, + manual_call_template.response_byte_format ) - - # Parse response try: - # Handle bytes response by trying to decode as UTF-8 for JSON parsing - if isinstance(response, bytes): - response_str = response.decode('utf-8') - else: - response_str = response - + response_str = response.decode('utf-8') if isinstance(response, bytes) else response response_data = json.loads(response_str) - - # Check if response contains tools + tools: List[Tool] = [] if isinstance(response_data, dict) and 'tools' in response_data: tools_data = response_data['tools'] - - # Parse tools - tools = [] for tool_data in tools_data: try: - tool = Tool(**tool_data) + normalized = self._ensure_tool_call_template(tool_data, manual_call_template) + tool = Tool(**normalized) tools.append(tool) - except Exception as e: - self._log_error(f"Invalid tool definition in UDP provider '{manual_provider.name}': {traceback.format_exc()}") + except Exception: + self._log_error(f"Invalid tool definition in UDP provider '{manual_call_template.name}': {traceback.format_exc()}") continue - - self._log_info(f"Discovered {len(tools)} tools from UDP provider '{manual_provider.name}'") - return tools + self._log_info(f"Discovered {len(tools)} tools from UDP provider '{manual_call_template.name}'") else: - self._log_info(f"No tools found in UDP provider '{manual_provider.name}' response") - return [] - + self._log_info(f"No tools found in UDP provider '{manual_call_template.name}' response") + manual = UtcpManual(utcp_version="1.0", manual_version="1.0", tools=tools) + return RegisterManualResult( + manual_call_template=manual_call_template, + manual=manual, + success=True, + errors=[] + ) except json.JSONDecodeError as e: - self._log_error(f"Invalid JSON response from UDP provider '{manual_provider.name}': {traceback.format_exc()}") - return [] - + self._log_error(f"Invalid JSON response from UDP provider '{manual_call_template.name}': {traceback.format_exc()}") + manual = UtcpManual(utcp_version="1.0", manual_version="1.0", tools=[]) + return RegisterManualResult( + manual_call_template=manual_call_template, + manual=manual, + success=False, + errors=[str(e)] + ) except Exception as e: - self._log_error(f"Error registering UDP provider '{manual_provider.name}': {traceback.format_exc()}") - return [] + self._log_error(f"Error registering UDP provider '{manual_call_template.name}': {traceback.format_exc()}") + manual = UtcpManual(utcp_version="1.0", manual_version="1.0", tools=[]) + return RegisterManualResult( + manual_call_template=manual_call_template, + manual=manual, + success=False, + errors=[str(e)] + ) - async def deregister_tool_provider(self, manual_provider: Provider) -> None: - """Deregister a UDP provider. - - This is a no-op for UDP providers since they are stateless. - - Args: - manual_provider: The provider to deregister - """ - if not isinstance(manual_provider, UDPProvider): + async def deregister_manual(self, caller, manual_call_template: CallTemplate) -> None: + if not isinstance(manual_call_template, UDPProvider): raise ValueError("UDPTransport can only be used with UDPProvider") - - self._log_info(f"Deregistering UDP provider '{manual_provider.name}' (no-op)") + self._log_info(f"Deregistering UDP provider '{manual_call_template.name}' (no-op)") - async def call_tool(self, tool_name: str, tool_args: Dict[str, Any], tool_provider: Provider) -> Any: - """Call a UDP tool. - - Sends a tool call message to the UDP provider and returns the response. - - Args: - tool_name: Name of the tool to call - arguments: Arguments for the tool call - tool_provider: The UDPProvider containing the tool - - Returns: - The response from the UDP tool - - Raises: - ValueError: If provider is not a UDPProvider - """ - if not isinstance(tool_provider, UDPProvider): + async def call_tool(self, caller, tool_name: str, tool_args: Dict[str, Any], tool_call_template: CallTemplate) -> Any: + if not isinstance(tool_call_template, UDPProvider): raise ValueError("UDPTransport can only be used with UDPProvider") - - self._log_info(f"Calling UDP tool '{tool_name}' on provider '{tool_provider.name}'") - + self._log_info(f"Calling UDP tool '{tool_name}' on provider '{tool_call_template.name}'") try: - tool_call_message = self._format_tool_call_message(tool_args, tool_provider) - + tool_call_message = self._format_tool_call_message(tool_args, tool_call_template) response = await self._send_udp_message( - tool_provider.host, - tool_provider.port, + tool_call_template.host, + tool_call_template.port, tool_call_message, - tool_provider.timeout / 1000.0, # Convert ms to seconds - tool_provider.number_of_response_datagrams, - tool_provider.response_byte_format + tool_call_template.timeout / 1000.0, + tool_call_template.number_of_response_datagrams, + tool_call_template.response_byte_format ) return response - except Exception as e: self._log_error(f"Error calling UDP tool '{tool_name}': {traceback.format_exc()}") raise + + # Copilot AI (5 days ago): + # The call_tool_streaming method wraps a generator function but doesn't use the async def syntax for the method itself. + # While this works, it's inconsistent with the other implementation in tcp_communication_protocol.py (lines 384-387) which properly uses async def with an inner generator. + # For consistency and clarity, this should also use async def directly: + # + # async def call_tool_streaming(self, caller, tool_name: str, tool_args: Dict[str, Any], tool_call_template: CallTemplate): + # yield await self.call_tool(caller, tool_name, tool_args, tool_call_template) + async def call_tool_streaming(self, caller, tool_name: str, tool_args: Dict[str, Any], tool_call_template: CallTemplate): + yield await self.call_tool(caller, tool_name, tool_args, tool_call_template) diff --git a/plugins/communication_protocols/socket/tests/test_tcp_communication_protocol.py b/plugins/communication_protocols/socket/tests/test_tcp_communication_protocol.py new file mode 100644 index 0000000..d359fd9 --- /dev/null +++ b/plugins/communication_protocols/socket/tests/test_tcp_communication_protocol.py @@ -0,0 +1,180 @@ +import asyncio +import json +import pytest + +from utcp_socket.tcp_communication_protocol import TCPTransport +from utcp_socket.tcp_call_template import TCPProvider + + +async def start_tcp_server(): + """Start a simple TCP server that sends a mutable JSON object then closes.""" + response_container = {"bytes": b""} + + async def handle(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + try: + # Read any incoming data to simulate request handling + await reader.read(1024) + except Exception: + # Ignore exceptions during read (e.g., client disconnects), as this is a test server. + pass + # Send response and close connection + writer.write(response_container["bytes"]) + await writer.drain() + try: + writer.close() + await writer.wait_closed() + except Exception: + # Ignore exceptions during writer close; connection may already be closed or in error state. + pass + + server = await asyncio.start_server(handle, host="127.0.0.1", port=0) + port = server.sockets[0].getsockname()[1] + + def set_response(obj): + response_container["bytes"] = json.dumps(obj).encode("utf-8") + + return server, port, set_response + + +@pytest.mark.asyncio +async def test_register_manual_converts_legacy_tool_provider_tcp(): + """When manual returns legacy tool_provider, it is converted to tool_call_template.""" + # Start server and configure response after obtaining port + server, port, set_response = await start_tcp_server() + set_response({ + "tools": [ + { + "name": "tcp_tool", + "description": "Echo over TCP", + "inputs": {}, + "outputs": {}, + "tool_provider": { + "call_template_type": "tcp", + "name": "tcp-executor", + "host": "127.0.0.1", + "port": port, + "request_data_format": "json", + "response_byte_format": "utf-8", + "framing_strategy": "stream", + "timeout": 2000 + } + } + ] + }) + + try: + provider = TCPProvider( + name="tcp-provider", + host="127.0.0.1", + port=port, + request_data_format="json", + response_byte_format="utf-8", + framing_strategy="stream", + timeout=2000 + ) + transport_client = TCPTransport() + result = await transport_client.register_manual(None, provider) + + assert result.success + assert result.manual is not None + assert len(result.manual.tools) == 1 + tool = result.manual.tools[0] + assert tool.tool_call_template.call_template_type == "tcp" + assert isinstance(tool.tool_call_template, TCPProvider) + assert tool.tool_call_template.host == "127.0.0.1" + assert tool.tool_call_template.port == port + finally: + server.close() + await server.wait_closed() + + +@pytest.mark.asyncio +async def test_register_manual_validates_provided_tool_call_template_tcp(): + """When manual provides tool_call_template, it is validated and preserved.""" + server, port, set_response = await start_tcp_server() + set_response({ + "tools": [ + { + "name": "tcp_tool", + "description": "Echo over TCP", + "inputs": {}, + "outputs": {}, + "tool_call_template": { + "call_template_type": "tcp", + "name": "tcp-executor", + "host": "127.0.0.1", + "port": port, + "request_data_format": "json", + "response_byte_format": "utf-8", + "framing_strategy": "stream", + "timeout": 2000 + } + } + ] + }) + + try: + provider = TCPProvider( + name="tcp-provider", + host="127.0.0.1", + port=port, + request_data_format="json", + response_byte_format="utf-8", + framing_strategy="stream", + timeout=2000 + ) + transport_client = TCPTransport() + result = await transport_client.register_manual(None, provider) + + assert result.success + assert len(result.manual.tools) == 1 + tool = result.manual.tools[0] + assert tool.tool_call_template.call_template_type == "tcp" + assert isinstance(tool.tool_call_template, TCPProvider) + assert tool.tool_call_template.host == "127.0.0.1" + assert tool.tool_call_template.port == port + finally: + server.close() + await server.wait_closed() + + +@pytest.mark.asyncio +async def test_register_manual_fallbacks_to_manual_template_tcp(): + """When neither tool_provider nor tool_call_template is provided, fall back to manual template.""" + server, port, set_response = await start_tcp_server() + set_response({ + "tools": [ + { + "name": "tcp_tool", + "description": "Echo over TCP", + "inputs": {}, + "outputs": {} + } + ] + }) + + try: + provider = TCPProvider( + name="tcp-provider", + host="127.0.0.1", + port=port, + request_data_format="json", + response_byte_format="utf-8", + framing_strategy="stream", + timeout=2000 + ) + transport_client = TCPTransport() + result = await transport_client.register_manual(None, provider) + + assert result.success + assert len(result.manual.tools) == 1 + tool = result.manual.tools[0] + assert tool.tool_call_template.call_template_type == "tcp" + assert isinstance(tool.tool_call_template, TCPProvider) + # Should match manual (discovery) provider values + assert tool.tool_call_template.host == provider.host + assert tool.tool_call_template.port == provider.port + assert tool.tool_call_template.name == provider.name + finally: + server.close() + await server.wait_closed() \ No newline at end of file diff --git a/plugins/communication_protocols/socket/tests/test_udp_communication_protocol.py b/plugins/communication_protocols/socket/tests/test_udp_communication_protocol.py new file mode 100644 index 0000000..d6a770c --- /dev/null +++ b/plugins/communication_protocols/socket/tests/test_udp_communication_protocol.py @@ -0,0 +1,176 @@ +import asyncio +import json +import pytest + +from utcp_socket.udp_communication_protocol import UDPTransport +from utcp_socket.udp_call_template import UDPProvider + + +async def start_udp_server(): + """Start a simple UDP server that replies with a mutable JSON payload.""" + loop = asyncio.get_running_loop() + response_container = {"bytes": b""} + + class _Protocol(asyncio.DatagramProtocol): + def __init__(self, container): + self.container = container + self.transport = None + + def connection_made(self, transport): + self.transport = transport + + def datagram_received(self, data, addr): + # Always respond with the prepared payload + if self.transport: + self.transport.sendto(self.container["bytes"], addr) + + transport, protocol = await loop.create_datagram_endpoint( + lambda: _Protocol(response_container), local_addr=("127.0.0.1", 0) + ) + port = transport.get_extra_info("socket").getsockname()[1] + + def set_response(obj): + response_container["bytes"] = json.dumps(obj).encode("utf-8") + + return transport, port, set_response + + +@pytest.mark.asyncio +async def test_register_manual_converts_legacy_tool_provider_udp(): + """When manual returns legacy tool_provider, it is converted to tool_call_template.""" + # Start server and configure response after obtaining port + transport, port, set_response = await start_udp_server() + set_response({ + "tools": [ + { + "name": "udp_tool", + "description": "Echo over UDP", + "inputs": {}, + "outputs": {}, + "tool_provider": { + "call_template_type": "udp", + "name": "udp-executor", + "host": "127.0.0.1", + "port": port, + "number_of_response_datagrams": 1, + "request_data_format": "json", + "response_byte_format": "utf-8", + "timeout": 2000 + } + } + ] + }) + + try: + provider = UDPProvider( + name="udp-provider", + host="127.0.0.1", + port=port, + number_of_response_datagrams=1, + request_data_format="json", + response_byte_format="utf-8", + timeout=2000 + ) + transport_client = UDPTransport() + result = await transport_client.register_manual(None, provider) + + assert result.success + assert result.manual is not None + assert len(result.manual.tools) == 1 + tool = result.manual.tools[0] + assert tool.tool_call_template.call_template_type == "udp" + assert isinstance(tool.tool_call_template, UDPProvider) + assert tool.tool_call_template.host == "127.0.0.1" + assert tool.tool_call_template.port == port + finally: + transport.close() + + +@pytest.mark.asyncio +async def test_register_manual_validates_provided_tool_call_template_udp(): + """When manual provides tool_call_template, it is validated and preserved.""" + transport, port, set_response = await start_udp_server() + set_response({ + "tools": [ + { + "name": "udp_tool", + "description": "Echo over UDP", + "inputs": {}, + "outputs": {}, + "tool_call_template": { + "call_template_type": "udp", + "name": "udp-executor", + "host": "127.0.0.1", + "port": port, + "number_of_response_datagrams": 1, + "request_data_format": "json", + "response_byte_format": "utf-8", + "timeout": 2000 + } + } + ] + }) + + try: + provider = UDPProvider( + name="udp-provider", + host="127.0.0.1", + port=port, + number_of_response_datagrams=1, + request_data_format="json", + response_byte_format="utf-8", + timeout=2000 + ) + transport_client = UDPTransport() + result = await transport_client.register_manual(None, provider) + + assert result.success + assert len(result.manual.tools) == 1 + tool = result.manual.tools[0] + assert tool.tool_call_template.call_template_type == "udp" + assert isinstance(tool.tool_call_template, UDPProvider) + assert tool.tool_call_template.host == "127.0.0.1" + assert tool.tool_call_template.port == port + finally: + transport.close() + + +@pytest.mark.asyncio +async def test_register_manual_fallbacks_to_manual_template_udp(): + """When neither tool_provider nor tool_call_template is provided, fall back to manual template.""" + transport, port, set_response = await start_udp_server() + set_response({ + "tools": [ + { + "name": "udp_tool", + "description": "Echo over UDP", + "inputs": {}, + "outputs": {} + } + ] + }) + + try: + provider = UDPProvider( + name="udp-provider", + host="127.0.0.1", + port=port, + number_of_response_datagrams=1, + request_data_format="json", + response_byte_format="utf-8", + timeout=2000 + ) + transport_client = UDPTransport() + result = await transport_client.register_manual(None, provider) + + assert result.success + assert len(result.manual.tools) == 1 + tool = result.manual.tools[0] + assert tool.tool_call_template.call_template_type == "udp" + assert isinstance(tool.tool_call_template, UDPProvider) + # Should match manual (discovery) provider values + assert tool.tool_call_template.host == provider.host + assert tool.tool_call_template.port == provider.port + assert tool.tool_call_template.name == provider.name + finally: + transport.close() \ No newline at end of file diff --git a/plugins/communication_protocols/websocket/tests/__init__.py b/plugins/communication_protocols/websocket/tests/__init__.py deleted file mode 100644 index 614ce9a..0000000 --- a/plugins/communication_protocols/websocket/tests/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Tests for the WebSocket communication protocol plugin.""" diff --git a/scripts/socket_sanity.py b/scripts/socket_sanity.py new file mode 100644 index 0000000..5ac6028 --- /dev/null +++ b/scripts/socket_sanity.py @@ -0,0 +1,265 @@ +import sys +import json +import socket +import threading +import asyncio +from pathlib import Path + +# Ensure core and socket plugin sources are on sys.path +ROOT = Path(__file__).resolve().parent.parent +CORE_SRC = ROOT / "core" / "src" +SOCKET_SRC = ROOT / "plugins" / "communication_protocols" / "socket" / "src" +for p in [str(CORE_SRC), str(SOCKET_SRC)]: + if p not in sys.path: + sys.path.insert(0, p) + +from utcp_socket.udp_communication_protocol import UDPTransport +from utcp_socket.tcp_communication_protocol import TCPTransport +from utcp_socket.udp_call_template import UDPProvider +from utcp_socket.tcp_call_template import TCPProvider + +# ------------------------------- +# Mock UDP Server +# ------------------------------- + +def start_udp_server(host: str, port: int): + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.bind((host, port)) + + def run(): + while True: + data, addr = sock.recvfrom(65535) + try: + msg = data.decode("utf-8") + except Exception: + msg = "" + # Handle discovery + try: + parsed = json.loads(msg) + except Exception: + # Ignore JSON parsing errors; non-JSON input will be handled below + parsed = None + if isinstance(parsed, dict) and parsed.get("type") == "utcp": + manual = { + "utcp_version": "1.0", + "manual_version": "1.0", + "tools": [ + { + "name": "udp.echo", + "description": "Echo UDP args as JSON", + "inputs": { + "type": "object", + "properties": { + "text": {"type": "string"}, + "extra": {"type": "number"} + }, + "required": ["text"] + }, + "outputs": { + "type": "object", + "properties": { + "ok": {"type": "boolean"}, + "echo": {"type": "string"}, + "args": {"type": "object"} + } + }, + "tags": ["socket", "udp"], + "average_response_size": 64, + # Return legacy provider to exercise conversion path + "tool_provider": { + "call_template_type": "udp", + "name": "udp", + "host": host, + "port": port, + "request_data_format": "json", + "response_byte_format": "utf-8", + "number_of_response_datagrams": 1, + "timeout": 3000 + } + } + ] + } + payload = json.dumps(manual).encode("utf-8") + sock.sendto(payload, addr) + else: + # Tool call: echo JSON payload + try: + args = json.loads(msg) + except Exception: + args = {"raw": msg} + resp = { + "ok": True, + "echo": args.get("text", ""), + "args": args + } + sock.sendto(json.dumps(resp).encode("utf-8"), addr) + t = threading.Thread(target=run, daemon=True) + t.start() + return t + +# ------------------------------- +# Mock TCP Server (delimiter-based) +# ------------------------------- + +def start_tcp_server(host: str, port: int, delimiter: str = "\n"): + delim_bytes = delimiter.encode("utf-8") + + def handle_client(conn: socket.socket, addr): + try: + # Read until delimiter + buf = b"" + while True: + chunk = conn.recv(1) + if not chunk: + break + buf += chunk + if buf.endswith(delim_bytes): + break + msg = buf[:-len(delim_bytes)].decode("utf-8") if buf.endswith(delim_bytes) else buf.decode("utf-8") + # Discovery + parsed = None + try: + parsed = json.loads(msg) + except Exception: + pass + if isinstance(parsed, dict) and parsed.get("type") == "utcp": + manual = { + "utcp_version": "1.0", + "manual_version": "1.0", + "tools": [ + { + "name": "tcp.echo", + "description": "Echo TCP args as JSON", + "inputs": { + "type": "object", + "properties": { + "text": {"type": "string"}, + "extra": {"type": "number"} + }, + "required": ["text"] + }, + "outputs": { + "type": "object", + "properties": { + "ok": {"type": "boolean"}, + "echo": {"type": "string"}, + "args": {"type": "object"} + } + }, + "tags": ["socket", "tcp"], + "average_response_size": 64, + # Legacy provider to exercise conversion + "tool_provider": { + "call_template_type": "tcp", + "name": "tcp", + "host": host, + "port": port, + "request_data_format": "json", + "response_byte_format": "utf-8", + "framing_strategy": "delimiter", + "message_delimiter": "\\n", + "timeout": 3000 + } + } + ] + } + payload = json.dumps(manual).encode("utf-8") + delim_bytes + conn.sendall(payload) + else: + # Tool call: echo JSON payload + try: + args = json.loads(msg) + except Exception: + args = {"raw": msg} + resp = { + "ok": True, + "echo": args.get("text", ""), + "args": args + } + conn.sendall(json.dumps(resp).encode("utf-8") + delim_bytes) + finally: + try: + conn.shutdown(socket.SHUT_RDWR) + except Exception: + # Ignore errors if socket is already closed or shutdown fails + pass + conn.close() + + def run(): + srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + srv.bind((host, port)) + srv.listen(5) + while True: + conn, addr = srv.accept() + threading.Thread(target=handle_client, args=(conn, addr), daemon=True).start() + + t = threading.Thread(target=run, daemon=True) + t.start() + return t + +# ------------------------------- +# Sanity test runner +# ------------------------------- + +async def run_sanity(): + udp_host, udp_port = "127.0.0.1", 23456 + tcp_host, tcp_port = "127.0.0.1", 23457 + + # Start servers + start_udp_server(udp_host, udp_port) + start_tcp_server(tcp_host, tcp_port, delimiter="\n") + await asyncio.sleep(0.2) # small delay to ensure servers are listening + + # Transports + udp_transport = UDPTransport() + tcp_transport = TCPTransport() + + # Register manuals + udp_manual_template = UDPProvider(name="udp", host=udp_host, port=udp_port, request_data_format="json", response_byte_format="utf-8", number_of_response_datagrams=1, timeout=3000) + tcp_manual_template = TCPProvider(name="tcp", host=tcp_host, port=tcp_port, request_data_format="json", response_byte_format="utf-8", framing_strategy="delimiter", message_delimiter="\n", timeout=3000) + + udp_reg = await udp_transport.register_manual(None, udp_manual_template) + tcp_reg = await tcp_transport.register_manual(None, tcp_manual_template) + + print("UDP register success:", udp_reg.success, "tools:", len(udp_reg.manual.tools)) + print("TCP register success:", tcp_reg.success, "tools:", len(tcp_reg.manual.tools)) + + assert udp_reg.success and len(udp_reg.manual.tools) == 1 + assert tcp_reg.success and len(tcp_reg.manual.tools) == 1 + + # Verify tool_call_template present + assert udp_reg.manual.tools[0].tool_call_template.call_template_type == "udp" + assert tcp_reg.manual.tools[0].tool_call_template.call_template_type == "tcp" + + # Call tools + udp_result = await udp_transport.call_tool(None, "udp.echo", {"text": "hello", "extra": 42}, udp_reg.manual.tools[0].tool_call_template) + tcp_result = await tcp_transport.call_tool(None, "tcp.echo", {"text": "world", "extra": 99}, tcp_reg.manual.tools[0].tool_call_template) + + print("UDP call result:", udp_result) + print("TCP call result:", tcp_result) + + # Basic assertions on response shape + def ensure_dict(s): + if isinstance(s, (bytes, bytearray)): + try: + s = s.decode("utf-8") + except Exception: + return {} + if isinstance(s, str): + try: + return json.loads(s) + except Exception: + return {"raw": s} + return s if isinstance(s, dict) else {} + + udp_resp = ensure_dict(udp_result) + tcp_resp = ensure_dict(tcp_result) + + assert udp_resp.get("ok") is True and udp_resp.get("echo") == "hello" + assert tcp_resp.get("ok") is True and tcp_resp.get("echo") == "world" + + print("Sanity check passed: UDP/TCP discovery and calls work with tool_call_template normalization.") + +if __name__ == "__main__": + asyncio.run(run_sanity()) \ No newline at end of file diff --git a/test_websocket_manual.py b/test_websocket_manual.py deleted file mode 100644 index a1457c4..0000000 --- a/test_websocket_manual.py +++ /dev/null @@ -1,201 +0,0 @@ -#!/usr/bin/env python3 -""" -Manual test script for WebSocket transport implementation. -This tests the core functionality without requiring pytest setup. -""" - -import asyncio -import sys -import os - -# Add src to path for imports -sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) - -from utcp.client.transport_interfaces.websocket_transport import WebSocketClientTransport -from utcp.shared.provider import WebSocketProvider -from utcp.shared.auth import ApiKeyAuth, BasicAuth - - -async def test_basic_functionality(): - """Test basic WebSocket transport functionality""" - print("Testing WebSocket Transport Implementation...") - - transport = WebSocketClientTransport() - - # Test 1: Security enforcement - print("\n1. Testing security enforcement...") - try: - insecure_provider = WebSocketProvider( - name="insecure", - url="ws://example.com/ws" # Should be rejected - ) - await transport.register_tool_provider(insecure_provider) - print("āŒ FAILED: Insecure URL was accepted") - except ValueError as e: - if "Security error" in str(e): - print("āœ… PASSED: Insecure URL properly rejected") - else: - print(f"āŒ FAILED: Wrong error: {e}") - except Exception as e: - print(f"āŒ FAILED: Unexpected error: {e}") - - # Test 2: Provider type validation - print("\n2. Testing provider type validation...") - try: - from utcp.shared.provider import HttpProvider - wrong_provider = HttpProvider(name="wrong", url="https://example.com") - await transport.register_tool_provider(wrong_provider) - print("āŒ FAILED: Wrong provider type was accepted") - except ValueError as e: - if "WebSocketClientTransport can only be used with WebSocketProvider" in str(e): - print("āœ… PASSED: Provider type validation works") - else: - print(f"āŒ FAILED: Wrong error: {e}") - except Exception as e: - print(f"āŒ FAILED: Unexpected error: {e}") - - # Test 3: Authentication header preparation - print("\n3. Testing authentication...") - try: - # Test API Key auth - api_provider = WebSocketProvider( - name="api_test", - url="wss://example.com/ws", - auth=ApiKeyAuth( - var_name="X-API-Key", - api_key="test-key-123", - location="header" - ) - ) - headers = await transport._prepare_headers(api_provider) - if headers.get("X-API-Key") == "test-key-123": - print("āœ… PASSED: API Key authentication headers prepared correctly") - else: - print(f"āŒ FAILED: API Key headers incorrect: {headers}") - - # Test Basic auth - basic_provider = WebSocketProvider( - name="basic_test", - url="wss://example.com/ws", - auth=BasicAuth(username="user", password="pass") - ) - headers = await transport._prepare_headers(basic_provider) - if "Authorization" in headers and headers["Authorization"].startswith("Basic "): - print("āœ… PASSED: Basic authentication headers prepared correctly") - else: - print(f"āŒ FAILED: Basic auth headers incorrect: {headers}") - - except Exception as e: - print(f"āŒ FAILED: Authentication test error: {e}") - - # Test 4: Connection management - print("\n4. Testing connection management...") - try: - localhost_provider = WebSocketProvider( - name="test_provider", - url="ws://localhost:8765/ws" - ) - - # This should fail to connect but not due to security - try: - await transport.register_tool_provider(localhost_provider) - print("āŒ FAILED: Connection should have failed (no server)") - except ValueError as e: - if "Security error" in str(e): - print("āŒ FAILED: Security error on localhost") - else: - print("ā“ UNEXPECTED: Different error occurred") - except Exception as e: - # Expected - connection refused or similar - print("āœ… PASSED: Connection management works (failed to connect as expected)") - - except Exception as e: - print(f"āŒ FAILED: Connection test error: {e}") - - # Test 5: Cleanup - print("\n5. Testing cleanup...") - try: - await transport.close() - if len(transport._connections) == 0 and len(transport._oauth_tokens) == 0: - print("āœ… PASSED: Cleanup successful") - else: - print("āŒ FAILED: Cleanup incomplete") - except Exception as e: - print(f"āŒ FAILED: Cleanup error: {e}") - - print("\nāœ… WebSocket transport basic functionality tests completed!") - - -async def test_with_mock_server(): - """Test with a real WebSocket connection to our mock server""" - print("\n" + "="*50) - print("Testing with Mock WebSocket Server") - print("="*50) - - # Import and start mock server - sys.path.append('tests/client/transport_interfaces') - try: - from mock_websocket_server import create_app - from aiohttp import web - - print("Starting mock WebSocket server...") - app = await create_app() - runner = web.AppRunner(app) - await runner.setup() - site = web.TCPSite(runner, 'localhost', 8765) - await site.start() - - print("Mock server started on ws://localhost:8765/ws") - - # Test with our transport - transport = WebSocketClientTransport() - provider = WebSocketProvider( - name="test_provider", - url="ws://localhost:8765/ws" - ) - - try: - # Test tool discovery - print("\nTesting tool discovery...") - tools = await transport.register_tool_provider(provider) - print(f"āœ… Discovered {len(tools)} tools:") - for tool in tools: - print(f" - {tool.name}: {tool.description}") - - # Test tool execution - print("\nTesting tool execution...") - result = await transport.call_tool("echo", {"message": "Hello WebSocket!"}, provider) - print(f"āœ… Echo result: {result}") - - result = await transport.call_tool("add_numbers", {"a": 5, "b": 3}, provider) - print(f"āœ… Add result: {result}") - - # Test error handling - print("\nTesting error handling...") - try: - await transport.call_tool("simulate_error", {"error_message": "Test error"}, provider) - print("āŒ FAILED: Error tool should have failed") - except RuntimeError as e: - print(f"āœ… Error properly handled: {e}") - - except Exception as e: - print(f"āŒ Transport test failed: {e}") - finally: - await transport.close() - await runner.cleanup() - print("Mock server stopped") - - except ImportError as e: - print(f"āš ļø Mock server test skipped (missing dependencies): {e}") - except Exception as e: - print(f"āŒ Mock server test failed: {e}") - - -async def main(): - """Run all manual tests""" - await test_basic_functionality() - # await test_with_mock_server() # Uncomment if you want to test with real server - - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file