diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..87de8e5 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,113 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +This is the Python implementation of the Universal Tool Calling Protocol (UTCP), a flexible and scalable standard for defining and interacting with tools across various communication protocols. UTCP emphasizes scalability, interoperability, and ease of use compared to other protocols like MCP. + +## Development Commands + +### Building and Installation +```bash +# Create virtual environment and install dependencies +conda create --name utcp python=3.10 +conda activate utcp +pip install -r requirements.txt +python -m pip install --upgrade pip + +# Build the package +python -m build + +# Install locally +pip install dist/utcp-.tar.gz +``` + +### Testing +```bash +# Run all tests +pytest + +# Run tests with coverage +pytest --cov=src/utcp + +# Run specific plugin tests +pytest plugins/communication_protocols/http/tests/ +pytest plugins/communication_protocols/websocket/tests/ +``` + +### Development Dependencies +- Install dev dependencies: `pip install -e .[dev]` +- Key dev tools: pytest, pytest-asyncio, pytest-aiohttp, pytest-cov, coverage, fastapi, uvicorn + +## Architecture Overview + +### Core Components + +**Client Architecture (`src/utcp/client/`)**: +- `UtcpClient`: Main entry point for UTCP ecosystem interaction +- `UtcpClientConfig`: Pydantic model for client configuration +- `ClientTransportInterface`: Abstract base for transport implementations +- `ToolRepository`: Interface for storing/retrieving tools (default: `InMemToolRepository`) +- `ToolSearchStrategy`: Interface for tool search algorithms (default: `TagSearchStrategy`) + +**Shared Models (`src/utcp/shared/`)**: +- `Tool`: Core tool definition with inputs/outputs schemas +- `Provider`: Defines communication protocols for tools +- `UtcpManual`: Contains discovery information for tool collections +- `Auth`: Authentication models (API key, Basic, OAuth2) + +**Transport Layer (`src/utcp/client/transport_interfaces/`)**: +Each transport handles protocol-specific communication: +- `HttpClientTransport`: RESTful HTTP/HTTPS APIs +- `CliTransport`: Command Line Interface tools +- `SSEClientTransport`: Server-Sent Events +- `StreamableHttpClientTransport`: HTTP chunked transfer +- `MCPTransport`: Model Context Protocol interoperability +- `TextTransport`: Local file-based tool definitions +- `GraphQLClientTransport`: GraphQL APIs + +### Key Design Patterns + +**Provider Registration**: Tools are discovered via `UtcpManual` objects from providers, then registered in the client's `ToolRepository`. + +**Namespaced Tool Calling**: Tools are called using format `provider_name.tool_name` to avoid naming conflicts. + +**OpenAPI Auto-conversion**: HTTP providers can point to OpenAPI v3 specs for automatic tool generation. + +**Extensible Authentication**: Support for API keys, Basic auth, and OAuth2 with per-provider configuration. + +## Configuration + +### Provider Configuration +Tools are configured via `providers.json` files that specify: +- Provider name and type +- Connection details (URL, method, etc.) +- Authentication configuration +- Tool discovery endpoints + +### Client Initialization +```python +client = await UtcpClient.create( + config={ + "providers_file_path": "./providers.json", + "load_variables_from": [{"type": "dotenv", "env_file_path": ".env"}] + } +) +``` + +## File Structure + +- `src/utcp/client/`: Client implementation and transport interfaces +- `src/utcp/shared/`: Shared models and utilities +- `tests/`: Comprehensive test suite with transport-specific tests +- `example/`: Complete usage examples including LLM integration +- `scripts/`: Utility scripts for OpenAPI conversion and API fetching + +## Important Implementation Notes + +- All async operations use `asyncio` +- Pydantic models throughout for validation and serialization +- Transport interfaces are protocol-agnostic and swappable +- Tool search supports tag-based ranking and keyword matching +- Variable substitution in configuration supports environment variables and .env files \ No newline at end of file diff --git a/README.md b/README.md index 6b899ac..6b520f5 100644 --- a/README.md +++ b/README.md @@ -86,6 +86,7 @@ UTCP supports multiple communication protocols through dedicated plugins: | [`utcp-cli`](plugins/communication_protocols/cli/) | Command-line tools | āœ… Stable | [CLI Plugin README](plugins/communication_protocols/cli/README.md) | | [`utcp-mcp`](plugins/communication_protocols/mcp/) | Model Context Protocol | āœ… Stable | [MCP Plugin README](plugins/communication_protocols/mcp/README.md) | | [`utcp-text`](plugins/communication_protocols/text/) | Local file-based tools | āœ… Stable | [Text Plugin README](plugins/communication_protocols/text/README.md) | +| [`utcp-websocket`](plugins/communication_protocols/websocket/) | WebSocket real-time bidirectional communication | āœ… Stable | [WebSocket Plugin README](plugins/communication_protocols/websocket/README.md) | | [`utcp-socket`](plugins/communication_protocols/socket/) | TCP/UDP protocols | 🚧 In Progress | [Socket Plugin README](plugins/communication_protocols/socket/README.md) | | [`utcp-gql`](plugins/communication_protocols/gql/) | GraphQL APIs | 🚧 In Progress | [GraphQL Plugin README](plugins/communication_protocols/gql/README.md) | diff --git a/example/src/websocket_example/README.md b/example/src/websocket_example/README.md new file mode 100644 index 0000000..22c236c --- /dev/null +++ b/example/src/websocket_example/README.md @@ -0,0 +1,87 @@ +# WebSocket Transport Example + +This example demonstrates how to use the UTCP WebSocket transport for real-time communication. + +## Overview + +The WebSocket transport provides: +- Real-time bidirectional communication +- Tool discovery via WebSocket handshake +- Streaming tool execution +- Authentication support (API Key, Basic Auth, OAuth2) +- Automatic reconnection and keep-alive + +## Files + +- `websocket_server.py` - Mock WebSocket server implementing UTCP protocol +- `websocket_client.py` - Client example using WebSocket transport +- `providers.json` - WebSocket provider configuration + +## Protocol + +The UTCP WebSocket protocol uses JSON messages: + +### Tool Discovery +```json +// Client sends: +{"type": "discover", "request_id": "unique_id"} + +// Server responds: +{ + "type": "discovery_response", + "request_id": "unique_id", + "tools": [...] +} +``` + +### Tool Execution +```json +// Client sends: +{ + "type": "call_tool", + "request_id": "unique_id", + "tool_name": "tool_name", + "arguments": {...} +} + +// Server responds: +{ + "type": "tool_response", + "request_id": "unique_id", + "result": {...} +} +``` + +## Running the Example + +1. Start the mock WebSocket server: +```bash +python websocket_server.py +``` + +2. In another terminal, run the client: +```bash +python websocket_client.py +``` + +## Configuration + +The `providers.json` shows how to configure WebSocket providers with authentication: + +```json +[ + { + "name": "websocket_tools", + "provider_type": "websocket", + "url": "ws://localhost:8765/ws", + "auth": { + "auth_type": "api_key", + "api_key": "your-api-key", + "var_name": "X-API-Key", + "location": "header" + }, + "keep_alive": true, + "protocol": "utcp-v1" + } +] +``` \ No newline at end of file diff --git a/example/src/websocket_example/providers.json b/example/src/websocket_example/providers.json new file mode 100644 index 0000000..101be96 --- /dev/null +++ b/example/src/websocket_example/providers.json @@ -0,0 +1,11 @@ +[ + { + "name": "websocket_tools", + "provider_type": "websocket", + "url": "ws://localhost:8765/ws", + "keep_alive": true, + "headers": { + "User-Agent": "UTCP-WebSocket-Client/1.0" + } + } +] \ No newline at end of file diff --git a/example/src/websocket_example/websocket_client.py b/example/src/websocket_example/websocket_client.py new file mode 100644 index 0000000..df0b444 --- /dev/null +++ b/example/src/websocket_example/websocket_client.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python3 +""" +WebSocket client example demonstrating UTCP WebSocket transport. + +This example shows how to: +1. Create a UTCP client with WebSocket transport +2. Discover tools from a WebSocket provider +3. Execute tools via WebSocket +4. Handle real-time responses + +Make sure to run websocket_server.py first! +""" + +import asyncio +import json +import logging +from utcp.client import UtcpClient + + +async def demonstrate_websocket_tools(): + """Demonstrate WebSocket transport capabilities""" + print("šŸš€ UTCP WebSocket Client Example") + print("=" * 50) + + # Create UTCP client with WebSocket provider + print("šŸ“” Connecting to WebSocket provider...") + client = await UtcpClient.create( + config={"providers_file_path": "./providers.json"} + ) + + try: + # Discover available tools + print("\nšŸ” Discovering available tools...") + all_tools = await client.get_all_tools() + websocket_tools = [tool for tool in all_tools if tool.tool_provider.provider_type == "websocket"] + + print(f"Found {len(websocket_tools)} WebSocket tools:") + for tool in websocket_tools: + print(f" • {tool.name}: {tool.description}") + if tool.tags: + print(f" Tags: {', '.join(tool.tags)}") + + if not websocket_tools: + print("āŒ No WebSocket tools found. Make sure websocket_server.py is running!") + return + + print("\n" + "=" * 50) + print("šŸ› ļø Testing WebSocket tools...") + + # Test echo tool + print("\n1ļøāƒ£ Testing echo tool:") + result = await client.call_tool( + "websocket_tools.echo", + {"message": "Hello from UTCP WebSocket client! šŸ‘‹"} + ) + print(f" Echo result: {result}") + + # Test calculator + print("\n2ļøāƒ£ Testing calculator tool:") + calculations = [ + {"operation": "add", "a": 15, "b": 25}, + {"operation": "multiply", "a": 7, "b": 8}, + {"operation": "divide", "a": 100, "b": 4} + ] + + for calc in calculations: + result = await client.call_tool("websocket_tools.calculate", calc) + op = calc["operation"] + a, b = calc["a"], calc["b"] + print(f" {a} {op} {b} = {result['result']}") + + # Test time tool + print("\n3ļøāƒ£ Testing time tool:") + formats = ["timestamp", "iso", "human"] + for fmt in formats: + result = await client.call_tool("websocket_tools.get_time", {"format": fmt}) + print(f" {fmt} format: {result['time']}") + + # Test error handling + print("\n4ļøāƒ£ Testing error handling:") + try: + await client.call_tool( + "websocket_tools.simulate_error", + {"error_type": "validation", "message": "This is a test error"} + ) + except Exception as e: + print(f" āœ… Error properly caught: {e}") + + # Test tool search + print("\nšŸ”Ž Testing tool search...") + math_tools = await client.search_tools("math calculation") + print(f"Found {len(math_tools)} tools for 'math calculation':") + for tool in math_tools: + print(f" • {tool.name} (score: {getattr(tool, 'score', 'N/A')})") + + print("\nāœ… All WebSocket transport tests completed successfully!") + + except Exception as e: + print(f"āŒ Error during demonstration: {e}") + import traceback + traceback.print_exc() + + finally: + # Clean up + await client.close() + print("\nšŸ”Œ WebSocket connection closed") + + +async def interactive_mode(): + """Interactive mode for manual testing""" + print("\n" + "=" * 50) + print("šŸŽ® Interactive Mode") + print("Type 'help' for commands, 'exit' to quit") + + client = await UtcpClient.create( + config={"providers_file_path": "./providers.json"} + ) + + try: + while True: + try: + command = input("\n> ").strip() + + if command.lower() in ['exit', 'quit', 'q']: + break + elif command.lower() == 'help': + print(""" +Available commands: + list - List all available tools + call - Call a tool with JSON arguments + search - Search for tools + help - Show this help + exit - Exit interactive mode + +Examples: + call websocket_tools.echo {"message": "Hello!"} + call websocket_tools.calculate {"operation": "add", "a": 5, "b": 3} + search math + """) + elif command.startswith('list'): + tools = await client.get_all_tools() + ws_tools = [t for t in tools if t.tool_provider.provider_type == "websocket"] + for tool in ws_tools: + print(f" {tool.name}: {tool.description}") + + elif command.startswith('call '): + parts = command[5:].split(' ', 1) + if len(parts) != 2: + print("Usage: call ") + continue + + tool_name, args_str = parts + try: + args = json.loads(args_str) + result = await client.call_tool(tool_name, args) + print(f"Result: {json.dumps(result, indent=2)}") + except json.JSONDecodeError: + print("Error: Invalid JSON arguments") + except Exception as e: + print(f"Error: {e}") + + elif command.startswith('search '): + query = command[7:] + tools = await client.search_tools(query) + print(f"Found {len(tools)} tools:") + for tool in tools: + print(f" {tool.name}: {tool.description}") + + else: + print("Unknown command. Type 'help' for available commands.") + + except KeyboardInterrupt: + break + except Exception as e: + print(f"Error: {e}") + + finally: + await client.close() + + +async def main(): + """Main entry point""" + # Setup logging + logging.basicConfig(level=logging.INFO) + + try: + # Run demonstration + await demonstrate_websocket_tools() + + # Ask if user wants interactive mode + if input("\nšŸŽ® Enter interactive mode? (y/N): ").lower().startswith('y'): + await interactive_mode() + + except KeyboardInterrupt: + print("\nšŸ‘‹ Goodbye!") + except Exception as e: + print(f"āŒ Fatal error: {e}") + import traceback + traceback.print_exc() + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/example/src/websocket_example/websocket_server.py b/example/src/websocket_example/websocket_server.py new file mode 100644 index 0000000..eae2700 --- /dev/null +++ b/example/src/websocket_example/websocket_server.py @@ -0,0 +1,348 @@ +#!/usr/bin/env python3 +""" +Mock WebSocket server implementing UTCP protocol for demonstration. + +This server provides several example tools accessible via WebSocket: +- echo: Echo back messages +- calculate: Perform basic math operations +- get_time: Return current timestamp +- simulate_error: Demonstrate error handling + +Run this server and then use websocket_client.py to interact with it. +""" + +import asyncio +import json +import logging +import time +from aiohttp import web, WSMsgType +from aiohttp.web import Application, WebSocketResponse + + +class UTCPWebSocketServer: + """WebSocket server implementing UTCP protocol""" + + def __init__(self): + self.logger = logging.getLogger(__name__) + self.tools = self._define_tools() + + def _define_tools(self): + """Define the tools available on this server""" + return [ + { + "name": "echo", + "description": "Echo back the input message", + "inputs": { + "type": "object", + "properties": { + "message": { + "type": "string", + "description": "The message to echo back" + } + }, + "required": ["message"] + }, + "outputs": { + "type": "object", + "properties": { + "echo": {"type": "string"} + } + }, + "tags": ["utility", "test"] + }, + { + "name": "calculate", + "description": "Perform basic mathematical operations", + "inputs": { + "type": "object", + "properties": { + "operation": { + "type": "string", + "enum": ["add", "subtract", "multiply", "divide"], + "description": "The operation to perform" + }, + "a": { + "type": "number", + "description": "First operand" + }, + "b": { + "type": "number", + "description": "Second operand" + } + }, + "required": ["operation", "a", "b"] + }, + "outputs": { + "type": "object", + "properties": { + "result": {"type": "number"} + } + }, + "tags": ["math", "calculation"] + }, + { + "name": "get_time", + "description": "Get the current server time", + "inputs": { + "type": "object", + "properties": { + "format": { + "type": "string", + "enum": ["timestamp", "iso", "human"], + "description": "Time format to return" + } + } + }, + "outputs": { + "type": "object", + "properties": { + "time": {"type": "string"}, + "timestamp": {"type": "number"} + } + }, + "tags": ["time", "utility"] + }, + { + "name": "simulate_error", + "description": "Simulate an error for testing error handling", + "inputs": { + "type": "object", + "properties": { + "error_type": { + "type": "string", + "enum": ["validation", "runtime", "custom"], + "description": "Type of error to simulate" + }, + "message": { + "type": "string", + "description": "Custom error message" + } + } + }, + "outputs": { + "type": "object", + "properties": {} + }, + "tags": ["test", "error"] + } + ] + + async def websocket_handler(self, request): + """Handle WebSocket connections""" + ws = WebSocketResponse() + await ws.prepare(request) + + # Get client info safely + peername = request.transport.get_extra_info('peername') if request.transport else None + if peername and len(peername) > 1: + client_info = f"{request.remote}:{peername[1]}" + else: + client_info = str(request.remote) if request.remote else 'unknown' + self.logger.info(f"WebSocket connection from {client_info}") + + # Log any authentication headers + auth_header = request.headers.get('Authorization') + if auth_header: + self.logger.info("Authentication header provided") + + api_key = request.headers.get('X-API-Key') + if api_key: + self.logger.info("API Key header provided") + + try: + async for msg in ws: + if msg.type == WSMsgType.TEXT: + await self._handle_message(ws, msg.data, client_info) + elif msg.type == WSMsgType.ERROR: + self.logger.error(f"WebSocket error: {ws.exception()}") + break + except Exception as e: + self.logger.error(f"Error in WebSocket handler: {e}") + finally: + self.logger.info(f"WebSocket connection closed: {client_info}") + + return ws + + async def _handle_message(self, ws, data, client_info): + """Handle incoming WebSocket messages""" + try: + message = json.loads(data) + message_type = message.get("type") + request_id = message.get("request_id") + + self.logger.info(f"[{client_info}] Received {message_type} (ID: {request_id})") + + if message_type == "discover": + await self._handle_discovery(ws, request_id) + elif message_type == "call_tool": + await self._handle_tool_call(ws, message, client_info) + else: + await self._send_error(ws, request_id, f"Unknown message type: {message_type}") + + except json.JSONDecodeError as e: + self.logger.error(f"[{client_info}] Invalid JSON: {e}") + await self._send_error(ws, None, "Invalid JSON message") + except Exception as e: + self.logger.error(f"[{client_info}] Error handling message: {e}") + await self._send_error(ws, None, f"Internal server error: {str(e)}") + + async def _handle_discovery(self, ws, request_id): + """Handle tool discovery requests""" + response = { + "type": "discovery_response", + "request_id": request_id, + "tools": self.tools + } + await ws.send_str(json.dumps(response)) + self.logger.info(f"Sent discovery response with {len(self.tools)} tools") + + async def _handle_tool_call(self, ws, message, client_info): + """Handle tool execution requests""" + tool_name = message.get("tool_name") + arguments = message.get("arguments", {}) + request_id = message.get("request_id") + + self.logger.info(f"[{client_info}] Executing {tool_name}: {arguments}") + + try: + result = await self._execute_tool(tool_name, arguments) + response = { + "type": "tool_response", + "request_id": request_id, + "result": result + } + await ws.send_str(json.dumps(response)) + self.logger.info(f"[{client_info}] Tool {tool_name} completed successfully") + + except Exception as e: + self.logger.error(f"[{client_info}] Tool {tool_name} failed: {e}") + await self._send_tool_error(ws, request_id, str(e)) + + async def _execute_tool(self, tool_name, arguments): + """Execute a specific tool""" + if tool_name == "echo": + message = arguments.get("message", "") + return {"echo": message} + + elif tool_name == "calculate": + operation = arguments.get("operation") + a = arguments.get("a", 0) + b = arguments.get("b", 0) + + if operation == "add": + result = a + b + elif operation == "subtract": + result = a - b + elif operation == "multiply": + result = a * b + elif operation == "divide": + if b == 0: + raise ValueError("Division by zero") + result = a / b + else: + raise ValueError(f"Unknown operation: {operation}") + + return {"result": result} + + elif tool_name == "get_time": + format_type = arguments.get("format", "timestamp") + current_time = time.time() + + if format_type == "timestamp": + return {"time": str(current_time), "timestamp": current_time} + elif format_type == "iso": + from datetime import datetime + iso_time = datetime.fromtimestamp(current_time).isoformat() + return {"time": iso_time, "timestamp": current_time} + elif format_type == "human": + from datetime import datetime + human_time = datetime.fromtimestamp(current_time).strftime("%Y-%m-%d %H:%M:%S") + return {"time": human_time, "timestamp": current_time} + else: + raise ValueError(f"Unknown format: {format_type}") + + elif tool_name == "simulate_error": + error_type = arguments.get("error_type", "runtime") + custom_message = arguments.get("message", "Simulated error") + + if error_type == "validation": + raise ValueError(f"Validation error: {custom_message}") + elif error_type == "runtime": + raise RuntimeError(f"Runtime error: {custom_message}") + elif error_type == "custom": + raise Exception(custom_message) + else: + raise ValueError(f"Unknown error type: {error_type}") + else: + raise ValueError(f"Unknown tool: {tool_name}") + + async def _send_error(self, ws, request_id, error_message): + """Send a general error response""" + response = { + "type": "error", + "request_id": request_id, + "error": error_message + } + await ws.send_str(json.dumps(response)) + + async def _send_tool_error(self, ws, request_id, error_message): + """Send a tool-specific error response""" + response = { + "type": "tool_error", + "request_id": request_id, + "error": error_message + } + await ws.send_str(json.dumps(response)) + + +async def create_app(): + """Create the aiohttp application""" + app = Application() + server = UTCPWebSocketServer() + + # WebSocket endpoint + app.router.add_get('/ws', server.websocket_handler) + + # Health check endpoint + async def health_check(request): + return web.json_response({ + "status": "ok", + "service": "utcp-websocket-server", + "tools_available": len(server.tools) + }) + + app.router.add_get('/health', health_check) + + return app + + +async def main(): + """Run the WebSocket server""" + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + ) + + app = await create_app() + runner = web.AppRunner(app) + await runner.setup() + + site = web.TCPSite(runner, 'localhost', 8765) + await site.start() + + print("šŸš€ UTCP WebSocket Server running!") + print("šŸ“” WebSocket: ws://localhost:8765/ws") + print("šŸ” Health check: http://localhost:8765/health") + print("šŸ“š Available tools: echo, calculate, get_time, simulate_error") + print("ā¹ļø Press Ctrl+C to stop") + + try: + await asyncio.Future() # Run forever + except KeyboardInterrupt: + print("\nā¹ļø Shutting down server...") + finally: + await runner.cleanup() + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/plugins/communication_protocols/websocket/README.md b/plugins/communication_protocols/websocket/README.md new file mode 100644 index 0000000..8daa32a --- /dev/null +++ b/plugins/communication_protocols/websocket/README.md @@ -0,0 +1,408 @@ +# UTCP WebSocket Plugin + +WebSocket communication protocol plugin for UTCP, enabling real-time bidirectional communication with **maximum flexibility** to support ANY WebSocket endpoint format. + +## Key Feature: Maximum Flexibility + +**The WebSocket plugin is designed to work with ANY existing WebSocket endpoint without modification.** + +Unlike other implementations that enforce specific message structures, this plugin: +- āœ… **No enforced request format**: Use `message` templates with `UTCP_ARG_arg_name_UTCP_ARG` placeholders +- āœ… **No enforced response format**: Returns raw responses by default +- āœ… **Works with existing endpoints**: No need to modify your WebSocket servers +- āœ… **Flexible templating**: Support dict or string message templates + +This addresses the UTCP principle: "Talk to as many WebSocket endpoints as possible." + +## Features + +- āœ… **Maximum Flexibility**: Works with ANY WebSocket endpoint without modification +- āœ… **Flexible Message Templates**: Dict or string templates with `UTCP_ARG_arg_name_UTCP_ARG` placeholders +- āœ… **No Enforced Structure**: Send/receive messages in any format +- āœ… **Real-time Communication**: Bidirectional WebSocket connections +- āœ… **Multiple Authentication**: API Key, Basic Auth, and OAuth2 support +- āœ… **Connection Management**: Keep-alive, reconnection, and connection pooling +- āœ… **Streaming Support**: Both single-response and streaming execution +- āœ… **Security Enforced**: WSS required (or ws://localhost for development) + +## Installation + +```bash +pip install utcp-websocket +``` + +For development: + +```bash +pip install -e plugins/communication_protocols/websocket +``` + +## Quick Start + +### Basic Usage (No Template - Maximum Flexibility) + +```python +from utcp.utcp_client import UtcpClient + +# Works with ANY WebSocket endpoint - just sends arguments as JSON +client = await UtcpClient.create(config={ + "manual_call_templates": [{ + "name": "my_websocket", + "call_template_type": "websocket", + "url": "wss://api.example.com/ws" + }] +}) + +# Sends: {"user_id": "123", "action": "getData"} +result = await client.call_tool("my_websocket.get_data", { + "user_id": "123", + "action": "getData" +}) +``` + +### With Message Template (Dict) + +```python +{ + "name": "formatted_ws", + "call_template_type": "websocket", + "url": "wss://api.example.com/ws", + "message": { + "type": "request", + "action": "UTCP_ARG_action_UTCP_ARG", + "params": { + "user_id": "UTCP_ARG_user_id_UTCP_ARG", + "query": "UTCP_ARG_query_UTCP_ARG" + } + } +} +``` + +Calling with `{"action": "search", "user_id": "123", "query": "test"}` sends: +```json +{ + "type": "request", + "action": "search", + "params": { + "user_id": "123", + "query": "test" + } +} +``` + +### With Message Template (String) + +```python +{ + "name": "text_ws", + "call_template_type": "websocket", + "url": "wss://iot.example.com/ws", + "message": "CMD:UTCP_ARG_command_UTCP_ARG;DEVICE:UTCP_ARG_device_id_UTCP_ARG;VALUE:UTCP_ARG_value_UTCP_ARG" +} +``` + +Calling with `{"command": "SET_TEMP", "device_id": "dev123", "value": "25"}` sends: +``` +CMD:SET_TEMP;DEVICE:dev123;VALUE:25 +``` + +## Configuration Options + +### WebSocketCallTemplate Fields + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `call_template_type` | string | Yes | `"websocket"` | Must be "websocket" | +| `url` | string | Yes | - | WebSocket URL (wss:// or ws://localhost) | +| `message` | string\|dict | No | `null` | Message template with UTCP_ARG_arg_name_UTCP_ARG placeholders | +| `response_format` | string | No | `null` | Expected response format ("json", "text", "raw") | +| `protocol` | string | No | `null` | WebSocket subprotocol | +| `keep_alive` | boolean | No | `true` | Enable persistent connection with heartbeat | +| `timeout` | integer | No | `30` | Timeout in seconds | +| `headers` | object | No | `null` | Static headers for handshake | +| `header_fields` | array | No | `null` | Tool arguments to map to headers | +| `auth` | object | No | `null` | Authentication configuration | + +## Message Templating + +### No Template (Default - Maximum Flexibility) + +If `message` is not specified, arguments are sent as-is in JSON format: + +```python +# Config +{"call_template_type": "websocket", "url": "wss://api.example.com/ws"} + +# Call +await client.call_tool("ws.tool", {"foo": "bar", "baz": 123}) + +# Sends exactly: +{"foo": "bar", "baz": 123} +``` + +This works with **any** WebSocket endpoint that accepts JSON. + +### Dict Template + +Use dict templates for structured messages: + +```python +{ + "message": { + "jsonrpc": "2.0", + "method": "UTCP_ARG_method_UTCP_ARG", + "params": "UTCP_ARG_params_UTCP_ARG", + "id": 1 + } +} +``` + +### String Template + +Use string templates for text-based protocols: + +```python +{ + "message": "GET UTCP_ARG_resource_UTCP_ARG HTTP/1.1\r\nHost: UTCP_ARG_host_UTCP_ARG\r\n\r\n" +} +``` + +### Nested Templates + +Templates work recursively in dicts and lists: + +```python +{ + "message": { + "type": "command", + "data": { + "commands": ["UTCP_ARG_cmd1_UTCP_ARG", "UTCP_ARG_cmd2_UTCP_ARG"], + "metadata": { + "user": "UTCP_ARG_user_UTCP_ARG", + "timestamp": "2025-01-01" + } + } + } +} +``` + +## Response Handling + +### No Format Specification (Default) + +By default, responses are returned as-is (maximum flexibility): + +```python +# Returns whatever the WebSocket sends - could be JSON string, text, or binary +result = await client.call_tool("ws.tool", {...}) +``` + +### JSON Format + +Parse responses as JSON: + +```python +{ + "call_template_type": "websocket", + "url": "wss://api.example.com/ws", + "response_format": "json" +} +``` + +### Text Format + +Return responses as text strings: + +```python +{ + "response_format": "text" +} +``` + +### Raw Format + +Return responses without any processing: + +```python +{ + "response_format": "raw" +} +``` + +## Real-World Examples + +### Example 1: Stock Price WebSocket (No Template) + +Works with existing stock APIs without modification: + +```python +{ + "name": "stocks", + "call_template_type": "websocket", + "url": "wss://stream.example.com/stocks", + "auth": { + "auth_type": "api_key", + "api_key": "${STOCK_API_KEY}", + "var_name": "Authorization", + "location": "header" + } +} + +# Sends: {"symbol": "AAPL", "action": "subscribe"} +await client.call_tool("stocks.subscribe", { + "symbol": "AAPL", + "action": "subscribe" +}) +``` + +### Example 2: IoT Device Control (String Template) + +```python +{ + "name": "iot", + "call_template_type": "websocket", + "url": "wss://iot.example.com/devices", + "message": "DEVICE:UTCP_ARG_device_id_UTCP_ARG CMD:UTCP_ARG_command_UTCP_ARG VAL:UTCP_ARG_value_UTCP_ARG" +} + +# Sends: "DEVICE:light_01 CMD:SET_BRIGHTNESS VAL:75" +await client.call_tool("iot.control", { + "device_id": "light_01", + "command": "SET_BRIGHTNESS", + "value": "75" +}) +``` + +### Example 3: JSON-RPC WebSocket (Dict Template) + +```python +{ + "name": "jsonrpc", + "call_template_type": "websocket", + "url": "wss://rpc.example.com/ws", + "message": { + "jsonrpc": "2.0", + "method": "UTCP_ARG_method_UTCP_ARG", + "params": "UTCP_ARG_params_UTCP_ARG", + "id": 1 + }, + "response_format": "json" +} + +# Sends: {"jsonrpc": "2.0", "method": "getUser", "params": "{\"id\": 123}", "id": 1} +# Note: params is stringified since it's a non-string value in the template +result = await client.call_tool("jsonrpc.call", { + "method": "getUser", + "params": {"id": 123} +}) +``` + +### Example 4: Chat Application (Dict Template) + +```python +{ + "name": "chat", + "call_template_type": "websocket", + "url": "wss://chat.example.com/ws", + "message": { + "type": "message", + "channel": "UTCP_ARG_channel_UTCP_ARG", + "user": "UTCP_ARG_user_UTCP_ARG", + "text": "UTCP_ARG_text_UTCP_ARG", + "timestamp": "{{now}}" + } +} +``` + +## Authentication + +### API Key Authentication + +```python +{ + "auth": { + "auth_type": "api_key", + "api_key": "${API_KEY}", + "var_name": "Authorization", + "location": "header" + } +} +``` + +### Basic Authentication + +```python +{ + "auth": { + "auth_type": "basic", + "username": "${USERNAME}", + "password": "${PASSWORD}" + } +} +``` + +### OAuth2 Authentication + +```python +{ + "auth": { + "auth_type": "oauth2", + "client_id": "${CLIENT_ID}", + "client_secret": "${CLIENT_SECRET}", + "token_url": "https://auth.example.com/token", + "scope": "read write" + } +} +``` + +## Streaming Responses + +```python +async for chunk in client.call_tool_streaming("ws.stream", {"query": "data"}): + print(chunk) +``` + +## Security + +- **WSS Required**: Production URLs must use `wss://` for encrypted communication +- **Localhost Exception**: `ws://localhost` and `ws://127.0.0.1` allowed for development +- **Authentication**: Full support for API Key, Basic Auth, and OAuth2 +- **Token Caching**: OAuth2 tokens are cached for reuse; refresh must be handled by the service or manual re-auth. + +## Best Practices + +1. **Start Simple**: Don't use `message` template unless your endpoint requires specific format +2. **Use WSS in Production**: Always use `wss://` for secure connections +3. **Set Appropriate Timeouts**: Configure timeouts based on expected response times +4. **Test Without Template First**: Try without `message` template to see if it works +5. **Add Template Only When Needed**: Only add `message` template if endpoint requires specific structure + +## Comparison with Enforced Formats + +| Approach | Flexibility | Works with Existing Endpoints | +|----------|-------------|------------------------------| +| **UTCP WebSocket (This Plugin)** | āœ… Maximum | āœ… Yes - works with any endpoint | +| Enforced request/response structure | āŒ Limited | āŒ No - requires endpoint modification | +| UTCP-specific message format | āŒ Limited | āŒ No - only works with UTCP servers | + +## Testing + +Run tests: + +```bash +pytest plugins/communication_protocols/websocket/tests/ -v +``` + +With coverage: + +```bash +pytest plugins/communication_protocols/websocket/tests/ --cov=utcp_websocket --cov-report=term-missing +``` + +## Contributing + +Contributions are welcome! Please see the [main repository](https://github.com/universal-tool-calling-protocol/python-utcp) for contribution guidelines. + +## License + +Mozilla Public License 2.0 (MPL-2.0) diff --git a/plugins/communication_protocols/websocket/pyproject.toml b/plugins/communication_protocols/websocket/pyproject.toml new file mode 100644 index 0000000..5391418 --- /dev/null +++ b/plugins/communication_protocols/websocket/pyproject.toml @@ -0,0 +1,44 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "utcp-websocket" +version = "1.0.0" +authors = [ + { name = "UTCP Contributors" }, +] +description = "UTCP communication protocol plugin for WebSocket real-time bidirectional communication." +readme = "README.md" +requires-python = ">=3.10" +dependencies = [ + "pydantic>=2.0", + "aiohttp>=3.8", + "utcp>=1.0" +] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Programming Language :: Python :: 3", + "Operating System :: OS Independent", +] +license = "MPL-2.0" + +[project.optional-dependencies] +dev = [ + "build", + "pytest", + "pytest-asyncio", + "pytest-aiohttp", + "pytest-cov", + "coverage", + "twine", +] + +[project.urls] +Homepage = "https://utcp.io" +Source = "https://github.com/universal-tool-calling-protocol/python-utcp" +Issues = "https://github.com/universal-tool-calling-protocol/python-utcp/issues" + +[project.entry-points."utcp.plugins"] +websocket = "utcp_websocket:register" diff --git a/plugins/communication_protocols/websocket/src/utcp_websocket/__init__.py b/plugins/communication_protocols/websocket/src/utcp_websocket/__init__.py new file mode 100644 index 0000000..21c5879 --- /dev/null +++ b/plugins/communication_protocols/websocket/src/utcp_websocket/__init__.py @@ -0,0 +1,23 @@ +"""WebSocket Communication Protocol plugin for UTCP. + +This plugin provides WebSocket-based real-time bidirectional communication protocol. +""" + +from utcp.plugins.discovery import register_communication_protocol, register_call_template +from utcp_websocket.websocket_communication_protocol import WebSocketCommunicationProtocol +from utcp_websocket.websocket_call_template import WebSocketCallTemplate, WebSocketCallTemplateSerializer + +def register(): + """Register the WebSocket communication protocol and call template serializer.""" + # Register WebSocket communication protocol + register_communication_protocol("websocket", WebSocketCommunicationProtocol()) + + # Register call template serializer + register_call_template("websocket", WebSocketCallTemplateSerializer()) + +# Export public API +__all__ = [ + "WebSocketCommunicationProtocol", + "WebSocketCallTemplate", + "WebSocketCallTemplateSerializer", +] diff --git a/plugins/communication_protocols/websocket/src/utcp_websocket/websocket_call_template.py b/plugins/communication_protocols/websocket/src/utcp_websocket/websocket_call_template.py new file mode 100644 index 0000000..81dbb2c --- /dev/null +++ b/plugins/communication_protocols/websocket/src/utcp_websocket/websocket_call_template.py @@ -0,0 +1,165 @@ +from utcp.data.call_template import CallTemplate, CallTemplateSerializer +from utcp.data.auth import Auth, AuthSerializer +from utcp.interfaces.serializer import Serializer +from utcp.exceptions import UtcpSerializerValidationError +import traceback +from typing import Optional, Dict, List, Literal, Union, Any +from pydantic import Field, field_serializer, field_validator + +class WebSocketCallTemplate(CallTemplate): + """REQUIRED + Call template configuration for WebSocket-based tools. + + Supports real-time bidirectional communication via WebSocket protocol with + various message formats, authentication methods, and connection management features. + + Configuration Examples: + Basic WebSocket connection: + ```json + { + "name": "realtime_service", + "call_template_type": "websocket", + "url": "wss://api.example.com/ws" + } + ``` + + With authentication: + ```json + { + "name": "secure_websocket", + "call_template_type": "websocket", + "url": "wss://api.example.com/ws", + "auth": { + "auth_type": "api_key", + "api_key": "${WS_API_KEY}", + "var_name": "Authorization", + "location": "header" + }, + "keep_alive": true, + "protocol": "utcp-v1" + } + ``` + + Custom message format: + ```json + { + "name": "custom_format_ws", + "call_template_type": "websocket", + "url": "wss://api.example.com/ws", + "request_data_format": "text", + "request_data_template": "CMD:UTCP_ARG_command_UTCP_ARG;DATA:UTCP_ARG_data_UTCP_ARG", + "timeout": 60 + } + ``` + + Attributes: + call_template_type: Always "websocket" for WebSocket providers. + url: WebSocket URL (must be wss:// or ws://localhost). + message: Message template with UTCP_ARG_arg_name_UTCP_ARG placeholders for flexible formatting. + protocol: Optional WebSocket subprotocol to use. + keep_alive: Whether to maintain persistent connection with heartbeat. + response_format: Expected response format ("json", "text", or "raw"). If None, returns raw response. + timeout: Timeout in seconds for WebSocket operations. + headers: Optional static headers to include in WebSocket handshake. + header_fields: List of tool argument names to map to WebSocket handshake headers. + auth: Optional authentication configuration for WebSocket connection. + """ + call_template_type: Literal["websocket"] = Field(default="websocket") + url: str = Field(..., description="WebSocket URL (wss:// or ws://localhost)") + message: Optional[Union[str, Dict[str, Any]]] = Field( + default=None, + description="Message template. Can be a string or dict with UTCP_ARG_arg_name_UTCP_ARG placeholders" + ) + protocol: Optional[str] = Field(default=None, description="WebSocket subprotocol") + keep_alive: bool = Field(default=True, description="Enable persistent connection with heartbeat") + response_format: Optional[Literal["json", "text", "raw"]] = Field( + default=None, + description="Expected response format. If None, returns raw response" + ) + timeout: int = Field(default=30, description="Timeout in seconds for WebSocket operations") + headers: Optional[Dict[str, str]] = Field(default=None, description="Static headers for WebSocket handshake") + header_fields: Optional[List[str]] = Field(default=None, description="Tool arguments to map to headers") + + @field_validator("url") + @classmethod + def validate_url(cls, v: str) -> str: + """Validate WebSocket URL format.""" + if not (v.startswith("wss://") or v.startswith("ws://localhost") or v.startswith("ws://127.0.0.1")): + raise ValueError( + f"WebSocket URL must use wss:// or start with ws://localhost or ws://127.0.0.1. Got: {v}" + ) + return v + + @field_serializer("headers", when_used="unless-none") + def serialize_headers(self, headers: Optional[Dict[str, str]], _info): + return headers if headers else None + + @field_serializer("header_fields", when_used="unless-none") + def serialize_header_fields(self, header_fields: Optional[List[str]], _info): + return header_fields if header_fields else None + + +class WebSocketCallTemplateSerializer(Serializer[WebSocketCallTemplate]): + """REQUIRED + Serializer for WebSocket call templates. + + Handles conversion between WebSocketCallTemplate objects and dictionaries + for storage, transmission, and configuration parsing. + """ + + def to_dict(self, obj: WebSocketCallTemplate) -> dict: + """Convert WebSocketCallTemplate to dictionary. + + Args: + obj: The WebSocketCallTemplate object to convert. + + Returns: + Dictionary representation of the call template. + """ + result = { + "name": obj.name, + "call_template_type": obj.call_template_type, + "url": obj.url, + } + + if obj.message is not None: + result["message"] = obj.message + if obj.protocol is not None: + result["protocol"] = obj.protocol + if obj.keep_alive is not True: + result["keep_alive"] = obj.keep_alive + if obj.response_format is not None: + result["response_format"] = obj.response_format + if obj.timeout != 30: + result["timeout"] = obj.timeout + if obj.headers: + result["headers"] = obj.headers + if obj.header_fields: + result["header_fields"] = obj.header_fields + if obj.auth: + result["auth"] = AuthSerializer().to_dict(obj.auth) + + return result + + def validate_dict(self, obj: dict) -> WebSocketCallTemplate: + """Validate dictionary and convert to WebSocketCallTemplate. + + Args: + obj: Dictionary to validate and convert. + + Returns: + WebSocketCallTemplate object. + + Raises: + UtcpSerializerValidationError: If validation fails. + """ + try: + # Parse auth if present + if "auth" in obj and obj["auth"] is not None: + obj["auth"] = AuthSerializer().validate_dict(obj["auth"]) + + return WebSocketCallTemplate(**obj) + except Exception as e: + raise UtcpSerializerValidationError( + f"Failed to validate WebSocketCallTemplate: {str(e)}\n{traceback.format_exc()}" + ) diff --git a/plugins/communication_protocols/websocket/src/utcp_websocket/websocket_communication_protocol.py b/plugins/communication_protocols/websocket/src/utcp_websocket/websocket_communication_protocol.py new file mode 100644 index 0000000..48a1d21 --- /dev/null +++ b/plugins/communication_protocols/websocket/src/utcp_websocket/websocket_communication_protocol.py @@ -0,0 +1,447 @@ +"""WebSocket communication protocol implementation for UTCP client. + +This module provides the WebSocket communication protocol implementation that handles +real-time bidirectional communication with WebSocket-based tool providers. + +Key Features: + - Real-time bidirectional communication + - Multiple authentication methods (API key, Basic, OAuth2) + - Tool discovery via WebSocket handshake + - Connection pooling and keep-alive + - Security enforcement (WSS or localhost only) + - Custom message formats and templates +""" + +from typing import Dict, Any, Optional, Callable, AsyncGenerator +import asyncio +import json +import base64 +import aiohttp +from aiohttp import ClientWebSocketResponse, ClientSession +import logging + +from utcp.interfaces.communication_protocol import CommunicationProtocol +from utcp.data.call_template import CallTemplate +from utcp.data.tool import Tool +from utcp.data.utcp_manual import UtcpManual, UtcpManualSerializer +from utcp.data.register_manual_response import RegisterManualResult +from utcp.data.auth_implementations.api_key_auth import ApiKeyAuth +from utcp.data.auth_implementations.basic_auth import BasicAuth +from utcp.data.auth_implementations.oauth2_auth import OAuth2Auth +from utcp_websocket.websocket_call_template import WebSocketCallTemplate + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d - %(message)s" +) + +logger = logging.getLogger(__name__) + + +class WebSocketCommunicationProtocol(CommunicationProtocol): + """REQUIRED + WebSocket communication protocol implementation for UTCP client. + + Handles real-time bidirectional communication with WebSocket-based tool providers, + supporting various authentication methods and message formats. Enforces security + by requiring WSS or localhost connections. + + Features: + - Real-time WebSocket communication with persistent connections + - Multiple authentication: API key (header), Basic, OAuth2 + - Tool discovery via WebSocket handshake using UTCP messages + - Flexible message formats (JSON or text-based with templates) + - Connection pooling and automatic keep-alive + - OAuth2 token caching and automatic refresh + - Security validation of connection URLs + + Attributes: + _connections: Active WebSocket connections by provider key. + _sessions: aiohttp ClientSessions for connection management. + _oauth_tokens: Cache of OAuth2 tokens by client_id. + """ + + def __init__(self, logger_func: Optional[Callable[[str], None]] = None): + """Initialize the WebSocket communication protocol. + + Args: + logger_func: Optional logging function that accepts log messages. + """ + self._connections: Dict[str, ClientWebSocketResponse] = {} + self._sessions: Dict[str, ClientSession] = {} + self._oauth_tokens: Dict[str, Dict[str, Any]] = {} + + def _substitute_placeholders(self, template: Any, arguments: Dict[str, Any]) -> Any: + """Recursively substitute UTCP_ARG_arg_name_UTCP_ARG placeholders in template. + + Args: + template: Template (string, dict, or list) with UTCP_ARG_arg_name_UTCP_ARG placeholders + arguments: Arguments to substitute + + Returns: + Template with placeholders replaced + """ + if isinstance(template, str): + # Replace UTCP_ARG_arg_name_UTCP_ARG placeholders + result = template + for arg_name, arg_value in arguments.items(): + placeholder = f"UTCP_ARG_{arg_name}_UTCP_ARG" + if placeholder in result: + if isinstance(arg_value, str): + result = result.replace(placeholder, arg_value) + else: + result = result.replace(placeholder, json.dumps(arg_value)) + return result + elif isinstance(template, dict): + return {k: self._substitute_placeholders(v, arguments) for k, v in template.items()} + elif isinstance(template, list): + return [self._substitute_placeholders(item, arguments) for item in template] + else: + return template + + def _format_tool_call_message( + self, + tool_name: str, + arguments: Dict[str, Any], + call_template: WebSocketCallTemplate, + request_id: str + ) -> str: + """Format a tool call message based on call template configuration. + + Provides maximum flexibility to support ANY WebSocket endpoint format: + - If message template is provided, uses it with UTCP_ARG_arg_name_UTCP_ARG substitution + - Otherwise, sends arguments directly as JSON (no enforced structure) + + Args: + tool_name: Name of the tool to call + arguments: Arguments for the tool call + call_template: The WebSocketCallTemplate with formatting configuration + request_id: Unique request identifier + + Returns: + Formatted message string + """ + # Priority 1: Use message template if provided (most flexible - supports any format) + if call_template.message is not None: + substituted = self._substitute_placeholders(call_template.message, arguments) + # If it's a dict, convert to JSON string + if isinstance(substituted, dict): + return json.dumps(substituted) + else: + return str(substituted) + + # Priority 2: Default to just sending arguments as JSON (maximum flexibility) + # This allows ANY WebSocket endpoint to work without modification + # No enforced structure - just the raw arguments + return json.dumps(arguments) + + async def _handle_oauth2(self, auth: OAuth2Auth) -> str: + """Handle OAuth2 authentication and token management.""" + client_id = auth.client_id + if client_id in self._oauth_tokens: + return self._oauth_tokens[client_id]["access_token"] + + async with aiohttp.ClientSession() as session: + data = { + 'grant_type': 'client_credentials', + 'client_id': client_id, + 'client_secret': auth.client_secret, + 'scope': auth.scope + } + async with session.post(auth.token_url, data=data) as resp: + resp.raise_for_status() + token_response = await resp.json() + self._oauth_tokens[client_id] = token_response + return token_response["access_token"] + + async def _prepare_headers(self, call_template: WebSocketCallTemplate) -> Dict[str, str]: + """Prepare headers for WebSocket connection including authentication.""" + headers = call_template.headers.copy() if call_template.headers else {} + + if call_template.auth: + if isinstance(call_template.auth, ApiKeyAuth): + if call_template.auth.api_key: + if call_template.auth.location == "header": + headers[call_template.auth.var_name] = call_template.auth.api_key + + elif isinstance(call_template.auth, BasicAuth): + userpass = f"{call_template.auth.username}:{call_template.auth.password}" + headers["Authorization"] = "Basic " + base64.b64encode(userpass.encode()).decode() + + elif isinstance(call_template.auth, OAuth2Auth): + token = await self._handle_oauth2(call_template.auth) + headers["Authorization"] = f"Bearer {token}" + + return headers + + async def _get_connection(self, call_template: WebSocketCallTemplate) -> ClientWebSocketResponse: + """Get or create a WebSocket connection for the call template.""" + provider_key = f"{call_template.name}_{call_template.url}" + + # Check if we have an active connection + if provider_key in self._connections: + ws = self._connections[provider_key] + if not ws.closed: + return ws + else: + # Clean up closed connection + await self._cleanup_connection(provider_key) + + # Create new connection + headers = await self._prepare_headers(call_template) + + session = ClientSession() + self._sessions[provider_key] = session + + try: + ws = await session.ws_connect( + call_template.url, + headers=headers, + protocols=[call_template.protocol] if call_template.protocol else None, + heartbeat=30 if call_template.keep_alive else None + ) + self._connections[provider_key] = ws + logger.info(f"WebSocket connected to {call_template.url}") + return ws + + except Exception as e: + await session.close() + if provider_key in self._sessions: + del self._sessions[provider_key] + logger.error(f"Failed to connect to WebSocket {call_template.url}: {e}") + raise + + async def _cleanup_connection(self, provider_key: str): + """Clean up a specific connection.""" + if provider_key in self._connections: + ws = self._connections[provider_key] + if not ws.closed: + await ws.close() + del self._connections[provider_key] + + if provider_key in self._sessions: + session = self._sessions[provider_key] + await session.close() + del self._sessions[provider_key] + + async def register_manual(self, caller, manual_call_template: CallTemplate) -> RegisterManualResult: + """REQUIRED + Register a manual and its tools via WebSocket discovery. + + Sends a discovery message: {"type": "utcp"} + Expects a UtcpManual response with tools. + + Args: + caller: The UTCP client that is calling this method. + manual_call_template: The call template of the manual to register. + + Returns: + RegisterManualResult object containing the call template and manual. + """ + if not isinstance(manual_call_template, WebSocketCallTemplate): + raise ValueError("WebSocketCommunicationProtocol can only be used with WebSocketCallTemplate") + + ws = await self._get_connection(manual_call_template) + + try: + # Send discovery request (matching UDP pattern) + discovery_message = json.dumps({"type": "utcp"}) + await ws.send_str(discovery_message) + logger.info(f"Registering WebSocket manual '{manual_call_template.name}' at {manual_call_template.url}") + + # Wait for discovery response + timeout = manual_call_template.timeout + try: + async with asyncio.timeout(timeout): + async for msg in ws: + if msg.type == aiohttp.WSMsgType.TEXT: + try: + response_data = json.loads(msg.data) + + # Response data for a /utcp endpoint NEEDS to be a UtcpManual + if isinstance(response_data, dict) and 'tools' in response_data: + try: + # Parse as UtcpManual + utcp_manual = UtcpManualSerializer().validate_dict(response_data) + logger.info(f"Discovered {len(utcp_manual.tools)} tools from WebSocket manual '{manual_call_template.name}'") + return RegisterManualResult( + call_template=manual_call_template, + manual=utcp_manual + ) + except Exception as e: + logger.error(f"Invalid UtcpManual response from WebSocket manual '{manual_call_template.name}': {e}") + raise ValueError(f"Invalid UtcpManual format: {e}") + + except json.JSONDecodeError as e: + logger.error(f"Invalid JSON response from WebSocket manual '{manual_call_template.name}': {e}") + + elif msg.type == aiohttp.WSMsgType.ERROR: + logger.error(f"WebSocket error during discovery: {ws.exception()}") + break + + except asyncio.TimeoutError: + logger.error(f"Discovery timeout for {manual_call_template.url}") + raise ValueError(f"Tool discovery timeout for WebSocket manual {manual_call_template.url}") + + except Exception as e: + logger.error(f"Error registering WebSocket manual '{manual_call_template.name}': {e}") + raise + + # Should not reach here, but just in case + raise ValueError(f"Failed to discover tools from {manual_call_template.url}") + + async def deregister_manual(self, caller, manual_call_template: CallTemplate) -> None: + """REQUIRED + Deregister a manual by closing its WebSocket connection. + + Args: + caller: The UTCP client that is calling this method. + manual_call_template: The call template of the manual to deregister. + """ + if not isinstance(manual_call_template, WebSocketCallTemplate): + return + + provider_key = f"{manual_call_template.name}_{manual_call_template.url}" + await self._cleanup_connection(provider_key) + logger.info(f"Deregistered WebSocket manual '{manual_call_template.name}' (connection closed)") + + async def call_tool(self, caller, tool_name: str, tool_args: Dict[str, Any], tool_call_template: CallTemplate) -> Any: + """REQUIRED + Execute a tool call through WebSocket. + + Provides maximum flexibility to support ANY WebSocket response format: + - If response_format is specified, parses accordingly + - Otherwise, returns the raw response (string or bytes) + - No enforced response structure - works with any WebSocket endpoint + + Args: + caller: The UTCP client that is calling this method. + tool_name: Name of the tool to call. + tool_args: Dictionary of arguments to pass to the tool. + tool_call_template: Call template of the tool to call. + + Returns: + The tool's response (format depends on response_format setting). + """ + if not isinstance(tool_call_template, WebSocketCallTemplate): + raise ValueError("WebSocketCommunicationProtocol can only be used with WebSocketCallTemplate") + + logger.info(f"Calling WebSocket tool '{tool_name}'") + + ws = await self._get_connection(tool_call_template) + + try: + # Prepare tool call request + request_id = f"call_{tool_name}_{id(tool_args)}" + tool_call_message = self._format_tool_call_message(tool_name, tool_args, tool_call_template, request_id) + + await ws.send_str(tool_call_message) + logger.info(f"Sent tool call request for {tool_name}") + + # Wait for response + timeout = tool_call_template.timeout + try: + async with asyncio.timeout(timeout): + async for msg in ws: + if msg.type == aiohttp.WSMsgType.TEXT: + # Handle response based on response_format + if tool_call_template.response_format == "json": + try: + return json.loads(msg.data) + except json.JSONDecodeError: + logger.warning(f"Expected JSON response but got: {msg.data[:100]}") + return msg.data + elif tool_call_template.response_format == "text": + return msg.data + elif tool_call_template.response_format == "raw": + return msg.data + else: + # No format specified - return raw response (maximum flexibility) + return msg.data + + elif msg.type == aiohttp.WSMsgType.BINARY: + # Return binary data as-is + return msg.data + + elif msg.type == aiohttp.WSMsgType.ERROR: + logger.error(f"WebSocket error during tool call: {ws.exception()}") + raise RuntimeError(f"WebSocket error: {ws.exception()}") + + except asyncio.TimeoutError: + logger.error(f"Tool call timeout for {tool_name}") + raise RuntimeError(f"Tool call timeout for {tool_name}") + + except Exception as e: + logger.error(f"Error calling WebSocket tool '{tool_name}': {e}") + raise + + async def call_tool_streaming(self, caller, tool_name: str, tool_args: Dict[str, Any], tool_call_template: CallTemplate) -> AsyncGenerator[Any, None]: + """REQUIRED + Execute a tool call through WebSocket with streaming responses. + + Args: + caller: The UTCP client that is calling this method. + tool_name: Name of the tool to call. + tool_args: Dictionary of arguments to pass to the tool. + tool_call_template: Call template of the tool to call. + + Yields: + Streaming responses from the tool. + """ + if not isinstance(tool_call_template, WebSocketCallTemplate): + raise ValueError("WebSocketCommunicationProtocol can only be used with WebSocketCallTemplate") + + logger.info(f"Calling WebSocket tool '{tool_name}' (streaming)") + + ws = await self._get_connection(tool_call_template) + + try: + # Prepare tool call request + request_id = f"call_{tool_name}_{id(tool_args)}" + tool_call_message = self._format_tool_call_message(tool_name, tool_args, tool_call_template, request_id) + + await ws.send_str(tool_call_message) + logger.info(f"Sent streaming tool call request for {tool_name}") + + # Stream responses + timeout = tool_call_template.timeout + try: + async with asyncio.timeout(timeout): + async for msg in ws: + if msg.type == aiohttp.WSMsgType.TEXT: + try: + response = json.loads(msg.data) + if (response.get("request_id") == request_id or not response.get("request_id")): + if response.get("type") == "tool_response": + yield response.get("result") + elif response.get("type") == "tool_error": + error_msg = response.get("error", "Unknown error") + logger.error(f"Tool error for {tool_name}: {error_msg}") + raise RuntimeError(f"Tool {tool_name} failed: {error_msg}") + elif response.get("type") == "stream_end": + break + else: + yield msg.data + + except json.JSONDecodeError: + yield msg.data + + elif msg.type == aiohttp.WSMsgType.ERROR: + logger.error(f"WebSocket error during streaming: {ws.exception()}") + break + + except asyncio.TimeoutError: + logger.error(f"Streaming timeout for {tool_name}") + raise RuntimeError(f"Streaming timeout for {tool_name}") + + except Exception as e: + logger.error(f"Error streaming WebSocket tool '{tool_name}': {e}") + raise + + async def close(self) -> None: + """Close all WebSocket connections and sessions.""" + for provider_key in list(self._connections.keys()): + await self._cleanup_connection(provider_key) + + self._oauth_tokens.clear() + logger.info("WebSocket communication protocol closed") diff --git a/plugins/communication_protocols/websocket/tests/__init__.py b/plugins/communication_protocols/websocket/tests/__init__.py new file mode 100644 index 0000000..614ce9a --- /dev/null +++ b/plugins/communication_protocols/websocket/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for the WebSocket communication protocol plugin.""" diff --git a/plugins/communication_protocols/websocket/tests/test_websocket_call_template.py b/plugins/communication_protocols/websocket/tests/test_websocket_call_template.py new file mode 100644 index 0000000..ae62fd3 --- /dev/null +++ b/plugins/communication_protocols/websocket/tests/test_websocket_call_template.py @@ -0,0 +1,135 @@ +"""Tests for WebSocket call template.""" + +import pytest +from pydantic import ValidationError +from utcp_websocket.websocket_call_template import WebSocketCallTemplate, WebSocketCallTemplateSerializer + + +def test_websocket_call_template_basic(): + """Test basic WebSocket call template creation.""" + template = WebSocketCallTemplate( + name="test_ws", + url="wss://api.example.com/ws" + ) + assert template.name == "test_ws" + assert template.url == "wss://api.example.com/ws" + assert template.call_template_type == "websocket" + assert template.keep_alive is True + assert template.message is None # No message template by default (maximum flexibility) + assert template.response_format is None # No format enforcement by default + assert template.timeout == 30 + + +def test_websocket_call_template_localhost(): + """Test WebSocket call template with localhost URL.""" + template = WebSocketCallTemplate( + name="local_ws", + url="ws://localhost:8080/ws" + ) + assert template.url == "ws://localhost:8080/ws" + + +def test_websocket_call_template_invalid_url(): + """Test WebSocket call template rejects insecure URLs.""" + with pytest.raises(ValidationError) as exc_info: + WebSocketCallTemplate( + name="insecure_ws", + url="ws://remote.example.com/ws" + ) + assert "wss://" in str(exc_info.value) + + +def test_websocket_call_template_with_auth(): + """Test WebSocket call template with authentication.""" + from utcp.data.auth_implementations.api_key_auth import ApiKeyAuth + + template = WebSocketCallTemplate( + name="auth_ws", + url="wss://api.example.com/ws", + auth=ApiKeyAuth( + api_key="test-key", + var_name="Authorization", + location="header" + ) + ) + assert template.auth is not None + assert template.auth.api_key == "test-key" + + +def test_websocket_call_template_with_message_dict(): + """Test WebSocket call template with dict message template.""" + template = WebSocketCallTemplate( + name="dict_ws", + url="wss://api.example.com/ws", + message={"action": "UTCP_ARG_action_UTCP_ARG", "data": "UTCP_ARG_data_UTCP_ARG", "id": "123"} + ) + assert template.message == {"action": "UTCP_ARG_action_UTCP_ARG", "data": "UTCP_ARG_data_UTCP_ARG", "id": "123"} + + +def test_websocket_call_template_with_message_string(): + """Test WebSocket call template with string message template.""" + template = WebSocketCallTemplate( + name="string_ws", + url="wss://api.example.com/ws", + message="CMD:UTCP_ARG_command_UTCP_ARG;VALUE:UTCP_ARG_value_UTCP_ARG" + ) + assert template.message == "CMD:UTCP_ARG_command_UTCP_ARG;VALUE:UTCP_ARG_value_UTCP_ARG" + + +def test_websocket_call_template_serialization(): + """Test WebSocket call template serialization.""" + template = WebSocketCallTemplate( + name="test_ws", + url="wss://api.example.com/ws", + protocol="utcp-v1", + timeout=60, + message={"type": "UTCP_ARG_type_UTCP_ARG"}, + response_format="json" + ) + + serializer = WebSocketCallTemplateSerializer() + data = serializer.to_dict(template) + + assert data["name"] == "test_ws" + assert data["call_template_type"] == "websocket" + assert data["url"] == "wss://api.example.com/ws" + assert data["protocol"] == "utcp-v1" + assert data["timeout"] == 60 + assert data["message"] == {"type": "UTCP_ARG_type_UTCP_ARG"} + assert data["response_format"] == "json" + + # Deserialize + restored = serializer.validate_dict(data) + assert restored.name == template.name + assert restored.url == template.url + assert restored.protocol == template.protocol + assert restored.message == template.message + + +def test_websocket_call_template_with_headers(): + """Test WebSocket call template with custom headers.""" + template = WebSocketCallTemplate( + name="headers_ws", + url="wss://api.example.com/ws", + headers={"X-Custom": "value"}, + header_fields=["user_id"] + ) + assert template.headers == {"X-Custom": "value"} + assert template.header_fields == ["user_id"] + + +def test_websocket_call_template_response_format(): + """Test WebSocket call template with response format specification.""" + template = WebSocketCallTemplate( + name="format_ws", + url="wss://api.example.com/ws", + response_format="json" + ) + assert template.response_format == "json" + + template2 = WebSocketCallTemplate( + name="text_ws", + url="wss://api.example.com/ws", + response_format="text" + ) + assert template2.response_format == "text" diff --git a/src/utcp/client/transport_interfaces/websocket_transport.py b/src/utcp/client/transport_interfaces/websocket_transport.py new file mode 100644 index 0000000..465a7ae --- /dev/null +++ b/src/utcp/client/transport_interfaces/websocket_transport.py @@ -0,0 +1,400 @@ +from typing import Dict, Any, List, Optional, Callable, Union +import asyncio +import json +import logging +import ssl +import aiohttp +from aiohttp import ClientWebSocketResponse, ClientSession +import base64 + +from utcp.client.client_transport_interface import ClientTransportInterface +from utcp.shared.provider import Provider, WebSocketProvider +from utcp.shared.tool import Tool, ToolInputOutputSchema +from utcp.shared.utcp_manual import UtcpManual +from utcp.shared.auth import ApiKeyAuth, BasicAuth, OAuth2Auth + + +class WebSocketClientTransport(ClientTransportInterface): + """ + WebSocket transport implementation for UTCP that provides real-time bidirectional communication. + + This transport supports: + - Tool discovery via initial connection handshake + - Real-time tool execution with streaming responses + - Authentication (API Key, Basic Auth, OAuth2) + - Automatic reconnection and keep-alive + - Protocol subprotocols + """ + + def __init__(self, logger: Optional[Callable[[str], None]] = None): + self._log = logger or (lambda *args, **kwargs: None) + self._oauth_tokens: Dict[str, Dict[str, Any]] = {} + self._connections: Dict[str, ClientWebSocketResponse] = {} + self._sessions: Dict[str, ClientSession] = {} + + def _log_info(self, message: str): + """Log informational messages.""" + self._log(f"[WebSocketTransport] {message}") + + def _log_error(self, message: str): + """Log error messages.""" + logging.error(f"[WebSocketTransport Error] {message}") + + def _format_tool_call_message( + self, + tool_name: str, + arguments: Dict[str, Any], + provider: WebSocketProvider, + request_id: str + ) -> str: + """Format a tool call message based on provider configuration. + + Args: + tool_name: Name of the tool to call + arguments: Arguments for the tool call + provider: The WebSocketProvider with formatting configuration + request_id: Unique request identifier + + Returns: + Formatted message string + """ + # Check if provider specifies a custom message format + if provider.message_format: + # Custom format with placeholders (maintains backward compatibility) + try: + formatted_message = provider.message_format.format( + tool_name=tool_name, + arguments=json.dumps(arguments), + request_id=request_id + ) + return formatted_message + except (KeyError, json.JSONDecodeError) as e: + self._log_error(f"Error formatting custom message: {e}") + # Fall back to default format below + + # Handle request_data_format similar to UDP transport + if provider.request_data_format == "json": + return json.dumps({ + "type": "call_tool", + "request_id": request_id, + "tool_name": tool_name, + "arguments": arguments + }) + elif provider.request_data_format == "text": + # Use template-based formatting + if provider.request_data_template is not None and provider.request_data_template != "": + message = provider.request_data_template + # Replace placeholders with argument values + for arg_name, arg_value in arguments.items(): + placeholder = f"UTCP_ARG_{arg_name}_UTCP_ARG" + if isinstance(arg_value, str): + message = message.replace(placeholder, arg_value) + else: + message = message.replace(placeholder, json.dumps(arg_value)) + # Also replace tool name and request ID if placeholders exist + message = message.replace("UTCP_ARG_tool_name_UTCP_ARG", tool_name) + message = message.replace("UTCP_ARG_request_id_UTCP_ARG", request_id) + return message + else: + # Fallback to simple format + return f"{tool_name} {' '.join([str(v) for k, v in arguments.items()])}" + else: + # Default to JSON format + return json.dumps({ + "type": "call_tool", + "request_id": request_id, + "tool_name": tool_name, + "arguments": arguments + }) + + def _enforce_security(self, url: str): + """Enforce HTTPS/WSS or localhost for security.""" + if not (url.startswith("wss://") or + url.startswith("ws://localhost") or + url.startswith("ws://127.0.0.1")): + raise ValueError( + f"Security error: WebSocket URL must use WSS or start with 'ws://localhost' or 'ws://127.0.0.1'. " + f"Got: {url}. Non-secure URLs are vulnerable to man-in-the-middle attacks." + ) + + async def _handle_oauth2(self, auth: OAuth2Auth) -> str: + """Handle OAuth2 authentication and token management.""" + client_id = auth.client_id + if client_id in self._oauth_tokens: + return self._oauth_tokens[client_id]["access_token"] + + async with aiohttp.ClientSession() as session: + data = { + 'grant_type': 'client_credentials', + 'client_id': client_id, + 'client_secret': auth.client_secret, + 'scope': auth.scope + } + async with session.post(auth.token_url, data=data) as resp: + resp.raise_for_status() + token_response = await resp.json() + self._oauth_tokens[client_id] = token_response + return token_response["access_token"] + + async def _prepare_headers(self, provider: WebSocketProvider) -> Dict[str, str]: + """Prepare headers for WebSocket connection including authentication.""" + headers = provider.headers.copy() if provider.headers else {} + + if provider.auth: + if isinstance(provider.auth, ApiKeyAuth): + if provider.auth.api_key: + if provider.auth.location == "header": + headers[provider.auth.var_name] = provider.auth.api_key + # WebSocket doesn't support query params or cookies in the same way as HTTP + + elif isinstance(provider.auth, BasicAuth): + userpass = f"{provider.auth.username}:{provider.auth.password}" + headers["Authorization"] = "Basic " + base64.b64encode(userpass.encode()).decode() + + elif isinstance(provider.auth, OAuth2Auth): + token = await self._handle_oauth2(provider.auth) + headers["Authorization"] = f"Bearer {token}" + + return headers + + async def _get_connection(self, provider: WebSocketProvider) -> ClientWebSocketResponse: + """Get or create a WebSocket connection for the provider.""" + provider_key = f"{provider.name}_{provider.url}" + + # Check if we have an active connection + if provider_key in self._connections: + ws = self._connections[provider_key] + if not ws.closed: + return ws + else: + # Clean up closed connection + await self._cleanup_connection(provider_key) + + # Create new connection + self._enforce_security(provider.url) + headers = await self._prepare_headers(provider) + + session = ClientSession() + self._sessions[provider_key] = session + + try: + ws = await session.ws_connect( + provider.url, + headers=headers, + protocols=[provider.protocol] if provider.protocol else None, + heartbeat=30 if provider.keep_alive else None + ) + self._connections[provider_key] = ws + self._log(f"WebSocket connected to {provider.url}") + return ws + + except Exception as e: + await session.close() + if provider_key in self._sessions: + del self._sessions[provider_key] + self._log_error(f"Failed to connect to WebSocket {provider.url}: {e}") + raise + + async def _cleanup_connection(self, provider_key: str): + """Clean up a specific connection.""" + if provider_key in self._connections: + ws = self._connections[provider_key] + if not ws.closed: + await ws.close() + del self._connections[provider_key] + + if provider_key in self._sessions: + session = self._sessions[provider_key] + await session.close() + del self._sessions[provider_key] + + async def register_tool_provider(self, manual_provider: Provider) -> List[Tool]: + """ + Register a WebSocket tool provider by connecting and requesting tool discovery. + + The discovery protocol sends a JSON message: + {"type": "discover", "request_id": "unique_id"} + + Expected response: + {"type": "discovery_response", "request_id": "unique_id", "tools": [...]} + """ + if not isinstance(manual_provider, WebSocketProvider): + raise ValueError("WebSocketClientTransport can only be used with WebSocketProvider") + + ws = await self._get_connection(manual_provider) + + try: + # Send discovery request (matching UDP pattern) + discovery_message = json.dumps({ + "type": "utcp" + }) + await ws.send_str(discovery_message) + self._log_info(f"Registering WebSocket provider '{manual_provider.name}' at {manual_provider.url}") + + # Wait for discovery response + timeout = manual_provider.timeout / 1000.0 # Convert ms to seconds + try: + async with asyncio.timeout(timeout): + async for msg in ws: + if msg.type == aiohttp.WSMsgType.TEXT: + try: + response_data = json.loads(msg.data) + + # Response data for a /utcp endpoint NEEDS to be a UtcpManual + if isinstance(response_data, dict): + # Check if it's a UtcpManual format with tools + if 'tools' in response_data: + try: + # Parse as UtcpManual + utcp_manual = UtcpManual(**response_data) + tools = utcp_manual.tools + + self._log_info(f"Discovered {len(tools)} tools from WebSocket provider '{manual_provider.name}'") + return tools + except Exception as e: + self._log_error(f"Invalid UtcpManual response from WebSocket provider '{manual_provider.name}': {e}") + return [] + else: + # Try to parse individual tools directly (fallback for backward compatibility) + tools_data = response_data.get('tools', []) + tools = [] + for tool_data in tools_data: + try: + # Tools should come with their own tool_provider + tool = Tool(**tool_data) + tools.append(tool) + except Exception as e: + self._log_error(f"Invalid tool definition in WebSocket provider '{manual_provider.name}': {e}") + continue + + self._log_info(f"Discovered {len(tools)} tools from WebSocket provider '{manual_provider.name}'") + return tools + else: + self._log_info(f"No tools found in WebSocket provider '{manual_provider.name}' response") + return [] + + except json.JSONDecodeError as e: + self._log_error(f"Invalid JSON response from WebSocket provider '{manual_provider.name}': {e}") + + elif msg.type == aiohttp.WSMsgType.ERROR: + self._log_error(f"WebSocket error during discovery: {ws.exception()}") + break + + except asyncio.TimeoutError: + self._log_error(f"Discovery timeout for {manual_provider.url}") + raise ValueError(f"Tool discovery timeout for WebSocket provider {manual_provider.url}") + + except Exception as e: + self._log_error(f"Error registering WebSocket provider '{manual_provider.name}': {e}") + return [] + + return [] + + async def deregister_tool_provider(self, manual_provider: Provider) -> None: + """Deregister a WebSocket provider by closing its connection.""" + if not isinstance(manual_provider, WebSocketProvider): + return + + provider_key = f"{manual_provider.name}_{manual_provider.url}" + await self._cleanup_connection(provider_key) + self._log_info(f"Deregistering WebSocket provider '{manual_provider.name}' (connection closed)") + + async def call_tool(self, tool_name: str, arguments: Dict[str, Any], tool_provider: Provider) -> Any: + """ + Call a tool via WebSocket. + + The format can be customized per tool, but defaults to: + {"type": "call_tool", "request_id": "unique_id", "tool_name": "tool", "arguments": {...}} + + Expected response: + {"type": "tool_response", "request_id": "unique_id", "result": {...}} + or + {"type": "tool_error", "request_id": "unique_id", "error": "error message"} + """ + if not isinstance(tool_provider, WebSocketProvider): + raise ValueError("WebSocketClientTransport can only be used with WebSocketProvider") + + self._log_info(f"Calling WebSocket tool '{tool_name}' on provider '{tool_provider.name}'") + + ws = await self._get_connection(tool_provider) + + try: + # Prepare tool call request using the new formatting method + request_id = f"call_{tool_name}_{id(arguments)}" + tool_call_message = self._format_tool_call_message(tool_name, arguments, tool_provider, request_id) + + # For JSON format, we need to parse it back to add header fields if needed + if tool_provider.request_data_format == "json" or tool_provider.message_format: + try: + call_request = json.loads(tool_call_message) + + # Add any header fields to the request + if tool_provider.header_fields and arguments: + headers = {} + for field in tool_provider.header_fields: + if field in arguments: + headers[field] = arguments[field] + if headers: + call_request["headers"] = headers + + tool_call_message = json.dumps(call_request) + except json.JSONDecodeError: + # Keep the original message if it's not valid JSON + pass + + await ws.send_str(tool_call_message) + self._log_info(f"Sent tool call request for {tool_name}") + + # Wait for response + timeout = tool_provider.timeout / 1000.0 # Convert ms to seconds + try: + async with asyncio.timeout(timeout): + async for msg in ws: + if msg.type == aiohttp.WSMsgType.TEXT: + try: + response = json.loads(msg.data) + # Check for either new format or backward compatible format + if (response.get("request_id") == request_id or + not response.get("request_id")): # Allow responses without request_id for backward compatibility + if response.get("type") == "tool_response": + return response.get("result") + elif response.get("type") == "tool_error": + error_msg = response.get("error", "Unknown error") + self._log_error(f"Tool error for {tool_name}: {error_msg}") + raise RuntimeError(f"Tool {tool_name} failed: {error_msg}") + else: + # For non-UTCP responses, return the entire response + return msg.data + + except json.JSONDecodeError: + # Return raw response for non-JSON responses + return msg.data + + elif msg.type == aiohttp.WSMsgType.ERROR: + self._log_error(f"WebSocket error during tool call: {ws.exception()}") + break + + except asyncio.TimeoutError: + self._log_error(f"Tool call timeout for {tool_name}") + raise RuntimeError(f"Tool call timeout for {tool_name}") + + except Exception as e: + self._log_error(f"Error calling WebSocket tool '{tool_name}': {e}") + raise + + async def close(self) -> None: + """Close all WebSocket connections and sessions.""" + # Close all connections + for provider_key in list(self._connections.keys()): + await self._cleanup_connection(provider_key) + + # Clear OAuth tokens + self._oauth_tokens.clear() + + self._log_info("WebSocket transport closed") + + def __del__(self): + """Ensure cleanup on object destruction.""" + if self._connections or self._sessions: + # Log warning but can't await in __del__ + logging.warning("WebSocketClientTransport was not properly closed. Call close() explicitly.") \ No newline at end of file diff --git a/test_websocket_manual.py b/test_websocket_manual.py new file mode 100644 index 0000000..a1457c4 --- /dev/null +++ b/test_websocket_manual.py @@ -0,0 +1,201 @@ +#!/usr/bin/env python3 +""" +Manual test script for WebSocket transport implementation. +This tests the core functionality without requiring pytest setup. +""" + +import asyncio +import sys +import os + +# Add src to path for imports +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) + +from utcp.client.transport_interfaces.websocket_transport import WebSocketClientTransport +from utcp.shared.provider import WebSocketProvider +from utcp.shared.auth import ApiKeyAuth, BasicAuth + + +async def test_basic_functionality(): + """Test basic WebSocket transport functionality""" + print("Testing WebSocket Transport Implementation...") + + transport = WebSocketClientTransport() + + # Test 1: Security enforcement + print("\n1. Testing security enforcement...") + try: + insecure_provider = WebSocketProvider( + name="insecure", + url="ws://example.com/ws" # Should be rejected + ) + await transport.register_tool_provider(insecure_provider) + print("āŒ FAILED: Insecure URL was accepted") + except ValueError as e: + if "Security error" in str(e): + print("āœ… PASSED: Insecure URL properly rejected") + else: + print(f"āŒ FAILED: Wrong error: {e}") + except Exception as e: + print(f"āŒ FAILED: Unexpected error: {e}") + + # Test 2: Provider type validation + print("\n2. Testing provider type validation...") + try: + from utcp.shared.provider import HttpProvider + wrong_provider = HttpProvider(name="wrong", url="https://example.com") + await transport.register_tool_provider(wrong_provider) + print("āŒ FAILED: Wrong provider type was accepted") + except ValueError as e: + if "WebSocketClientTransport can only be used with WebSocketProvider" in str(e): + print("āœ… PASSED: Provider type validation works") + else: + print(f"āŒ FAILED: Wrong error: {e}") + except Exception as e: + print(f"āŒ FAILED: Unexpected error: {e}") + + # Test 3: Authentication header preparation + print("\n3. Testing authentication...") + try: + # Test API Key auth + api_provider = WebSocketProvider( + name="api_test", + url="wss://example.com/ws", + auth=ApiKeyAuth( + var_name="X-API-Key", + api_key="test-key-123", + location="header" + ) + ) + headers = await transport._prepare_headers(api_provider) + if headers.get("X-API-Key") == "test-key-123": + print("āœ… PASSED: API Key authentication headers prepared correctly") + else: + print(f"āŒ FAILED: API Key headers incorrect: {headers}") + + # Test Basic auth + basic_provider = WebSocketProvider( + name="basic_test", + url="wss://example.com/ws", + auth=BasicAuth(username="user", password="pass") + ) + headers = await transport._prepare_headers(basic_provider) + if "Authorization" in headers and headers["Authorization"].startswith("Basic "): + print("āœ… PASSED: Basic authentication headers prepared correctly") + else: + print(f"āŒ FAILED: Basic auth headers incorrect: {headers}") + + except Exception as e: + print(f"āŒ FAILED: Authentication test error: {e}") + + # Test 4: Connection management + print("\n4. Testing connection management...") + try: + localhost_provider = WebSocketProvider( + name="test_provider", + url="ws://localhost:8765/ws" + ) + + # This should fail to connect but not due to security + try: + await transport.register_tool_provider(localhost_provider) + print("āŒ FAILED: Connection should have failed (no server)") + except ValueError as e: + if "Security error" in str(e): + print("āŒ FAILED: Security error on localhost") + else: + print("ā“ UNEXPECTED: Different error occurred") + except Exception as e: + # Expected - connection refused or similar + print("āœ… PASSED: Connection management works (failed to connect as expected)") + + except Exception as e: + print(f"āŒ FAILED: Connection test error: {e}") + + # Test 5: Cleanup + print("\n5. Testing cleanup...") + try: + await transport.close() + if len(transport._connections) == 0 and len(transport._oauth_tokens) == 0: + print("āœ… PASSED: Cleanup successful") + else: + print("āŒ FAILED: Cleanup incomplete") + except Exception as e: + print(f"āŒ FAILED: Cleanup error: {e}") + + print("\nāœ… WebSocket transport basic functionality tests completed!") + + +async def test_with_mock_server(): + """Test with a real WebSocket connection to our mock server""" + print("\n" + "="*50) + print("Testing with Mock WebSocket Server") + print("="*50) + + # Import and start mock server + sys.path.append('tests/client/transport_interfaces') + try: + from mock_websocket_server import create_app + from aiohttp import web + + print("Starting mock WebSocket server...") + app = await create_app() + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, 'localhost', 8765) + await site.start() + + print("Mock server started on ws://localhost:8765/ws") + + # Test with our transport + transport = WebSocketClientTransport() + provider = WebSocketProvider( + name="test_provider", + url="ws://localhost:8765/ws" + ) + + try: + # Test tool discovery + print("\nTesting tool discovery...") + tools = await transport.register_tool_provider(provider) + print(f"āœ… Discovered {len(tools)} tools:") + for tool in tools: + print(f" - {tool.name}: {tool.description}") + + # Test tool execution + print("\nTesting tool execution...") + result = await transport.call_tool("echo", {"message": "Hello WebSocket!"}, provider) + print(f"āœ… Echo result: {result}") + + result = await transport.call_tool("add_numbers", {"a": 5, "b": 3}, provider) + print(f"āœ… Add result: {result}") + + # Test error handling + print("\nTesting error handling...") + try: + await transport.call_tool("simulate_error", {"error_message": "Test error"}, provider) + print("āŒ FAILED: Error tool should have failed") + except RuntimeError as e: + print(f"āœ… Error properly handled: {e}") + + except Exception as e: + print(f"āŒ Transport test failed: {e}") + finally: + await transport.close() + await runner.cleanup() + print("Mock server stopped") + + except ImportError as e: + print(f"āš ļø Mock server test skipped (missing dependencies): {e}") + except Exception as e: + print(f"āŒ Mock server test failed: {e}") + + +async def main(): + """Run all manual tests""" + await test_basic_functionality() + # await test_with_mock_server() # Uncomment if you want to test with real server + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file