diff --git a/README.md b/README.md index 4a58c6e..6b899ac 100644 --- a/README.md +++ b/README.md @@ -376,12 +376,18 @@ Configuration examples for each protocol. Remember to replace `provider_type` wi "url": "https://api.example.com/users/{user_id}", // Required "http_method": "POST", // Required, default: "GET" "content_type": "application/json", // Optional, default: "application/json" - "auth": { // Optional, example using ApiKeyAuth for a Bearer token. The client must prepend "Bearer " to the token. + "auth": { // Optional, authentication for the HTTP request (example using ApiKeyAuth for Bearer token) "auth_type": "api_key", "api_key": "Bearer $API_KEY", // Required "var_name": "Authorization", // Optional, default: "X-Api-Key" "location": "header" // Optional, default: "header" }, + "auth_tools": { // Optional, authentication for converted tools, if this call template points to an openapi spec that should be automatically converted to a utcp manual (applied only to endpoints requiring auth per OpenAPI spec) + "auth_type": "api_key", + "api_key": "Bearer $TOOL_API_KEY", // Required + "var_name": "Authorization", // Optional, default: "X-Api-Key" + "location": "header" // Optional, default: "header" + }, "headers": { // Optional "X-Custom-Header": "value" }, @@ -473,7 +479,13 @@ Note the name change from `http_stream` to `streamable_http`. "name": "my_text_manual", "call_template_type": "text", // Required "file_path": "./manuals/my_manual.json", // Required - "auth": null // Optional (always null for Text) + "auth": null, // Optional (always null for Text) + "auth_tools": { // Optional, authentication for generated tools from OpenAPI specs + "auth_type": "api_key", + "api_key": "Bearer ${API_TOKEN}", + "var_name": "Authorization", + "location": "header" + } } ``` @@ -569,7 +581,13 @@ client = await UtcpClient.create(config={ "manual_call_templates": [{ "name": "github", "call_template_type": "http", - "url": "https://api.github.com/openapi.json" + "url": "https://api.github.com/openapi.json", + "auth_tools": { # Authentication for generated tools requiring auth + "auth_type": "api_key", + "api_key": "Bearer ${GITHUB_TOKEN}", + "var_name": "Authorization", + "location": "header" + } }] }) ``` @@ -579,6 +597,7 @@ client = await UtcpClient.create(config={ - ✅ **Zero Infrastructure**: No servers to deploy or maintain - ✅ **Direct API Calls**: Native performance, no proxy overhead - ✅ **Automatic Conversion**: OpenAPI schemas → UTCP tools +- ✅ **Selective Authentication**: Only protected endpoints get auth, public endpoints remain accessible - ✅ **Authentication Preserved**: API keys, OAuth2, Basic auth supported - ✅ **Multi-format Support**: JSON, YAML, OpenAPI 2.0/3.0 - ✅ **Batch Processing**: Convert multiple APIs simultaneously diff --git a/plugins/communication_protocols/http/pyproject.toml b/plugins/communication_protocols/http/pyproject.toml index 52104c4..42f0951 100644 --- a/plugins/communication_protocols/http/pyproject.toml +++ b/plugins/communication_protocols/http/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "utcp-http" -version = "1.0.4" +version = "1.0.5" authors = [ { name = "UTCP Contributors" }, ] diff --git a/plugins/communication_protocols/http/src/utcp_http/http_call_template.py b/plugins/communication_protocols/http/src/utcp_http/http_call_template.py index b3a9e70..2fac727 100644 --- a/plugins/communication_protocols/http/src/utcp_http/http_call_template.py +++ b/plugins/communication_protocols/http/src/utcp_http/http_call_template.py @@ -1,10 +1,10 @@ from utcp.data.call_template import CallTemplate, CallTemplateSerializer -from utcp.data.auth import Auth +from utcp.data.auth import Auth, AuthSerializer from utcp.interfaces.serializer import Serializer from utcp.exceptions import UtcpSerializerValidationError import traceback -from typing import Optional, Dict, List, Literal -from pydantic import Field +from typing import Optional, Dict, List, Literal, Any +from pydantic import Field, field_serializer, field_validator class HttpCallTemplate(CallTemplate): """REQUIRED @@ -40,6 +40,12 @@ class HttpCallTemplate(CallTemplate): "var_name": "Authorization", "location": "header" }, + "auth_tools": { + "auth_type": "api_key", + "api_key": "Bearer ${TOOL_API_KEY}", + "var_name": "Authorization", + "location": "header" + }, "headers": { "X-Custom-Header": "value" }, @@ -85,7 +91,8 @@ class HttpCallTemplate(CallTemplate): url: The base URL for the HTTP endpoint. Supports path parameters like "https://api.example.com/users/{user_id}/posts/{post_id}". content_type: The Content-Type header for requests. - auth: Optional authentication configuration. + auth: Optional authentication configuration for accessing the OpenAPI spec URL. + auth_tools: Optional authentication configuration for generated tools. Applied only to endpoints requiring auth per OpenAPI spec. headers: Optional static headers to include in all requests. body_field: Name of the tool argument to map to the HTTP request body. header_fields: List of tool argument names to map to HTTP request headers. @@ -96,10 +103,30 @@ class HttpCallTemplate(CallTemplate): url: str content_type: str = Field(default="application/json") auth: Optional[Auth] = None + auth_tools: Optional[Auth] = Field(default=None, description="Authentication configuration for generated tools (applied only to endpoints requiring auth per OpenAPI spec)") headers: Optional[Dict[str, str]] = None body_field: Optional[str] = Field(default="body", description="The name of the single input field to be sent as the request body.") header_fields: Optional[List[str]] = Field(default=None, description="List of input fields to be sent as request headers.") + @field_serializer('auth_tools') + def serialize_auth_tools(self, auth_tools: Optional[Auth]) -> Optional[dict]: + """Serialize auth_tools to dictionary.""" + if auth_tools is None: + return None + return AuthSerializer().to_dict(auth_tools) + + @field_validator('auth_tools', mode='before') + @classmethod + def validate_auth_tools(cls, v: Any) -> Optional[Auth]: + """Validate and deserialize auth_tools from dictionary.""" + if v is None: + return None + if isinstance(v, Auth): + return v + if isinstance(v, dict): + return AuthSerializer().validate_dict(v) + raise ValueError(f"auth_tools must be None, Auth instance, or dict, got {type(v)}") + class HttpCallTemplateSerializer(Serializer[HttpCallTemplate]): """REQUIRED diff --git a/plugins/communication_protocols/http/src/utcp_http/http_communication_protocol.py b/plugins/communication_protocols/http/src/utcp_http/http_communication_protocol.py index a62e4d3..191a749 100644 --- a/plugins/communication_protocols/http/src/utcp_http/http_communication_protocol.py +++ b/plugins/communication_protocols/http/src/utcp_http/http_communication_protocol.py @@ -197,7 +197,7 @@ async def register_manual(self, caller, manual_call_template: CallTemplate) -> R utcp_manual = UtcpManualSerializer().validate_dict(response_data) else: logger.info(f"Assuming OpenAPI spec from '{manual_call_template.name}'. Converting to UTCP manual.") - converter = OpenApiConverter(response_data, spec_url=manual_call_template.url, call_template_name=manual_call_template.name) + converter = OpenApiConverter(response_data, spec_url=manual_call_template.url, call_template_name=manual_call_template.name, auth_tools=manual_call_template.auth_tools) utcp_manual = converter.convert() return RegisterManualResult( diff --git a/plugins/communication_protocols/http/src/utcp_http/openapi_converter.py b/plugins/communication_protocols/http/src/utcp_http/openapi_converter.py index be20fe6..c16412a 100644 --- a/plugins/communication_protocols/http/src/utcp_http/openapi_converter.py +++ b/plugins/communication_protocols/http/src/utcp_http/openapi_converter.py @@ -87,7 +87,7 @@ class OpenApiConverter: call_template_name: Normalized name for the call_template derived from the spec. """ - def __init__(self, openapi_spec: Dict[str, Any], spec_url: Optional[str] = None, call_template_name: Optional[str] = None): + def __init__(self, openapi_spec: Dict[str, Any], spec_url: Optional[str] = None, call_template_name: Optional[str] = None, auth_tools: Optional[Auth] = None): """Initializes the OpenAPI converter. Args: @@ -96,9 +96,12 @@ def __init__(self, openapi_spec: Dict[str, Any], spec_url: Optional[str] = None, Used for base URL determination if servers are not specified. call_template_name: Optional custom name for the call_template if the specification title is not provided. + auth_tools: Optional auth configuration for generated tools. + Applied only to endpoints that require authentication per OpenAPI spec. """ self.spec = openapi_spec self.spec_url = spec_url + self.auth_tools = auth_tools # Single counter for all placeholder variables self.placeholder_counter = 0 if call_template_name is None: @@ -160,7 +163,10 @@ def convert(self) -> UtcpManual: def _extract_auth(self, operation: Dict[str, Any]) -> Optional[Auth]: """ - Extracts authentication information from OpenAPI operation and global security schemes.""" + Extracts authentication information from OpenAPI operation and global security schemes. + Uses auth_tools configuration when compatible with OpenAPI auth requirements. + Supports both OpenAPI 2.0 and 3.0 security schemes. + """ # First check for operation-level security requirements security_requirements = operation.get("security", []) @@ -168,11 +174,11 @@ def _extract_auth(self, operation: Dict[str, Any]) -> Optional[Auth]: if not security_requirements: security_requirements = self.spec.get("security", []) - # If no security requirements, return None + # If no security requirements, return None (endpoint is public) if not security_requirements: return None - # Get security schemes - support both OpenAPI 2.0 and 3.0 + # Generate auth from OpenAPI security schemes - support both OpenAPI 2.0 and 3.0 security_schemes = self._get_security_schemes() # Process the first security requirement (most common case) @@ -181,9 +187,47 @@ def _extract_auth(self, operation: Dict[str, Any]) -> Optional[Auth]: for scheme_name, scopes in security_req.items(): if scheme_name in security_schemes: scheme = security_schemes[scheme_name] - return self._create_auth_from_scheme(scheme, scheme_name) + openapi_auth = self._create_auth_from_scheme(scheme, scheme_name) + + # If compatible with auth_tools, use actual values from manual call template + if self._is_auth_compatible(openapi_auth, self.auth_tools): + return self.auth_tools + else: + return openapi_auth # Use placeholder from OpenAPI scheme return None + + def _is_auth_compatible(self, openapi_auth: Optional[Auth], auth_tools: Optional[Auth]) -> bool: + """ + Checks if auth_tools configuration is compatible with OpenAPI auth requirements. + + Args: + openapi_auth: Auth generated from OpenAPI security scheme + auth_tools: Auth configuration from manual call template + + Returns: + True if compatible and auth_tools should be used, False otherwise + """ + if not openapi_auth or not auth_tools: + return False + + # Must be same auth type + if type(openapi_auth) != type(auth_tools): + return False + + # For API Key auth, check header name and location compatibility + if hasattr(openapi_auth, 'var_name') and hasattr(auth_tools, 'var_name'): + openapi_var = openapi_auth.var_name.lower() if openapi_auth.var_name else "" + tools_var = auth_tools.var_name.lower() if auth_tools.var_name else "" + + if openapi_var != tools_var: + return False + + if hasattr(openapi_auth, 'location') and hasattr(auth_tools, 'location'): + if openapi_auth.location != auth_tools.location: + return False + + return True def _get_security_schemes(self) -> Dict[str, Any]: """ diff --git a/plugins/communication_protocols/http/tests/test_auth_tools.py b/plugins/communication_protocols/http/tests/test_auth_tools.py new file mode 100644 index 0000000..f806930 --- /dev/null +++ b/plugins/communication_protocols/http/tests/test_auth_tools.py @@ -0,0 +1,250 @@ +""" +Tests for auth_tools functionality in OpenAPI converter. + +Tests the new auth_tools feature that allows manual call templates to provide +authentication configuration for generated tools, with compatibility checking +against OpenAPI security schemes. +""" + +import pytest +from utcp_http.openapi_converter import OpenApiConverter +from utcp.data.auth_implementations.api_key_auth import ApiKeyAuth +from utcp.data.auth_implementations.basic_auth import BasicAuth + + +def test_compatible_api_key_auth(): + """Test auth_tools with compatible API key authentication.""" + openapi_spec = { + "swagger": "2.0", + "info": {"title": "Test API", "version": "1.0.0"}, + "host": "api.test.com", + "securityDefinitions": { + "api_key": { + "type": "apiKey", + "name": "Authorization", + "in": "header" + } + }, + "paths": { + "/protected": { + "get": { + "operationId": "getProtected", + "security": [{"api_key": []}], + "responses": {"200": {"description": "success"}} + } + } + } + } + + # Compatible auth_tools (same header name and location) + auth_tools = ApiKeyAuth( + api_key="Bearer token-123", + var_name="Authorization", + location="header" + ) + + converter = OpenApiConverter(openapi_spec, auth_tools=auth_tools) + manual = converter.convert() + + assert len(manual.tools) == 1 + tool = manual.tools[0] + + # Should use auth_tools values since they're compatible + assert tool.tool_call_template.auth is not None + assert isinstance(tool.tool_call_template.auth, ApiKeyAuth) + assert tool.tool_call_template.auth.api_key == "Bearer token-123" + assert tool.tool_call_template.auth.var_name == "Authorization" + assert tool.tool_call_template.auth.location == "header" + + +def test_incompatible_api_key_auth(): + """Test auth_tools with incompatible API key authentication.""" + openapi_spec = { + "swagger": "2.0", + "info": {"title": "Test API", "version": "1.0.0"}, + "host": "api.test.com", + "securityDefinitions": { + "custom_key": { + "type": "apiKey", + "name": "X-API-Key", # Different header name + "in": "header" + } + }, + "paths": { + "/protected": { + "get": { + "operationId": "getProtected", + "security": [{"custom_key": []}], + "responses": {"200": {"description": "success"}} + } + } + } + } + + # Incompatible auth_tools (different header name) + auth_tools = ApiKeyAuth( + api_key="Bearer token-123", + var_name="Authorization", # Different from OpenAPI + location="header" + ) + + converter = OpenApiConverter(openapi_spec, auth_tools=auth_tools) + manual = converter.convert() + + assert len(manual.tools) == 1 + tool = manual.tools[0] + + # Should use OpenAPI scheme with placeholder since incompatible + assert tool.tool_call_template.auth is not None + assert isinstance(tool.tool_call_template.auth, ApiKeyAuth) + assert tool.tool_call_template.auth.api_key.startswith("${") # Placeholder + assert tool.tool_call_template.auth.var_name == "X-API-Key" # From OpenAPI + assert tool.tool_call_template.auth.location == "header" + + +def test_case_insensitive_header_matching(): + """Test that header name matching is case-insensitive.""" + openapi_spec = { + "swagger": "2.0", + "info": {"title": "Test API", "version": "1.0.0"}, + "host": "api.test.com", + "securityDefinitions": { + "api_key": { + "type": "apiKey", + "name": "authorization", # lowercase + "in": "header" + } + }, + "paths": { + "/protected": { + "get": { + "operationId": "getProtected", + "security": [{"api_key": []}], + "responses": {"200": {"description": "success"}} + } + } + } + } + + # auth_tools with different case + auth_tools = ApiKeyAuth( + api_key="Bearer token-123", + var_name="Authorization", # uppercase + location="header" + ) + + converter = OpenApiConverter(openapi_spec, auth_tools=auth_tools) + manual = converter.convert() + + tool = manual.tools[0] + + # Should be compatible despite case difference + assert tool.tool_call_template.auth.api_key == "Bearer token-123" + + +def test_different_auth_types_incompatible(): + """Test that different auth types are incompatible.""" + openapi_spec = { + "swagger": "2.0", + "info": {"title": "Test API", "version": "1.0.0"}, + "host": "api.test.com", + "securityDefinitions": { + "basic_auth": { + "type": "basic" + } + }, + "paths": { + "/protected": { + "get": { + "operationId": "getProtected", + "security": [{"basic_auth": []}], + "responses": {"200": {"description": "success"}} + } + } + } + } + + # Different auth type (API key vs Basic) + auth_tools = ApiKeyAuth( + api_key="Bearer token-123", + var_name="Authorization", + location="header" + ) + + converter = OpenApiConverter(openapi_spec, auth_tools=auth_tools) + manual = converter.convert() + + tool = manual.tools[0] + + # Should use OpenAPI scheme since types don't match + assert isinstance(tool.tool_call_template.auth, BasicAuth) + assert tool.tool_call_template.auth.username.startswith("${") # Placeholder + + +def test_public_endpoint_no_auth(): + """Test that public endpoints remain public regardless of auth_tools.""" + openapi_spec = { + "swagger": "2.0", + "info": {"title": "Test API", "version": "1.0.0"}, + "host": "api.test.com", + "paths": { + "/public": { + "get": { + "operationId": "getPublic", + # No security field - public endpoint + "responses": {"200": {"description": "success"}} + } + } + } + } + + auth_tools = ApiKeyAuth( + api_key="Bearer token-123", + var_name="Authorization", + location="header" + ) + + converter = OpenApiConverter(openapi_spec, auth_tools=auth_tools) + manual = converter.convert() + + tool = manual.tools[0] + + # Should have no auth since endpoint is public + assert tool.tool_call_template.auth is None + + +def test_no_auth_tools_uses_openapi_scheme(): + """Test fallback to OpenAPI scheme when no auth_tools provided.""" + openapi_spec = { + "swagger": "2.0", + "info": {"title": "Test API", "version": "1.0.0"}, + "host": "api.test.com", + "securityDefinitions": { + "api_key": { + "type": "apiKey", + "name": "X-API-Key", + "in": "header" + } + }, + "paths": { + "/protected": { + "get": { + "operationId": "getProtected", + "security": [{"api_key": []}], + "responses": {"200": {"description": "success"}} + } + } + } + } + + # No auth_tools provided + converter = OpenApiConverter(openapi_spec, auth_tools=None) + manual = converter.convert() + + tool = manual.tools[0] + + # Should use OpenAPI scheme with placeholder + assert tool.tool_call_template.auth is not None + assert isinstance(tool.tool_call_template.auth, ApiKeyAuth) + assert tool.tool_call_template.auth.api_key.startswith("${") + assert tool.tool_call_template.auth.var_name == "X-API-Key" diff --git a/plugins/communication_protocols/http/tests/test_http_communication_protocol.py b/plugins/communication_protocols/http/tests/test_http_communication_protocol.py index 753ec8e..518b8df 100644 --- a/plugins/communication_protocols/http/tests/test_http_communication_protocol.py +++ b/plugins/communication_protocols/http/tests/test_http_communication_protocol.py @@ -142,11 +142,6 @@ async def error_handler(request): return app -@pytest_asyncio.fixture -async def aiohttp_client(aiohttp_client, app): - """Create a test client for our app.""" - return await aiohttp_client(app) - @pytest_asyncio.fixture async def http_transport(): @@ -155,48 +150,52 @@ async def http_transport(): @pytest_asyncio.fixture -async def http_call_template(aiohttp_client): +async def http_call_template(aiohttp_client, app): """Create a basic HTTP call template for testing.""" + client = await aiohttp_client(app) return HttpCallTemplate( name="test_call_template", - url=f"http://localhost:{aiohttp_client.port}/tools", + url=f"http://localhost:{client.port}/tools", http_method="GET" ) @pytest_asyncio.fixture -async def api_key_call_template(aiohttp_client): +async def api_key_call_template(aiohttp_client, app): """Create an HTTP call template with API key auth.""" + client = await aiohttp_client(app) return HttpCallTemplate( name="api-key-call-template", - url=f"http://localhost:{aiohttp_client.port}/tool", + url=f"http://localhost:{client.port}/tool", http_method="GET", auth=ApiKeyAuth(api_key="test-api-key", var_name="X-API-Key", location="header") ) @pytest_asyncio.fixture -async def basic_auth_call_template(aiohttp_client): +async def basic_auth_call_template(aiohttp_client, app): """Create an HTTP call template with Basic auth.""" + client = await aiohttp_client(app) return HttpCallTemplate( name="basic-auth-call-template", - url=f"http://localhost:{aiohttp_client.port}/tool", + url=f"http://localhost:{client.port}/tool", http_method="GET", auth=BasicAuth(username="user", password="pass") ) @pytest_asyncio.fixture -async def oauth2_call_template(aiohttp_client): +async def oauth2_call_template(aiohttp_client, app): """Create an HTTP call template with OAuth2 auth.""" + client = await aiohttp_client(app) return HttpCallTemplate( name="oauth2-call-template", - url=f"http://localhost:{aiohttp_client.port}/tool", + url=f"http://localhost:{client.port}/tool", http_method="GET", auth=OAuth2Auth( client_id="client-id", client_secret="client-secret", - token_url=f"http://localhost:{aiohttp_client.port}/token", + token_url=f"http://localhost:{client.port}/token", scope="read write" ) ) @@ -232,12 +231,13 @@ async def test_register_manual(http_transport: HttpCommunicationProtocol, http_c # Test error handling when registering a manual @pytest.mark.asyncio -async def test_register_manual_http_error(http_transport, aiohttp_client): +async def test_register_manual_http_error(http_transport, aiohttp_client, app): """Test error handling when registering a manual.""" # Create a call template that points to our error endpoint + client = await aiohttp_client(app) error_call_template = HttpCallTemplate( name="error-call-template", - url=f"http://localhost:{aiohttp_client.port}/error", + url=f"http://localhost:{client.port}/error", http_method="GET" ) @@ -263,12 +263,13 @@ async def test_deregister_manual(http_transport, http_call_template): # Test call_tool_basic @pytest.mark.asyncio -async def test_call_tool_basic(http_transport, http_call_template, aiohttp_client): +async def test_call_tool_basic(http_transport, http_call_template, aiohttp_client, app): """Test calling a tool with basic configuration.""" # Update call template URL to point to our /tool endpoint + client = await aiohttp_client(app) tool_call_template = HttpCallTemplate( name=http_call_template.name, - url=f"http://localhost:{aiohttp_client.port}/tool", + url=f"http://localhost:{client.port}/tool", http_method="GET" ) @@ -314,17 +315,18 @@ async def test_call_tool_with_oauth2(http_transport, oauth2_call_template): @pytest.mark.asyncio -async def test_call_tool_with_oauth2_header_auth(http_transport, aiohttp_client): +async def test_call_tool_with_oauth2_header_auth(http_transport, aiohttp_client, app): """Test calling a tool with OAuth2 authentication (credentials in header).""" # This call template points to an endpoint that expects Basic Auth for the token + client = await aiohttp_client(app) oauth2_header_call_template = HttpCallTemplate( name="oauth2-header-call-template", - url=f"http://localhost:{aiohttp_client.port}/tool", + url=f"http://localhost:{client.port}/tool", http_method="GET", auth=OAuth2Auth( client_id="client-id", client_secret="client-secret", - token_url=f"http://localhost:{aiohttp_client.port}/token_header_auth", + token_url=f"http://localhost:{client.port}/token_header_auth", scope="read write" ) ) @@ -339,12 +341,13 @@ async def test_call_tool_with_oauth2_header_auth(http_transport, aiohttp_client) # Test call_tool_with_body_field @pytest.mark.asyncio -async def test_call_tool_with_body_field(http_transport, aiohttp_client): +async def test_call_tool_with_body_field(http_transport, aiohttp_client, app): """Test calling a tool with a body field.""" # Create call template with body field + client = await aiohttp_client(app) call_template = HttpCallTemplate( name="body-field-call-template", - url=f"http://localhost:{aiohttp_client.port}/tool", + url=f"http://localhost:{client.port}/tool", http_method="POST", body_field="data" ) @@ -363,12 +366,13 @@ async def test_call_tool_with_body_field(http_transport, aiohttp_client): # Test call_tool_with_path_params @pytest.mark.asyncio -async def test_call_tool_with_path_params(http_transport, aiohttp_client): +async def test_call_tool_with_path_params(http_transport, aiohttp_client, app): """Test calling a tool with path parameters.""" # Create call template with path params in URL + client = await aiohttp_client(app) call_template = HttpCallTemplate( name="path-params-call-template", - url=f"http://localhost:{aiohttp_client.port}/tool/{{param1}}", + url=f"http://localhost:{client.port}/tool/{{param1}}", http_method="GET" ) @@ -386,12 +390,13 @@ async def test_call_tool_with_path_params(http_transport, aiohttp_client): # Test call_tool_with_custom_headers @pytest.mark.asyncio -async def test_call_tool_with_custom_headers(http_transport, aiohttp_client): +async def test_call_tool_with_custom_headers(http_transport, aiohttp_client, app): """Test calling a tool with custom headers.""" # Create call template with custom headers + client = await aiohttp_client(app) call_template = HttpCallTemplate( name="custom-headers-call-template", - url=f"http://localhost:{aiohttp_client.port}/tool", + url=f"http://localhost:{client.port}/tool", http_method="GET", additional_headers={"X-Custom-Header": "custom-value"} ) @@ -527,11 +532,12 @@ async def path_param_handler(request): @pytest.mark.asyncio -async def test_call_tool_streaming_basic(http_transport, http_call_template, aiohttp_client): +async def test_call_tool_streaming_basic(http_transport, http_call_template, aiohttp_client, app): """Streaming basic call should yield one result identical to call_tool.""" + client = await aiohttp_client(app) tool_call_template = HttpCallTemplate( name=http_call_template.name, - url=f"http://localhost:{aiohttp_client.port}/tool", + url=f"http://localhost:{client.port}/tool", http_method="GET", ) stream = http_transport.call_tool_streaming(None, "test_tool", {"param1": "value1"}, tool_call_template) @@ -564,16 +570,17 @@ async def test_call_tool_streaming_with_oauth2(http_transport, oauth2_call_templ @pytest.mark.asyncio -async def test_call_tool_streaming_with_oauth2_header_auth(http_transport, aiohttp_client): +async def test_call_tool_streaming_with_oauth2_header_auth(http_transport, aiohttp_client, app): """Streaming with OAuth2 (credentials in header) yields one aggregated result.""" + client = await aiohttp_client(app) oauth2_header_call_template = HttpCallTemplate( name="oauth2-header-call-template", - url=f"http://localhost:{aiohttp_client.port}/tool", + url=f"http://localhost:{client.port}/tool", http_method="GET", auth=OAuth2Auth( client_id="client-id", client_secret="client-secret", - token_url=f"http://localhost:{aiohttp_client.port}/token_header_auth", + token_url=f"http://localhost:{client.port}/token_header_auth", scope="read write", ), ) @@ -583,11 +590,12 @@ async def test_call_tool_streaming_with_oauth2_header_auth(http_transport, aioht @pytest.mark.asyncio -async def test_call_tool_streaming_with_body_field(http_transport, aiohttp_client): +async def test_call_tool_streaming_with_body_field(http_transport, aiohttp_client, app): """Streaming POST with body_field yields one aggregated result.""" + client = await aiohttp_client(app) call_template = HttpCallTemplate( name="body-field-call-template", - url=f"http://localhost:{aiohttp_client.port}/tool", + url=f"http://localhost:{client.port}/tool", http_method="POST", body_field="data", ) @@ -602,11 +610,12 @@ async def test_call_tool_streaming_with_body_field(http_transport, aiohttp_clien @pytest.mark.asyncio -async def test_call_tool_streaming_with_path_params(http_transport, aiohttp_client): +async def test_call_tool_streaming_with_path_params(http_transport, aiohttp_client, app): """Streaming with URL path params yields one aggregated result.""" + client = await aiohttp_client(app) call_template = HttpCallTemplate( name="path-params-call-template", - url=f"http://localhost:{aiohttp_client.port}/tool/{{param1}}", + url=f"http://localhost:{client.port}/tool/{{param1}}", http_method="GET", ) stream = http_transport.call_tool_streaming( @@ -620,11 +629,12 @@ async def test_call_tool_streaming_with_path_params(http_transport, aiohttp_clie @pytest.mark.asyncio -async def test_call_tool_streaming_with_custom_headers(http_transport, aiohttp_client): +async def test_call_tool_streaming_with_custom_headers(http_transport, aiohttp_client, app): """Streaming with additional headers yields one aggregated result.""" + client = await aiohttp_client(app) call_template = HttpCallTemplate( name="custom-headers-call-template", - url=f"http://localhost:{aiohttp_client.port}/tool", + url=f"http://localhost:{client.port}/tool", http_method="GET", additional_headers={"X-Custom-Header": "custom-value"}, ) @@ -694,3 +704,35 @@ async def test_call_tool_openlibrary_style_url(http_transport): expected_remaining = {"format": "json"} http_transport._build_url_with_path_params(call_template.url, arguments) assert arguments == expected_remaining + + +def test_auth_tools_integration(): + """Test that auth_tools field is properly integrated in HttpCallTemplate.""" + from utcp.data.auth_implementations.api_key_auth import ApiKeyAuth + from utcp_http.http_call_template import HttpCallTemplateSerializer + + # Create auth_tools configuration + auth_tools = ApiKeyAuth( + api_key="Bearer test-token", + var_name="Authorization", + location="header" + ) + + # Create HttpCallTemplate with auth_tools + call_template = HttpCallTemplate( + name="test-auth-tools", + url="https://api.example.com/spec.json", + auth_tools=auth_tools + ) + + # Verify auth_tools is stored correctly + assert call_template.auth_tools is not None + assert call_template.auth_tools.api_key == "Bearer test-token" + assert call_template.auth_tools.var_name == "Authorization" + assert call_template.auth_tools.location == "header" + + # Verify it can be serialized (auth_type is included for security) + serializer = HttpCallTemplateSerializer() + serialized = serializer.to_dict(call_template) + assert "auth_tools" in serialized + assert serialized["auth_tools"]["auth_type"] == "api_key" diff --git a/plugins/communication_protocols/http/tests/test_openapi_converter.py b/plugins/communication_protocols/http/tests/test_openapi_converter.py index 77382c7..aa3f3cb 100644 --- a/plugins/communication_protocols/http/tests/test_openapi_converter.py +++ b/plugins/communication_protocols/http/tests/test_openapi_converter.py @@ -3,6 +3,7 @@ import sys from utcp_http.openapi_converter import OpenApiConverter from utcp.data.utcp_manual import UtcpManual +from utcp.data.auth_implementations.api_key_auth import ApiKeyAuth @pytest.mark.asyncio @@ -28,7 +29,30 @@ async def test_openai_spec_conversion(): assert sample_tool.tool_call_template.http_method == "POST" body_schema = sample_tool.inputs.properties.get('body') assert body_schema is not None - assert body_schema.properties is not None - assert "messages" in body_schema.properties - assert "model" in body_schema.properties - assert "choices" in sample_tool.outputs.properties + + +@pytest.mark.asyncio +async def test_openapi_converter_with_auth_tools(): + """Test OpenAPI converter with auth_tools parameter.""" + url = "https://api.apis.guru/v2/specs/openai.com/1.2.0/openapi.json" + + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + response.raise_for_status() + openapi_spec = await response.json() + + # Test with auth_tools parameter + auth_tools = ApiKeyAuth( + api_key="Bearer test-token", + var_name="Authorization", + location="header" + ) + + converter = OpenApiConverter(openapi_spec, spec_url=url, auth_tools=auth_tools) + utcp_manual = converter.convert() + + assert isinstance(utcp_manual, UtcpManual) + assert len(utcp_manual.tools) > 0 + + # Verify auth_tools is stored + assert converter.auth_tools == auth_tools diff --git a/plugins/communication_protocols/text/README.md b/plugins/communication_protocols/text/README.md index 2057bf8..27f8525 100644 --- a/plugins/communication_protocols/text/README.md +++ b/plugins/communication_protocols/text/README.md @@ -8,9 +8,11 @@ A simple, file-based resource plugin for UTCP. This plugin allows you to define - **Local File Content**: Define tools that read and return the content of local files. - **UTCP Manual Discovery**: Load tool definitions from local UTCP manual files in JSON or YAML format. +- **OpenAPI Support**: Automatically converts local OpenAPI specs to UTCP tools with optional authentication. - **Static & Simple**: Ideal for returning mock data, configuration, or any static text content from a file. - **Version Control**: Tool definitions and their corresponding content files can be versioned with your code. -- **No Authentication**: Designed for simple, local file access without authentication. +- **No File Authentication**: Designed for simple, local file access without authentication for file reading. +- **Tool Authentication**: Supports authentication for generated tools from OpenAPI specs via `auth_tools`. ## Installation diff --git a/plugins/communication_protocols/text/pyproject.toml b/plugins/communication_protocols/text/pyproject.toml index 3780c57..c624e8c 100644 --- a/plugins/communication_protocols/text/pyproject.toml +++ b/plugins/communication_protocols/text/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "utcp-text" -version = "1.0.2" +version = "1.0.3" authors = [ { name = "UTCP Contributors" }, ] diff --git a/plugins/communication_protocols/text/src/utcp_text/text_call_template.py b/plugins/communication_protocols/text/src/utcp_text/text_call_template.py index 23ba009..a090817 100644 --- a/plugins/communication_protocols/text/src/utcp_text/text_call_template.py +++ b/plugins/communication_protocols/text/src/utcp_text/text_call_template.py @@ -1,7 +1,8 @@ -from typing import Literal -from pydantic import Field +from typing import Literal, Optional, Any +from pydantic import Field, field_serializer, field_validator from utcp.data.call_template import CallTemplate +from utcp.data.auth import Auth, AuthSerializer from utcp.interfaces.serializer import Serializer from utcp.exceptions import UtcpSerializerValidationError import traceback @@ -16,12 +17,33 @@ class TextCallTemplate(CallTemplate): Attributes: call_template_type: Always "text" for text file call templates. file_path: Path to the file containing the UTCP manual or tool definitions. - auth: Always None - text call templates don't support authentication. + auth: Always None - text call templates don't support authentication for file access. + auth_tools: Optional authentication to apply to generated tools from OpenAPI specs. """ call_template_type: Literal["text"] = "text" file_path: str = Field(..., description="The path to the file containing the UTCP manual or tool definitions.") auth: None = None + auth_tools: Optional[Auth] = Field(None, description="Authentication to apply to generated tools from OpenAPI specs.") + + @field_serializer('auth_tools') + def serialize_auth_tools(self, auth_tools: Optional[Auth]) -> Optional[dict]: + """Serialize auth_tools to dictionary.""" + if auth_tools is None: + return None + return AuthSerializer().to_dict(auth_tools) + + @field_validator('auth_tools', mode='before') + @classmethod + def validate_auth_tools(cls, v: Any) -> Optional[Auth]: + """Validate and deserialize auth_tools from dictionary.""" + if v is None: + return None + if isinstance(v, Auth): + return v + if isinstance(v, dict): + return AuthSerializer().validate_dict(v) + raise ValueError(f"auth_tools must be None, Auth instance, or dict, got {type(v)}") class TextCallTemplateSerializer(Serializer[TextCallTemplate]): diff --git a/plugins/communication_protocols/text/src/utcp_text/text_communication_protocol.py b/plugins/communication_protocols/text/src/utcp_text/text_communication_protocol.py index 0d672dd..cdd49ae 100644 --- a/plugins/communication_protocols/text/src/utcp_text/text_communication_protocol.py +++ b/plugins/communication_protocols/text/src/utcp_text/text_communication_protocol.py @@ -70,7 +70,12 @@ async def register_manual(self, caller: 'UtcpClient', manual_call_template: Call utcp_manual: UtcpManual if isinstance(data, dict) and ("openapi" in data or "swagger" in data or "paths" in data): self._log_info("Detected OpenAPI specification. Converting to UTCP manual.") - converter = OpenApiConverter(data, spec_url=file_path.as_uri(), call_template_name=manual_call_template.name) + converter = OpenApiConverter( + data, + spec_url=file_path.as_uri(), + call_template_name=manual_call_template.name, + auth_tools=manual_call_template.auth_tools + ) utcp_manual = converter.convert() else: # Try to validate as UTCP manual directly diff --git a/plugins/communication_protocols/text/tests/test_text_communication_protocol.py b/plugins/communication_protocols/text/tests/test_text_communication_protocol.py index 0b7dffb..179b34c 100644 --- a/plugins/communication_protocols/text/tests/test_text_communication_protocol.py +++ b/plugins/communication_protocols/text/tests/test_text_communication_protocol.py @@ -12,6 +12,7 @@ from utcp_text.text_call_template import TextCallTemplate from utcp.data.call_template import CallTemplate from utcp.data.register_manual_response import RegisterManualResult +from utcp.data.auth_implementations.api_key_auth import ApiKeyAuth from utcp.utcp_client import UtcpClient @pytest_asyncio.fixture @@ -356,3 +357,45 @@ async def test_call_tool_streaming(text_protocol: TextCommunicationProtocol, sam assert chunks == [content] finally: Path(temp_file).unlink() + + +@pytest.mark.asyncio +async def test_text_call_template_with_auth_tools(): + """Test that TextCallTemplate can be created with auth_tools.""" + auth_tools = ApiKeyAuth(api_key="test-key", var_name="Authorization", location="header") + + template = TextCallTemplate( + name="test-template", + file_path="test.json", + auth_tools=auth_tools + ) + + assert template.auth_tools == auth_tools + assert template.auth is None # auth should still be None for file access + + +@pytest.mark.asyncio +async def test_text_call_template_auth_tools_serialization(): + """Test that auth_tools field properly serializes and validates from dict.""" + # Test creation from dict + template_dict = { + "name": "test-template", + "call_template_type": "text", + "file_path": "test.json", + "auth_tools": { + "auth_type": "api_key", + "api_key": "test-key", + "var_name": "Authorization", + "location": "header" + } + } + + template = TextCallTemplate(**template_dict) + assert template.auth_tools is not None + assert template.auth_tools.api_key == "test-key" + assert template.auth_tools.var_name == "Authorization" + + # Test serialization to dict + serialized = template.model_dump() + assert serialized["auth_tools"]["auth_type"] == "api_key" + assert serialized["auth_tools"]["api_key"] == "test-key" diff --git a/plugins/tool_search/in_mem_embeddings/README.md b/plugins/tool_search/in_mem_embeddings/README.md new file mode 100644 index 0000000..5a844a6 --- /dev/null +++ b/plugins/tool_search/in_mem_embeddings/README.md @@ -0,0 +1,39 @@ +# UTCP In-Memory Embeddings Search Plugin + +This plugin registers the in-memory embedding-based semantic search strategy with UTCP 1.0 via entry points. + +## Installation + +```bash +pip install utcp-in-mem-embeddings +``` + +Optionally, for high-quality embeddings: + +```bash +pip install "utcp-in-mem-embeddings[embedding]" +``` + +Or install the required dependencies directly: + +```bash +pip install "sentence-transformers>=2.2.0" "torch>=1.9.0" +``` + +## Why are sentence-transformers and torch needed? + +While the plugin works without these packages (using a simple character frequency-based fallback), installing them provides significant benefits: + +- **Enhanced Semantic Understanding**: The `sentence-transformers` package provides pre-trained models that convert text into high-quality vector embeddings, capturing the semantic meaning of text rather than just keywords. + +- **Better Search Results**: With these packages installed, the search can understand conceptual similarity between queries and tools, even when they don't share exact keywords. + +- **Performance**: The default model (all-MiniLM-L6-v2) offers a good balance between quality and performance for semantic search applications. + +- **Fallback Mechanism**: Without these packages, the plugin automatically falls back to a simpler text similarity method, which works but with reduced accuracy. + +## How it works + +When installed, this package exposes an entry point under `utcp.plugins` so the UTCP core can auto-discover and register the `in_mem_embeddings` strategy. + +The embeddings are cached in memory for improved performance during repeated searches. diff --git a/plugins/tool_search/in_mem_embeddings/pyproject.toml b/plugins/tool_search/in_mem_embeddings/pyproject.toml new file mode 100644 index 0000000..3010572 --- /dev/null +++ b/plugins/tool_search/in_mem_embeddings/pyproject.toml @@ -0,0 +1,44 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "utcp-in-mem-embeddings" +version = "1.0.0" +authors = [ + { name = "UTCP Contributors" }, +] +description = "UTCP plugin providing in-memory embedding-based semantic tool search." +readme = "README.md" +requires-python = ">=3.10" +dependencies = [ + "pydantic>=2.0", + "utcp>=1.0", + "numpy>=2.3", +] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Programming Language :: Python :: 3", + "Operating System :: OS Independent", +] +license = "MPL-2.0" + +[project.optional-dependencies] +embedding = [ + "sentence-transformers>=2.2.0", + "torch>=1.9.0", +] +test = [ + "pytest>=7.0.0", + "pytest-asyncio>=0.21.0", +] + + +[project.urls] +Homepage = "https://utcp.io" +Source = "https://github.com/universal-tool-calling-protocol/python-utcp" +Issues = "https://github.com/universal-tool-calling-protocol/python-utcp/issues" + +[project.entry-points."utcp.plugins"] +in_mem_embeddings = "utcp_in_mem_embeddings:register" diff --git a/plugins/tool_search/in_mem_embeddings/src/utcp_in_mem_embeddings/__init__.py b/plugins/tool_search/in_mem_embeddings/src/utcp_in_mem_embeddings/__init__.py new file mode 100644 index 0000000..044e744 --- /dev/null +++ b/plugins/tool_search/in_mem_embeddings/src/utcp_in_mem_embeddings/__init__.py @@ -0,0 +1,11 @@ +from utcp.plugins.discovery import register_tool_search_strategy +from utcp_in_mem_embeddings.in_mem_embeddings_search import InMemEmbeddingsSearchStrategyConfigSerializer + + +def register(): + """Entry point function to register the in-memory embeddings search strategy.""" + register_tool_search_strategy("in_mem_embeddings", InMemEmbeddingsSearchStrategyConfigSerializer()) + +__all__ = [ + "InMemEmbeddingsSearchStrategyConfigSerializer", +] diff --git a/plugins/tool_search/in_mem_embeddings/src/utcp_in_mem_embeddings/in_mem_embeddings_search.py b/plugins/tool_search/in_mem_embeddings/src/utcp_in_mem_embeddings/in_mem_embeddings_search.py new file mode 100644 index 0000000..669748d --- /dev/null +++ b/plugins/tool_search/in_mem_embeddings/src/utcp_in_mem_embeddings/in_mem_embeddings_search.py @@ -0,0 +1,241 @@ +"""In-memory embedding-based semantic search strategy for UTCP tools. + +This module provides a semantic search implementation that uses sentence embeddings +to find tools based on meaning similarity rather than just keyword matching. +Embeddings are cached in memory for improved performance. +""" + +import asyncio +import logging +from typing import List, Tuple, Optional, Literal, Dict, Any +from concurrent.futures import ThreadPoolExecutor +import numpy as np +from pydantic import BaseModel, Field, PrivateAttr + +from utcp.interfaces.tool_search_strategy import ToolSearchStrategy +from utcp.data.tool import Tool +from utcp.interfaces.concurrent_tool_repository import ConcurrentToolRepository +from utcp.interfaces.serializer import Serializer + +logger = logging.getLogger(__name__) + +class InMemEmbeddingsSearchStrategy(ToolSearchStrategy): + """In-memory semantic search strategy using sentence embeddings. + + This strategy converts tool descriptions and search queries into numerical + embeddings and finds the most semantically similar tools using cosine similarity. + Embeddings are cached in memory for improved performance during repeated searches. + """ + + tool_search_strategy_type: Literal["in_mem_embeddings"] = "in_mem_embeddings" + + # Configuration parameters + model_name: str = Field( + default="all-MiniLM-L6-v2", + description="Sentence transformer model name to use for embeddings. " + "Accepts any model from Hugging Face sentence-transformers library. " + "Popular options: 'all-MiniLM-L6-v2' (fast, good quality), " + "'all-mpnet-base-v2' (slower, higher quality), " + "'paraphrase-MiniLM-L6-v2' (paraphrase detection). " + "See https://huggingface.co/sentence-transformers for full list." + ) + similarity_threshold: float = Field(default=0.3, description="Minimum similarity score to consider a match") + max_workers: int = Field(default=4, description="Maximum number of worker threads for embedding generation") + cache_embeddings: bool = Field(default=True, description="Whether to cache tool embeddings for performance") + + # Private attributes + _embedding_model: Optional[Any] = PrivateAttr(default=None) + _tool_embeddings_cache: Dict[str, np.ndarray] = PrivateAttr(default_factory=dict) + _executor: Optional[ThreadPoolExecutor] = PrivateAttr(default=None) + _model_loaded: bool = PrivateAttr(default=False) + + def __init__(self, **data): + super().__init__(**data) + self._executor = ThreadPoolExecutor(max_workers=self.max_workers) + + async def _ensure_model_loaded(self): + """Ensure the embedding model is loaded.""" + if self._model_loaded: + return + + try: + # Import sentence-transformers here to avoid dependency issues + from sentence_transformers import SentenceTransformer + + # Load the model in a thread to avoid blocking + loop = asyncio.get_running_loop() + self._embedding_model = await loop.run_in_executor( + self._executor, + SentenceTransformer, + self.model_name + ) + self._model_loaded = True + logger.info(f"Loaded embedding model: {self.model_name}") + + except ImportError: + logger.warning("sentence-transformers not available, falling back to simple text similarity") + self._embedding_model = None + self._model_loaded = True + except Exception as e: + logger.error(f"Failed to load embedding model: {e}") + self._embedding_model = None + self._model_loaded = True + + async def _get_text_embedding(self, text: str) -> np.ndarray: + """Generate embedding for given text.""" + if not text: + return np.zeros(384) # Default dimension for all-MiniLM-L6-v2 + + if self._embedding_model is None: + # Fallback to simple text similarity + return self._simple_text_embedding(text) + + try: + loop = asyncio.get_event_loop() + embedding = await loop.run_in_executor( + self._executor, + self._embedding_model.encode, + text + ) + return embedding + except Exception as e: + logger.warning(f"Failed to generate embedding for text: {e}") + return self._simple_text_embedding(text) + + def _simple_text_embedding(self, text: str) -> np.ndarray: + """Simple fallback embedding using character frequency.""" + # Create a simple embedding based on character frequency + # This is a fallback when sentence-transformers is not available + embedding = np.zeros(384) + text_lower = text.lower() + + # Simple character frequency-based embedding + for i, char in enumerate(text_lower): + embedding[i % 384] += ord(char) / 1000.0 + + # Normalize + norm = np.linalg.norm(embedding) + if norm > 0: + embedding = embedding / norm + + return embedding + + async def _get_tool_embedding(self, tool: Tool) -> np.ndarray: + """Get or generate embedding for a tool.""" + if not self.cache_embeddings or tool.name not in self._tool_embeddings_cache: + # Create text representation of the tool + tool_text = f"{tool.name} {tool.description} {' '.join(tool.tags)}" + embedding = await self._get_text_embedding(tool_text) + + if self.cache_embeddings: + self._tool_embeddings_cache[tool.name] = embedding + + return embedding + + return self._tool_embeddings_cache[tool.name] + + def _cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float: + """Calculate cosine similarity between two vectors.""" + try: + dot_product = np.dot(a, b) + norm_a = np.linalg.norm(a) + norm_b = np.linalg.norm(b) + + if norm_a == 0 or norm_b == 0: + return 0.0 + + return dot_product / (norm_a * norm_b) + except Exception as e: + logger.warning(f"Error calculating cosine similarity: {e}") + return 0.0 + + async def search_tools( + self, + tool_repository: ConcurrentToolRepository, + query: str, + limit: int = 10, + any_of_tags_required: Optional[List[str]] = None + ) -> List[Tool]: + """Search for tools using semantic similarity. + + Args: + tool_repository: The tool repository to search within. + query: The search query string. + limit: Maximum number of tools to return. + any_of_tags_required: Optional list of tags where one of them must be present. + + Returns: + List of Tool objects ranked by semantic similarity. + """ + if limit < 0: + raise ValueError("limit must be non-negative") + + # Ensure the embedding model is loaded + await self._ensure_model_loaded() + + # Get all tools + tools: List[Tool] = await tool_repository.get_tools() + + # Filter by required tags if specified + if any_of_tags_required and len(any_of_tags_required) > 0: + any_of_tags_required = [tag.lower() for tag in any_of_tags_required] + tools = [ + tool for tool in tools + if any(tag.lower() in any_of_tags_required for tag in tool.tags) + ] + + if not tools: + return [] + + # Generate query embedding + query_embedding = await self._get_text_embedding(query) + + # Calculate similarity scores for all tools + tool_scores: List[Tuple[Tool, float]] = [] + + for tool in tools: + try: + tool_embedding = await self._get_tool_embedding(tool) + similarity = self._cosine_similarity(query_embedding, tool_embedding) + + if similarity >= self.similarity_threshold: + tool_scores.append((tool, similarity)) + + except Exception as e: + logger.warning(f"Error processing tool {tool.name}: {e}") + continue + + # Sort by similarity score (descending) + sorted_tools = [ + tool for tool, score in sorted( + tool_scores, + key=lambda x: x[1], + reverse=True + ) + ] + + # Return up to 'limit' tools + return sorted_tools[:limit] if limit > 0 else sorted_tools + + async def __aenter__(self): + """Async context manager entry.""" + await self._ensure_model_loaded() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + if self._executor: + self._executor.shutdown(wait=False) + + +class InMemEmbeddingsSearchStrategyConfigSerializer(Serializer[InMemEmbeddingsSearchStrategy]): + """Serializer for InMemEmbeddingsSearchStrategy configuration.""" + + def to_dict(self, obj: InMemEmbeddingsSearchStrategy) -> dict: + return obj.model_dump() + + def validate_dict(self, data: dict) -> InMemEmbeddingsSearchStrategy: + try: + return InMemEmbeddingsSearchStrategy.model_validate(data) + except Exception as e: + raise ValueError(f"Invalid configuration: {e}") from e diff --git a/plugins/tool_search/in_mem_embeddings/tests/test_in_mem_embeddings_search.py b/plugins/tool_search/in_mem_embeddings/tests/test_in_mem_embeddings_search.py new file mode 100644 index 0000000..d294407 --- /dev/null +++ b/plugins/tool_search/in_mem_embeddings/tests/test_in_mem_embeddings_search.py @@ -0,0 +1,342 @@ +"""Tests for the InMemEmbeddingsSearchStrategy implementation.""" +import pytest +import numpy as np +import sys +from pathlib import Path +from unittest.mock import patch +from typing import List + +# Add plugin source to path +plugin_src = Path(__file__).parent.parent / "src" +sys.path.insert(0, str(plugin_src)) + +# Add core to path +core_src = Path(__file__).parent.parent.parent.parent.parent / "core" / "src" +sys.path.insert(0, str(core_src)) + +from utcp_in_mem_embeddings.in_mem_embeddings_search import InMemEmbeddingsSearchStrategy +from utcp.data.tool import Tool, JsonSchema +from utcp.data.call_template import CallTemplate + + +class MockToolRepository: + """Simplified mock repository for testing.""" + + def __init__(self, tools: List[Tool]): + self.tools = tools + + async def get_tools(self) -> List[Tool]: + return self.tools + + +@pytest.fixture +def sample_tools(): + """Create sample tools for testing.""" + tools = [] + + # Tool 1: Cooking related + tool1 = Tool( + name="cooking.spatula", + description="A kitchen utensil used for flipping and turning food while cooking", + inputs=JsonSchema(), + outputs=JsonSchema(), + tags=["cooking", "kitchen", "utensil"], + tool_call_template=CallTemplate( + name="cooking.spatula", + description="Spatula tool", + call_template_type="default" + ) + ) + tools.append(tool1) + + # Tool 2: Programming related + tool2 = Tool( + name="dev.code_review", + description="Review and analyze source code for quality and best practices", + inputs=JsonSchema(), + outputs=JsonSchema(), + tags=["programming", "development", "code"], + tool_call_template=CallTemplate( + name="dev.code_review", + description="Code review tool", + call_template_type="default" + ) + ) + tools.append(tool2) + + # Tool 3: Data analysis + tool3 = Tool( + name="data.analyze", + description="Analyze datasets and generate insights from data", + inputs=JsonSchema(), + outputs=JsonSchema(), + tags=["data", "analysis", "insights"], + tool_call_template=CallTemplate( + name="data.analyze", + description="Data analysis tool", + call_template_type="default" + ) + ) + tools.append(tool3) + + return tools + + +@pytest.fixture +def in_mem_embeddings_strategy(): + """Create an in-memory embeddings search strategy instance.""" + return InMemEmbeddingsSearchStrategy( + model_name="all-MiniLM-L6-v2", + similarity_threshold=0.3, + max_workers=2, + cache_embeddings=True + ) + + +@pytest.mark.asyncio +async def test_in_mem_embeddings_strategy_initialization(in_mem_embeddings_strategy): + """Test that the in-memory embeddings strategy initializes correctly.""" + assert in_mem_embeddings_strategy.tool_search_strategy_type == "in_mem_embeddings" + assert in_mem_embeddings_strategy.model_name == "all-MiniLM-L6-v2" + assert in_mem_embeddings_strategy.similarity_threshold == 0.3 + assert in_mem_embeddings_strategy.max_workers == 2 + assert in_mem_embeddings_strategy.cache_embeddings is True + + +@pytest.mark.asyncio +async def test_simple_text_embedding_fallback(in_mem_embeddings_strategy): + """Test the fallback text embedding when sentence-transformers is not available.""" + # Mock the embedding model to be None to trigger fallback + in_mem_embeddings_strategy._embedding_model = None + in_mem_embeddings_strategy._model_loaded = True + + text = "test text" + embedding = await in_mem_embeddings_strategy._get_text_embedding(text) + + assert isinstance(embedding, np.ndarray) + assert embedding.shape == (384,) + assert np.linalg.norm(embedding) > 0 + + +@pytest.mark.asyncio +async def test_cosine_similarity_calculation(in_mem_embeddings_strategy): + """Test cosine similarity calculation.""" + # Test with identical vectors + vec1 = np.array([1.0, 0.0, 0.0]) + vec2 = np.array([1.0, 0.0, 0.0]) + similarity = in_mem_embeddings_strategy._cosine_similarity(vec1, vec2) + assert similarity == pytest.approx(1.0) + + # Test with orthogonal vectors + vec3 = np.array([0.0, 1.0, 0.0]) + similarity = in_mem_embeddings_strategy._cosine_similarity(vec1, vec3) + assert similarity == pytest.approx(0.0) + + # Test with zero vectors + vec4 = np.zeros(3) + similarity = in_mem_embeddings_strategy._cosine_similarity(vec1, vec4) + assert similarity == 0.0 + + +@pytest.mark.asyncio +async def test_tool_embedding_generation(in_mem_embeddings_strategy, sample_tools): + """Test that tool embeddings are generated and cached correctly.""" + tool = sample_tools[0] + + # Mock the text embedding method + with patch.object(in_mem_embeddings_strategy, '_get_text_embedding') as mock_embed: + mock_embed.return_value = np.random.rand(384) + + # First call should generate and cache + embedding1 = await in_mem_embeddings_strategy._get_tool_embedding(tool) + assert tool.name in in_mem_embeddings_strategy._tool_embeddings_cache + + # Second call should use cache + embedding2 = await in_mem_embeddings_strategy._get_tool_embedding(tool) + assert np.array_equal(embedding1, embedding2) + + # Verify the mock was called only once + mock_embed.assert_called_once() + + +@pytest.mark.asyncio +async def test_search_tools_basic(in_mem_embeddings_strategy, sample_tools): + """Test basic search functionality.""" + tool_repo = MockToolRepository(sample_tools) + + # Mock the embedding methods + with patch.object(in_mem_embeddings_strategy, '_get_text_embedding') as mock_query_embed, \ + patch.object(in_mem_embeddings_strategy, '_get_tool_embedding') as mock_tool_embed: + + # Create mock embeddings + query_embedding = np.random.rand(384) + tool_embeddings = [np.random.rand(384) for _ in sample_tools] + + mock_query_embed.return_value = query_embedding + mock_tool_embed.side_effect = tool_embeddings + + # Mock cosine similarity to return high scores + with patch.object(in_mem_embeddings_strategy, '_cosine_similarity') as mock_sim: + mock_sim.return_value = 0.8 # High similarity + + results = await in_mem_embeddings_strategy.search_tools(tool_repo, "cooking", limit=2) + + assert len(results) == 2 + assert all(isinstance(tool, Tool) for tool in results) + + +@pytest.mark.asyncio +async def test_search_tools_with_tag_filtering(in_mem_embeddings_strategy, sample_tools): + """Test search with tag filtering.""" + tool_repo = MockToolRepository(sample_tools) + + with patch.object(in_mem_embeddings_strategy, '_get_text_embedding') as mock_query_embed, \ + patch.object(in_mem_embeddings_strategy, '_get_tool_embedding') as mock_tool_embed, \ + patch.object(in_mem_embeddings_strategy, '_cosine_similarity') as mock_sim: + + mock_query_embed.return_value = np.random.rand(384) + mock_tool_embed.return_value = np.random.rand(384) + mock_sim.return_value = 0.8 + + # Search with required tags + results = await in_mem_embeddings_strategy.search_tools( + tool_repo, + "cooking", + limit=10, + any_of_tags_required=["cooking", "kitchen"] + ) + + # Should only return tools with cooking or kitchen tags + assert all( + any(tag in ["cooking", "kitchen"] for tag in tool.tags) + for tool in results + ) + + +@pytest.mark.asyncio +async def test_search_tools_with_similarity_threshold(in_mem_embeddings_strategy, sample_tools): + """Test that similarity threshold filtering works correctly.""" + tool_repo = MockToolRepository(sample_tools) + + with patch.object(in_mem_embeddings_strategy, '_get_text_embedding') as mock_query_embed, \ + patch.object(in_mem_embeddings_strategy, '_get_tool_embedding') as mock_tool_embed, \ + patch.object(in_mem_embeddings_strategy, '_cosine_similarity') as mock_sim: + + mock_query_embed.return_value = np.random.rand(384) + mock_tool_embed.return_value = np.random.rand(384) + + # Set threshold to 0.5 and return scores below and above + in_mem_embeddings_strategy.similarity_threshold = 0.5 + mock_sim.side_effect = [0.3, 0.7, 0.2] # Only second tool should pass + + results = await in_mem_embeddings_strategy.search_tools(tool_repo, "test", limit=10) + + assert len(results) == 1 # Only one tool above threshold + + +@pytest.mark.asyncio +async def test_search_tools_limit_respected(in_mem_embeddings_strategy, sample_tools): + """Test that the limit parameter is respected.""" + tool_repo = MockToolRepository(sample_tools) + + with patch.object(in_mem_embeddings_strategy, '_get_text_embedding') as mock_query_embed, \ + patch.object(in_mem_embeddings_strategy, '_get_tool_embedding') as mock_tool_embed, \ + patch.object(in_mem_embeddings_strategy, '_cosine_similarity') as mock_sim: + + mock_query_embed.return_value = np.random.rand(384) + mock_tool_embed.return_value = np.random.rand(384) + mock_sim.return_value = 0.8 + + # Test with limit 1 + results = await in_mem_embeddings_strategy.search_tools(tool_repo, "test", limit=1) + assert len(results) == 1 + + # Test with limit 0 (no limit) + results = await in_mem_embeddings_strategy.search_tools(tool_repo, "test", limit=0) + assert len(results) == 3 # All tools + + +@pytest.mark.asyncio +async def test_search_tools_empty_repository(in_mem_embeddings_strategy): + """Test search behavior with empty tool repository.""" + tool_repo = MockToolRepository([]) + + results = await in_mem_embeddings_strategy.search_tools(tool_repo, "test", limit=10) + assert results == [] + + +@pytest.mark.asyncio +async def test_search_tools_invalid_limit(in_mem_embeddings_strategy, sample_tools): + """Test that invalid limit values raise appropriate errors.""" + tool_repo = MockToolRepository(sample_tools) + + with pytest.raises(ValueError, match="limit must be non-negative"): + await in_mem_embeddings_strategy.search_tools(tool_repo, "test", limit=-1) + + +@pytest.mark.asyncio +async def test_context_manager_behavior(in_mem_embeddings_strategy): + """Test async context manager behavior.""" + async with in_mem_embeddings_strategy as strategy: + assert strategy._model_loaded is True + + # Executor should be shut down + assert strategy._executor._shutdown is True + + +@pytest.mark.asyncio +async def test_error_handling_in_search(in_mem_embeddings_strategy, sample_tools): + """Test that errors in search are handled gracefully.""" + tool_repo = MockToolRepository(sample_tools) + + with patch.object(in_mem_embeddings_strategy, '_get_text_embedding') as mock_query_embed, \ + patch.object(in_mem_embeddings_strategy, '_get_tool_embedding') as mock_tool_embed: + + mock_query_embed.return_value = np.random.rand(384) + + # Make the second tool fail + def mock_tool_embed_side_effect(tool): + if tool.name == "dev.code_review": + raise Exception("Simulated error") + return np.random.rand(384) + + mock_tool_embed.side_effect = mock_tool_embed_side_effect + + # Mock cosine similarity + with patch.object(in_mem_embeddings_strategy, '_cosine_similarity') as mock_sim: + mock_sim.return_value = 0.8 + + # Should not crash, just skip the problematic tool + results = await in_mem_embeddings_strategy.search_tools(tool_repo, "test", limit=10) + + # Should return tools that didn't fail + assert len(results) == 2 # One tool failed, so only 2 results + + +@pytest.mark.asyncio +async def test_in_mem_embeddings_strategy_config_serializer(): + """Test the configuration serializer.""" + from utcp_in_mem_embeddings.in_mem_embeddings_search import InMemEmbeddingsSearchStrategyConfigSerializer + + serializer = InMemEmbeddingsSearchStrategyConfigSerializer() + + # Test serialization + strategy = InMemEmbeddingsSearchStrategy( + model_name="test-model", + similarity_threshold=0.5, + max_workers=8, + cache_embeddings=False + ) + + config_dict = serializer.to_dict(strategy) + assert config_dict["model_name"] == "test-model" + assert config_dict["similarity_threshold"] == 0.5 + assert config_dict["max_workers"] == 8 + assert config_dict["cache_embeddings"] is False + + # Test deserialization + restored_strategy = serializer.validate_dict(config_dict) + assert restored_strategy.model_name == "test-model" + assert restored_strategy.similarity_threshold == 0.5 + assert restored_strategy.max_workers == 8 + assert restored_strategy.cache_embeddings is False diff --git a/plugins/tool_search/in_mem_embeddings/tests/test_integration.py b/plugins/tool_search/in_mem_embeddings/tests/test_integration.py new file mode 100644 index 0000000..da4dedd --- /dev/null +++ b/plugins/tool_search/in_mem_embeddings/tests/test_integration.py @@ -0,0 +1,211 @@ +#!/usr/bin/env python3 +"""Integration tests to verify the plugin works with the core UTCP system.""" + +import sys +from pathlib import Path +import pytest +import pytest_asyncio + +# Add paths +plugin_src = (Path(__file__).parent / "src").resolve() +core_src = (Path(__file__).parent.parent.parent.parent / "core" / "src").resolve() +sys.path.insert(0, str(plugin_src)) +sys.path.insert(0, str(core_src)) + + +@pytest.fixture(scope="session") +def register_plugin(): + """Register the plugin once for all tests.""" + from utcp_in_mem_embeddings import register + register() + return True + + +@pytest_asyncio.fixture +async def sample_tools(): + """Create sample tools for testing.""" + from utcp.data.tool import Tool, JsonSchema + from utcp.data.call_template import CallTemplate + + return [ + Tool( + name="test.tool1", + description="A test tool for cooking", + inputs=JsonSchema(), + outputs=JsonSchema(), + tags=["cooking", "test"], + tool_call_template=CallTemplate( + name="test.tool1", + call_template_type="default" + ) + ), + Tool( + name="test.tool2", + description="A test tool for programming", + inputs=JsonSchema(), + outputs=JsonSchema(), + tags=["programming", "development"], + tool_call_template=CallTemplate( + name="test.tool2", + call_template_type="default" + ) + ) + ] + + +@pytest_asyncio.fixture +async def tool_repository(sample_tools): + """Create a tool repository with sample tools.""" + from utcp.implementations.in_mem_tool_repository import InMemToolRepository + from utcp.data.utcp_manual import UtcpManual + from utcp.data.call_template import CallTemplate + + repo = InMemToolRepository() + manual = UtcpManual(tools=sample_tools) + manual_call_template = CallTemplate(name="test_manual", call_template_type="default") + await repo.save_manual(manual_call_template, manual) + + return repo + + +@pytest.mark.asyncio +async def test_plugin_registration(register_plugin): + """Test that the plugin can be registered successfully.""" + # The fixture already registers the plugin, so we just verify it worked + assert register_plugin is True + + +@pytest.mark.asyncio +async def test_plugin_discovery(register_plugin): + """Test that the core system can discover the registered plugin.""" + from utcp.interfaces.tool_search_strategy import ToolSearchStrategyConfigSerializer + + strategies = ToolSearchStrategyConfigSerializer.tool_search_strategy_implementations + assert "in_mem_embeddings" in strategies, "Plugin should be discoverable by core system" + + +@pytest.mark.asyncio +async def test_strategy_creation_through_core(register_plugin): + """Test creating strategy instance through the core serialization system.""" + from utcp.interfaces.tool_search_strategy import ToolSearchStrategyConfigSerializer + + serializer = ToolSearchStrategyConfigSerializer() + + strategy_config = { + "tool_search_strategy_type": "in_mem_embeddings", + "model_name": "all-MiniLM-L6-v2", + "similarity_threshold": 0.3 + } + + strategy = serializer.validate_dict(strategy_config) + assert strategy.tool_search_strategy_type == "in_mem_embeddings" + assert strategy.model_name == "all-MiniLM-L6-v2" + assert strategy.similarity_threshold == 0.3 + + +@pytest.mark.asyncio +async def test_basic_search_functionality(register_plugin, tool_repository): + """Test basic search functionality with the plugin.""" + from utcp.interfaces.tool_search_strategy import ToolSearchStrategyConfigSerializer + + # Create strategy through core system + serializer = ToolSearchStrategyConfigSerializer() + strategy_config = { + "tool_search_strategy_type": "in_mem_embeddings", + "model_name": "all-MiniLM-L6-v2", + "similarity_threshold": 0.3 + } + strategy = serializer.validate_dict(strategy_config) + + # Test search for cooking-related tools + results = await strategy.search_tools(tool_repository, "cooking", limit=1) + assert len(results) > 0, "Search should return at least one result for 'cooking' query" + + # Verify the result is relevant + cooking_tool = results[0] + assert "cooking" in cooking_tool.description.lower() or "cooking" in cooking_tool.tags + + +@pytest.mark.asyncio +async def test_search_with_different_queries(register_plugin, tool_repository): + """Test search functionality with different query types.""" + from utcp.interfaces.tool_search_strategy import ToolSearchStrategyConfigSerializer + + serializer = ToolSearchStrategyConfigSerializer() + strategy_config = { + "tool_search_strategy_type": "in_mem_embeddings", + "model_name": "all-MiniLM-L6-v2", + "similarity_threshold": 0.3 + } + strategy = serializer.validate_dict(strategy_config) + + # Test different queries + test_cases = [ + ("cooking", "cooking"), + ("programming", "programming"), + ("development", "programming") # Should match programming tool + ] + + for query, expected_tag in test_cases: + results = await strategy.search_tools(tool_repository, query, limit=2) + assert len(results) > 0, f"Search should return results for '{query}' query" + + # Check if any result contains the expected tag + found_relevant = any( + expected_tag in tool.tags or expected_tag in tool.description.lower() + for tool in results + ) + assert found_relevant, f"Results should be relevant to '{query}' query" + + +@pytest.mark.asyncio +async def test_search_limit_parameter(register_plugin, tool_repository): + """Test that the limit parameter works correctly.""" + from utcp.interfaces.tool_search_strategy import ToolSearchStrategyConfigSerializer + + serializer = ToolSearchStrategyConfigSerializer() + strategy_config = { + "tool_search_strategy_type": "in_mem_embeddings", + "model_name": "all-MiniLM-L6-v2", + "similarity_threshold": 0.1 # Lower threshold to get more results + } + strategy = serializer.validate_dict(strategy_config) + + # Test with limit=1 + results_1 = await strategy.search_tools(tool_repository, "test", limit=1) + assert len(results_1) <= 1, "Should respect limit=1" + + # Test with limit=2 + results_2 = await strategy.search_tools(tool_repository, "test", limit=2) + assert len(results_2) <= 2, "Should respect limit=2" + + +@pytest.mark.asyncio +async def test_similarity_threshold(register_plugin, tool_repository): + """Test that similarity threshold affects results.""" + from utcp.interfaces.tool_search_strategy import ToolSearchStrategyConfigSerializer + + serializer = ToolSearchStrategyConfigSerializer() + + # Test with high threshold (should return fewer results) + high_threshold_config = { + "tool_search_strategy_type": "in_mem_embeddings", + "model_name": "all-MiniLM-L6-v2", + "similarity_threshold": 0.9 + } + high_threshold_strategy = serializer.validate_dict(high_threshold_config) + + # Test with low threshold (should return more results) + low_threshold_config = { + "tool_search_strategy_type": "in_mem_embeddings", + "model_name": "all-MiniLM-L6-v2", + "similarity_threshold": 0.1 + } + low_threshold_strategy = serializer.validate_dict(low_threshold_config) + + # Search with both strategies + high_results = await high_threshold_strategy.search_tools(tool_repository, "random_query", limit=10) + low_results = await low_threshold_strategy.search_tools(tool_repository, "random_query", limit=10) + + # Low threshold should return same or more results than high threshold + assert len(low_results) >= len(high_results), "Lower threshold should return more results" diff --git a/plugins/tool_search/in_mem_embeddings/tests/test_performance.py b/plugins/tool_search/in_mem_embeddings/tests/test_performance.py new file mode 100644 index 0000000..5e4c8a2 --- /dev/null +++ b/plugins/tool_search/in_mem_embeddings/tests/test_performance.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python3 +"""Performance test for the in-memory embeddings plugin.""" + +import sys +import asyncio +import time +from pathlib import Path +import pytest + +# Add paths +plugin_src = Path(__file__).parent.parent / "src" +core_src = Path(__file__).parent.parent.parent.parent.parent / "core" / "src" +sys.path.insert(0, str(plugin_src)) +sys.path.insert(0, str(core_src)) + +@pytest.mark.asyncio +async def test_performance(): + """Test plugin performance with multiple tools and searches.""" + print("⚡ Testing Performance...") + + try: + from utcp_in_mem_embeddings.in_mem_embeddings_search import InMemEmbeddingsSearchStrategy + from utcp.data.tool import Tool, JsonSchema + from utcp.data.call_template import CallTemplate + + # Create strategy + strategy = InMemEmbeddingsSearchStrategy( + model_name="all-MiniLM-L6-v2", + similarity_threshold=0.3, + max_workers=2, + cache_embeddings=True + ) + + # Create many tools + print("1. Creating 100 test tools...") + tools = [] + for i in range(100): + tool = Tool( + name=f"test_tool{i}", + description=f"Test tool {i} for various purposes like cooking, coding, data analysis", + inputs=JsonSchema(), + outputs=JsonSchema(), + tags=["test", f"category{i % 5}"], + tool_call_template=CallTemplate( + name=f"test_tool{i}", + description=f"Test tool {i}", + call_template_type="default" + ) + ) + tools.append(tool) + + # Mock repository + class MockRepo: + def __init__(self, tools): + self.tools = tools + + async def get_tools(self): + return self.tools + + repo = MockRepo(tools) + + # Test 1: First search (cold start) + print("2. Testing cold start performance...") + start_time = time.perf_counter() + results1 = await strategy.search_tools(repo, "cooking tools", limit=10) + cold_time = time.perf_counter() - start_time + print(f" ⏱️ Cold start: {cold_time:.3f}s, found {len(results1)} results") + + # Test 2: Second search (warm cache) + print("3. Testing warm cache performance...") + start_time = time.perf_counter() + results2 = await strategy.search_tools(repo, "coding tools", limit=10) + warm_time = time.perf_counter() - start_time + print(f" ⏱️ Warm cache: {warm_time:.3f}s, found {len(results2)} results") + + # Test 3: Multiple searches + print("4. Testing multiple searches...") + queries = ["cooking", "programming", "data analysis", "testing", "utilities"] + start_time = time.perf_counter() + + for query in queries: + await strategy.search_tools(repo, query, limit=5) + + total_time = time.perf_counter() - start_time + avg_time = total_time / len(queries) + print(f" ⏱️ Average per search: {avg_time:.3f}s") + + # Performance assertions + assert cold_time < 10.0, f"Cold start too slow: {cold_time}s" # Allow more time for model loading + assert warm_time < 1.0, f"Warm cache too slow: {warm_time}s" + assert avg_time < 0.5, f"Average search too slow: {avg_time}s" + + print("\n🎉 Performance test passed!") + + except Exception as e: + print(f"❌ Performance test failed: {e}") + import traceback + traceback.print_exc() + assert False, f"Performance test failed: {e}" diff --git a/plugins/tool_search/in_mem_embeddings/tests/test_plugin.py b/plugins/tool_search/in_mem_embeddings/tests/test_plugin.py new file mode 100644 index 0000000..636d85b --- /dev/null +++ b/plugins/tool_search/in_mem_embeddings/tests/test_plugin.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +"""Simple test script to verify the in-memory embeddings plugin works.""" + +import sys +import os +import asyncio +from pathlib import Path +import pytest + +# Add the plugin source to Python path +plugin_src = Path(__file__).parent / "src" +sys.path.insert(0, str(plugin_src)) + +# Add core to path for imports +core_src = Path(__file__).parent.parent.parent.parent / "core" / "src" +sys.path.insert(0, str(core_src)) + +@pytest.mark.asyncio +async def test_plugin(): + """Test the plugin functionality.""" + print("🧪 Testing In-Memory Embeddings Plugin...") + + try: + # Test 1: Import the plugin + print("1. Testing imports...") + from utcp_in_mem_embeddings.in_mem_embeddings_search import InMemEmbeddingsSearchStrategy + from utcp_in_mem_embeddings import register + print(" ✅ Imports successful") + + # Test 2: Create strategy instance + print("2. Testing strategy creation...") + strategy = InMemEmbeddingsSearchStrategy( + model_name="all-MiniLM-L6-v2", + similarity_threshold=0.3, + max_workers=2, + cache_embeddings=True + ) + print(f" ✅ Strategy created: {strategy.tool_search_strategy_type}") + + # Test 3: Test registration function + print("3. Testing registration...") + register() + print(" ✅ Registration function works") + + # Test 4: Test basic functionality + print("4. Testing basic functionality...") + + # Create mock tools + from utcp.data.tool import Tool, JsonSchema + from utcp.data.call_template import CallTemplate + + tools = [ + Tool( + name="cooking.spatula", + description="A kitchen utensil for flipping food", + inputs=JsonSchema(), + outputs=JsonSchema(), + tags=["cooking", "kitchen"], + tool_call_template=CallTemplate( + name="cooking.spatula", + description="Spatula tool", + call_template_type="default" + ) + ), + Tool( + name="dev.code_review", + description="Review source code for quality", + inputs=JsonSchema(), + outputs=JsonSchema(), + tags=["programming", "development"], + tool_call_template=CallTemplate( + name="dev.code_review", + description="Code review tool", + call_template_type="default" + ) + ) + ] + + # Create mock repository + class MockRepo: + def __init__(self, tools): + self.tools = tools + + async def get_tools(self): + return self.tools + + repo = MockRepo(tools) + + # Test search + results = await strategy.search_tools(repo, "cooking utensils", limit=2) + print(f" ✅ Search completed, found {len(results)} results") + + if results: + print(f" 📋 Top result: {results[0].name}") + + print("\n🎉 All tests passed! Plugin is working correctly.") + + except Exception as e: + print(f"❌ Test failed: {e}") + import traceback + traceback.print_exc() + assert False, f"Plugin test failed: {e}"