From ebe7b1ccebb453a9edf2354c3cde47e156791514 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Tue, 19 Aug 2025 17:01:14 +0200 Subject: [PATCH] feat: Add stateless MCP operation support and improve documentation - Introduced stateless method callback and provider classes for Complete, Prompt, Resource, and Tool operations using . - Added and method callbacks and providers for all MCP operation types. - Updated README to document stateless support, new callback/provider classes, and usage examples. - Improved Spring integration to support stateless MCP operations. - Added comprehensive tests for stateless method callbacks and providers. - Fixed minor issues and improved parameter validation in method callbacks. - Updated dependency to MCP Java SDK 0.11.2. Signed-off-by: Christian Tzolov --- README.md | 243 +++++- .../spring/AsyncMcpAnnotationProvider.java | 58 ++ .../mcp/spring/SyncMcpAnnotationProvider.java | 58 ++ .../mcp/annotation/PromptAdaptor.java | 2 +- ...yncStatelessMcpCompleteMethodCallback.java | 193 +++++ ...yncStatelessMcpCompleteMethodCallback.java | 180 ++++ .../AbstractMcpPromptMethodCallback.java | 6 +- .../prompt/AsyncMcpPromptMethodCallback.java | 2 +- ...AsyncStatelessMcpPromptMethodCallback.java | 128 +++ .../prompt/SyncMcpPromptMethodCallback.java | 2 +- .../SyncStatelessMcpPromptMethodCallback.java | 117 +++ .../AbstractMcpResourceMethodCallback.java | 12 +- .../AsyncMcpResourceMethodCallback.java | 2 +- ...yncStatelessMcpResourceMethodCallback.java | 166 ++++ .../SyncMcpResourceMethodCallback.java | 2 +- ...yncStatelessMcpResourceMethodCallback.java | 137 ++++ .../AbstractAsyncMcpToolMethodCallback.java | 265 ++++++ .../AbstractSyncMcpToolMethodCallback.java | 178 ++++ .../tool/AsyncMcpToolMethodCallback.java | 200 +---- .../AsyncStatelessMcpToolMethodCallback.java | 80 ++ .../tool/SyncMcpToolMethodCallback.java | 115 +-- .../SyncStatelessMcpToolMethodCallback.java | 68 ++ .../AsyncStatelessMcpCompleteProvider.java | 109 +++ .../AsyncStatelessMcpPromptProvider.java | 109 +++ .../AsyncStatelessMcpResourceProvider.java | 130 +++ .../AsyncStatelessMcpToolProvider.java | 158 ++++ .../SyncStatelessMcpCompleteProvider.java | 105 +++ .../SyncStatelessMcpPromptProvider.java | 105 +++ .../SyncStatelessMcpResourceProvider.java | 126 +++ .../SyncStatelessMcpToolProvider.java | 147 ++++ ...atelessMcpCompleteMethodCallbackTests.java | 635 +++++++++++++++ ...atelessMcpCompleteMethodCallbackTests.java | 468 +++++++++++ ...StatelessMcpPromptMethodCallbackTests.java | 687 ++++++++++++++++ ...StatelessMcpPromptMethodCallbackTests.java | 478 +++++++++++ .../AsyncMcpResourceMethodCallbackTests.java | 1 - ...atelessMcpResourceMethodCallbackTests.java | 646 +++++++++++++++ ...atelessMcpResourceMethodCallbackTests.java | 640 +++++++++++++++ ...ncStatelessMcpToolMethodCallbackTests.java | 770 ++++++++++++++++++ ...ncStatelessMcpToolMethodCallbackTests.java | 538 ++++++++++++ ...syncStatelessMcpCompleteProviderTests.java | 468 +++++++++++ .../AsyncStatelessMcpPromptProviderTests.java | 566 +++++++++++++ ...syncStatelessMcpResourceProviderTests.java | 493 +++++++++++ ...SyncStatelessMcpCompleteProviderTests.java | 451 ++++++++++ .../SyncStatelessMcpPromptProviderTests.java | 556 +++++++++++++ ...SyncStatelessMcpResourceProviderTests.java | 451 ++++++++++ pom.xml | 2 +- 46 files changed, 10754 insertions(+), 299 deletions(-) create mode 100644 mcp-annotations/src/main/java/org/springaicommunity/mcp/method/complete/AsyncStatelessMcpCompleteMethodCallback.java create mode 100644 mcp-annotations/src/main/java/org/springaicommunity/mcp/method/complete/SyncStatelessMcpCompleteMethodCallback.java create mode 100644 mcp-annotations/src/main/java/org/springaicommunity/mcp/method/prompt/AsyncStatelessMcpPromptMethodCallback.java create mode 100644 mcp-annotations/src/main/java/org/springaicommunity/mcp/method/prompt/SyncStatelessMcpPromptMethodCallback.java create mode 100644 mcp-annotations/src/main/java/org/springaicommunity/mcp/method/resource/AsyncStatelessMcpResourceMethodCallback.java create mode 100644 mcp-annotations/src/main/java/org/springaicommunity/mcp/method/resource/SyncStatelessMcpResourceMethodCallback.java create mode 100644 mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractAsyncMcpToolMethodCallback.java create mode 100644 mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractSyncMcpToolMethodCallback.java create mode 100644 mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AsyncStatelessMcpToolMethodCallback.java create mode 100644 mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncStatelessMcpToolMethodCallback.java create mode 100644 mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/AsyncStatelessMcpCompleteProvider.java create mode 100644 mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/AsyncStatelessMcpPromptProvider.java create mode 100644 mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/AsyncStatelessMcpResourceProvider.java create mode 100644 mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/AsyncStatelessMcpToolProvider.java create mode 100644 mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/SyncStatelessMcpCompleteProvider.java create mode 100644 mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/SyncStatelessMcpPromptProvider.java create mode 100644 mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/SyncStatelessMcpResourceProvider.java create mode 100644 mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/SyncStatelessMcpToolProvider.java create mode 100644 mcp-annotations/src/test/java/org/springaicommunity/mcp/method/complete/AsyncStatelessMcpCompleteMethodCallbackTests.java create mode 100644 mcp-annotations/src/test/java/org/springaicommunity/mcp/method/complete/SyncStatelessMcpCompleteMethodCallbackTests.java create mode 100644 mcp-annotations/src/test/java/org/springaicommunity/mcp/method/prompt/AsyncStatelessMcpPromptMethodCallbackTests.java create mode 100644 mcp-annotations/src/test/java/org/springaicommunity/mcp/method/prompt/SyncStatelessMcpPromptMethodCallbackTests.java create mode 100644 mcp-annotations/src/test/java/org/springaicommunity/mcp/method/resource/AsyncStatelessMcpResourceMethodCallbackTests.java create mode 100644 mcp-annotations/src/test/java/org/springaicommunity/mcp/method/resource/SyncStatelessMcpResourceMethodCallbackTests.java create mode 100644 mcp-annotations/src/test/java/org/springaicommunity/mcp/method/tool/AsyncStatelessMcpToolMethodCallbackTests.java create mode 100644 mcp-annotations/src/test/java/org/springaicommunity/mcp/method/tool/SyncStatelessMcpToolMethodCallbackTests.java create mode 100644 mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/AsyncStatelessMcpCompleteProviderTests.java create mode 100644 mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/AsyncStatelessMcpPromptProviderTests.java create mode 100644 mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/AsyncStatelessMcpResourceProviderTests.java create mode 100644 mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/SyncStatelessMcpCompleteProviderTests.java create mode 100644 mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/SyncStatelessMcpPromptProviderTests.java create mode 100644 mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/SyncStatelessMcpResourceProviderTests.java diff --git a/README.md b/README.md index 4ebe542..8905d34 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ The MCP Annotations project provides annotation-based method handling for [Model This project consists of two main modules: 1. **mcp-annotations** - Core annotations and method handling for MCP operations. Depends only on MCP Java SDK. -2. **spring-ai-mcp-annotations** - Spring AI integration for MCP annotations +2. **mcp-annotations-spring** - Spring AI integration for MCP annotations ## Overview @@ -46,8 +46,8 @@ To use the Spring integration module, add the following dependency: ```xml - corg.springaicommunity - spring-ai-mcp-annotations + org.springaicommunity + mcp-annotations-spring 0.2.0-SNAPSHOT ``` @@ -96,7 +96,7 @@ The core module provides a set of annotations and callback implementations for p Each operation type has both synchronous and asynchronous implementations, allowing for flexible integration with different application architectures. -### Spring Integration Module (spring-ai-mcp-annotations) +### Spring Integration Module (mcp-annotations-spring) The Spring integration module provides seamless integration with Spring AI and Spring Framework applications. It handles Spring-specific concerns such as AOP proxies and integrates with Spring AI's model abstractions. @@ -122,16 +122,22 @@ The modules provide callback implementations for each operation type: - `AbstractMcpCompleteMethodCallback` - Base class for complete method callbacks - `SyncMcpCompleteMethodCallback` - Synchronous implementation - `AsyncMcpCompleteMethodCallback` - Asynchronous implementation using Reactor's Mono +- `SyncStatelessMcpCompleteMethodCallback` - Synchronous stateless implementation using `McpTransportContext` +- `AsyncStatelessMcpCompleteMethodCallback` - Asynchronous stateless implementation using `McpTransportContext` #### Prompt - `AbstractMcpPromptMethodCallback` - Base class for prompt method callbacks - `SyncMcpPromptMethodCallback` - Synchronous implementation - `AsyncMcpPromptMethodCallback` - Asynchronous implementation using Reactor's Mono +- `SyncStatelessMcpPromptMethodCallback` - Synchronous stateless implementation using `McpTransportContext` +- `AsyncStatelessMcpPromptMethodCallback` - Asynchronous stateless implementation using `McpTransportContext` #### Resource - `AbstractMcpResourceMethodCallback` - Base class for resource method callbacks - `SyncMcpResourceMethodCallback` - Synchronous implementation - `AsyncMcpResourceMethodCallback` - Asynchronous implementation using Reactor's Mono +- `SyncStatelessMcpResourceMethodCallback` - Synchronous stateless implementation using `McpTransportContext` +- `AsyncStatelessMcpResourceMethodCallback` - Asynchronous stateless implementation using `McpTransportContext` #### Logging Consumer - `AbstractMcpLoggingConsumerMethodCallback` - Base class for logging consumer method callbacks @@ -139,8 +145,12 @@ The modules provide callback implementations for each operation type: - `AsyncMcpLoggingConsumerMethodCallback` - Asynchronous implementation using Reactor's Mono #### Tool -- `SyncMcpToolMethodCallback` - Synchronous implementation for tool method callbacks -- `AsyncMcpToolMethodCallback` - Asynchronous implementation using Reactor's Mono +- `AbstractSyncMcpToolMethodCallback` - Base class for synchronous tool method callbacks +- `AbstractAsyncMcpToolMethodCallback` - Base class for asynchronous tool method callbacks +- `SyncMcpToolMethodCallback` - Synchronous implementation for tool method callbacks with server exchange +- `AsyncMcpToolMethodCallback` - Asynchronous implementation using Reactor's Mono with server exchange +- `SyncStatelessMcpToolMethodCallback` - Synchronous stateless implementation for tool method callbacks +- `AsyncStatelessMcpToolMethodCallback` - Asynchronous stateless implementation using Reactor's Mono #### Sampling - `AbstractMcpSamplingMethodCallback` - Base class for sampling method callbacks @@ -156,6 +166,7 @@ The modules provide callback implementations for each operation type: The project includes provider classes that scan for annotated methods and create appropriate callbacks: +#### Stateful Providers (using McpSyncServerExchange/McpAsyncServerExchange) - `SyncMcpCompletionProvider` - Processes `@McpComplete` annotations for synchronous operations - `SyncMcpPromptProvider` - Processes `@McpPrompt` annotations for synchronous operations - `SyncMcpResourceProvider` - Processes `@McpResource` annotations for synchronous operations @@ -168,6 +179,16 @@ The project includes provider classes that scan for annotated methods and create - `SyncMcpElicitationProvider` - Processes `@McpElicitation` annotations for synchronous operations - `AsyncMcpElicitationProvider` - Processes `@McpElicitation` annotations for asynchronous operations +#### Stateless Providers (using McpTransportContext) +- `SyncStatelessMcpCompleteProvider` - Processes `@McpComplete` annotations for synchronous stateless operations +- `AsyncStatelessMcpCompleteProvider` - Processes `@McpComplete` annotations for asynchronous stateless operations +- `SyncStatelessMcpPromptProvider` - Processes `@McpPrompt` annotations for synchronous stateless operations +- `AsyncStatelessMcpPromptProvider` - Processes `@McpPrompt` annotations for asynchronous stateless operations +- `SyncStatelessMcpResourceProvider` - Processes `@McpResource` annotations for synchronous stateless operations +- `AsyncStatelessMcpResourceProvider` - Processes `@McpResource` annotations for asynchronous stateless operations +- `SyncStatelessMcpToolProvider` - Processes `@McpTool` annotations for synchronous stateless operations +- `AsyncStatelessMcpToolProvider` - Processes `@McpTool` annotations for asynchronous stateless operations + ### Spring Integration The Spring integration module provides: @@ -782,6 +803,188 @@ public class MyMcpClient { ``` +### Stateless Examples + +The library supports stateless implementations that use `McpTransportContext` instead of `McpSyncServerExchange` or `McpAsyncServerExchange`. This is useful for scenarios where you don't need the full server exchange context. + +#### Stateless Complete Example + +```java +public class StatelessAutocompleteProvider { + + private final Map> cityDatabase = new HashMap<>(); + + public StatelessAutocompleteProvider() { + // Initialize with sample data + cityDatabase.put("l", List.of("Lagos", "Lima", "Lisbon", "London", "Los Angeles")); + // Add more data... + } + + @McpComplete(prompt = "travel-planner") + public List completeCityName(McpTransportContext context, CompleteRequest.CompleteArgument argument) { + String prefix = argument.value().toLowerCase(); + String firstLetter = prefix.substring(0, 1); + List cities = cityDatabase.getOrDefault(firstLetter, List.of()); + + return cities.stream() + .filter(city -> city.toLowerCase().startsWith(prefix)) + .toList(); + } + + // Stateless method without context parameter + @McpComplete(prompt = "simple-complete") + public List simpleComplete(String value) { + return List.of("option1", "option2", "option3") + .stream() + .filter(option -> option.startsWith(value.toLowerCase())) + .toList(); + } +} +``` + +#### Stateless Prompt Example + +```java +public class StatelessPromptProvider { + + @McpPrompt(name = "simple-greeting", description = "Generate a simple greeting") + public GetPromptResult simpleGreeting( + @McpArg(name = "name", description = "The user's name", required = true) String name) { + + String message = "Hello, " + name + "! How can I help you today?"; + + return new GetPromptResult("Simple Greeting", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent(message)))); + } + + @McpPrompt(name = "contextual-greeting", description = "Generate a greeting with context") + public GetPromptResult contextualGreeting( + McpTransportContext context, + @McpArg(name = "name", description = "The user's name", required = true) String name) { + + // You can access transport context if needed + String message = "Hello, " + name + "! Welcome to our stateless MCP server."; + + return new GetPromptResult("Contextual Greeting", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent(message)))); + } +} +``` + +#### Stateless Resource Example + +```java +public class StatelessResourceProvider { + + private final Map resourceData = new HashMap<>(); + + public StatelessResourceProvider() { + resourceData.put("config", "server.port=8080\nserver.host=localhost"); + resourceData.put("readme", "# Welcome\nThis is a sample resource."); + } + + @McpResource(uri = "config://{key}", + name = "Configuration", + description = "Provides configuration data") + public String getConfig(String key) { + return resourceData.getOrDefault(key, "Configuration not found"); + } + + @McpResource(uri = "data://{id}", + name = "Data Resource", + description = "Provides data with transport context") + public ReadResourceResult getData(McpTransportContext context, String id) { + String data = resourceData.getOrDefault(id, "Data not found for ID: " + id); + + return new ReadResourceResult(List.of( + new TextResourceContents("data://" + id, "text/plain", data) + )); + } +} +``` + +#### Stateless Tool Example + +```java +public class StatelessCalculatorProvider { + + @McpTool(name = "add-stateless", description = "Add two numbers (stateless)") + public int addStateless( + @McpToolParam(description = "First number", required = true) int a, + @McpToolParam(description = "Second number", required = true) int b) { + return a + b; + } + + @McpTool(name = "multiply-with-context", description = "Multiply with transport context") + public double multiplyWithContext( + McpTransportContext context, + @McpToolParam(description = "First number", required = true) double x, + @McpToolParam(description = "Second number", required = true) double y) { + // Access transport context if needed + return x * y; + } + + // Async stateless tool + @McpTool(name = "async-divide", description = "Divide two numbers asynchronously") + public Mono asyncDivide( + @McpToolParam(description = "Dividend", required = true) double dividend, + @McpToolParam(description = "Divisor", required = true) double divisor) { + + return Mono.fromCallable(() -> { + if (divisor == 0) { + throw new IllegalArgumentException("Division by zero"); + } + return dividend / divisor; + }); + } +} +``` + +#### Using Stateless Providers + +```java +public class StatelessMcpServerFactory { + + public McpSyncServer createStatelessServer( + StatelessAutocompleteProvider completeProvider, + StatelessPromptProvider promptProvider, + StatelessResourceProvider resourceProvider, + StatelessCalculatorProvider toolProvider) { + + // Create stateless specifications + List completionSpecs = + new SyncStatelessMcpCompleteProvider(List.of(completeProvider)).getCompleteSpecifications(); + + List promptSpecs = + new SyncStatelessMcpPromptProvider(List.of(promptProvider)).getPromptSpecifications(); + + List resourceSpecs = + new SyncStatelessMcpResourceProvider(List.of(resourceProvider)).getResourceSpecifications(); + + List toolSpecs = + new SyncStatelessMcpToolProvider(List.of(toolProvider)).getToolSpecifications(); + + // Create a stateless server + McpSyncServer syncServer = McpServer.sync(transportProvider) + .serverInfo("stateless-server", "1.0.0") + .capabilities(ServerCapabilities.builder() + .tools(true) + .resources(true) + .prompts(true) + .completions() + .logging() + .build()) + .statelessTools(toolSpecs) + .statelessResources(resourceSpecs) + .statelessPrompts(promptSpecs) + .statelessCompletions(completionSpecs) + .build(); + + return syncServer; + } +} +``` + ### Spring Integration Example ```java @@ -847,6 +1050,26 @@ public class McpConfig { List asyncElicitationHandlers) { return SpringAiMcpAnnotationProvider.createAsyncElicitationHandler(asyncElicitationHandlers); } + + // Stateless Spring Integration Examples + + @Bean + public List syncStatelessToolSpecifications( + List statelessToolProviders) { + return SpringAiMcpAnnotationProvider.createSyncStatelessToolSpecifications(statelessToolProviders); + } + + @Bean + public List syncStatelessPromptSpecifications( + List statelessPromptProviders) { + return SpringAiMcpAnnotationProvider.createSyncStatelessPromptSpecifications(statelessPromptProviders); + } + + @Bean + public List syncStatelessResourceSpecifications( + List statelessResourceProviders) { + return SpringAiMcpAnnotationProvider.createSyncStatelessResourceSpecifications(statelessResourceProviders); + } } ``` @@ -854,21 +1077,23 @@ public class McpConfig { - **Annotation-based method handling** - Simplifies the creation and registration of MCP methods - **Support for both synchronous and asynchronous operations** - Flexible integration with different application architectures +- **Stateful and stateless implementations** - Choose between full server exchange context (`McpSyncServerExchange`/`McpAsyncServerExchange`) or lightweight transport context (`McpTransportContext`) for all MCP operations +- **Comprehensive stateless support** - All MCP operations (Complete, Prompt, Resource, Tool) support stateless implementations for scenarios where full server context is not needed - **Builder pattern for callback creation** - Clean and fluent API for creating method callbacks - **Comprehensive validation** - Ensures method signatures are compatible with MCP operations - **URI template support** - Powerful URI template handling for resource and completion operations - **Tool support with automatic JSON schema generation** - Create MCP tools with automatic input/output schema generation from method signatures - **Logging consumer support** - Handle logging message notifications from MCP servers - **Sampling support** - Handle sampling requests from MCP servers -- **Spring integration** - Seamless integration with Spring Framework and Spring AI +- **Spring integration** - Seamless integration with Spring Framework and Spring AI, including support for both stateful and stateless operations - **AOP proxy support** - Proper handling of Spring AOP proxies when processing annotations ## Requirements - Java 17 or higher - Reactor Core (for async operations) -- MCP Java SDK 0.10.0 or higher -- Spring Framework and Spring AI (for spring-ai-mcp-annotations module) +- MCP Java SDK 0.11.2 or higher +- Spring Framework and Spring AI (for mcp-annotations-spring module) ## Building from Source diff --git a/mcp-annotations-spring/src/main/java/org/springaicommunity/mcp/spring/AsyncMcpAnnotationProvider.java b/mcp-annotations-spring/src/main/java/org/springaicommunity/mcp/spring/AsyncMcpAnnotationProvider.java index 67dc3ce..f164139 100644 --- a/mcp-annotations-spring/src/main/java/org/springaicommunity/mcp/spring/AsyncMcpAnnotationProvider.java +++ b/mcp-annotations-spring/src/main/java/org/springaicommunity/mcp/spring/AsyncMcpAnnotationProvider.java @@ -23,8 +23,12 @@ import org.springaicommunity.mcp.provider.AsyncMcpLoggingConsumerProvider; import org.springaicommunity.mcp.provider.AsyncMcpSamplingProvider; import org.springaicommunity.mcp.provider.AsyncMcpToolProvider; +import org.springaicommunity.mcp.provider.AsyncStatelessMcpPromptProvider; +import org.springaicommunity.mcp.provider.AsyncStatelessMcpResourceProvider; +import org.springaicommunity.mcp.provider.AsyncStatelessMcpToolProvider; import io.modelcontextprotocol.server.McpServerFeatures.AsyncToolSpecification; +import io.modelcontextprotocol.server.McpStatelessServerFeatures; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; @@ -89,6 +93,45 @@ protected Method[] doGetClassMethods(Object bean) { } + private static class SpringAiAsyncStatelessMcpToolProvider extends AsyncStatelessMcpToolProvider { + + public SpringAiAsyncStatelessMcpToolProvider(List toolObjects) { + super(toolObjects); + } + + @Override + protected Method[] doGetClassMethods(Object bean) { + return AnnotationProviderUtil.beanMethods(bean); + } + + } + + private static class SpringAiAsyncStatelessPromptProvider extends AsyncStatelessMcpPromptProvider { + + public SpringAiAsyncStatelessPromptProvider(List promptObjects) { + super(promptObjects); + } + + @Override + protected Method[] doGetClassMethods(Object bean) { + return AnnotationProviderUtil.beanMethods(bean); + } + + } + + private static class SpringAiAsyncStatelessResourceProvider extends AsyncStatelessMcpResourceProvider { + + public SpringAiAsyncStatelessResourceProvider(List resourceObjects) { + super(resourceObjects); + } + + @Override + protected Method[] doGetClassMethods(Object bean) { + return AnnotationProviderUtil.beanMethods(bean); + } + + } + public static List>> createAsyncLoggingConsumers( List loggingObjects) { return new SpringAiAsyncMcpLoggingConsumerProvider(loggingObjects).getLoggingConsumers(); @@ -108,4 +151,19 @@ public static List createAsyncToolSpecifications(List createAsyncStatelessToolSpecifications( + List toolObjects) { + return new SpringAiAsyncStatelessMcpToolProvider(toolObjects).getToolSpecifications(); + } + + public static List createAsyncStatelessPromptSpecifications( + List promptObjects) { + return new SpringAiAsyncStatelessPromptProvider(promptObjects).getPromptSpecifications(); + } + + public static List createAsyncStatelessResourceSpecifications( + List resourceObjects) { + return new SpringAiAsyncStatelessResourceProvider(resourceObjects).getResourceSpecifications(); + } + } diff --git a/mcp-annotations-spring/src/main/java/org/springaicommunity/mcp/spring/SyncMcpAnnotationProvider.java b/mcp-annotations-spring/src/main/java/org/springaicommunity/mcp/spring/SyncMcpAnnotationProvider.java index 7ac4738..3bbe604 100644 --- a/mcp-annotations-spring/src/main/java/org/springaicommunity/mcp/spring/SyncMcpAnnotationProvider.java +++ b/mcp-annotations-spring/src/main/java/org/springaicommunity/mcp/spring/SyncMcpAnnotationProvider.java @@ -27,11 +27,15 @@ import org.springaicommunity.mcp.provider.SyncMcpResourceProvider; import org.springaicommunity.mcp.provider.SyncMcpSamplingProvider; import org.springaicommunity.mcp.provider.SyncMcpToolProvider; +import org.springaicommunity.mcp.provider.SyncStatelessMcpPromptProvider; +import org.springaicommunity.mcp.provider.SyncStatelessMcpResourceProvider; +import org.springaicommunity.mcp.provider.SyncStatelessMcpToolProvider; import io.modelcontextprotocol.server.McpServerFeatures.SyncCompletionSpecification; import io.modelcontextprotocol.server.McpServerFeatures.SyncPromptSpecification; import io.modelcontextprotocol.server.McpServerFeatures.SyncResourceSpecification; import io.modelcontextprotocol.server.McpServerFeatures.SyncToolSpecification; +import io.modelcontextprotocol.server.McpStatelessServerFeatures; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; @@ -69,6 +73,19 @@ protected Method[] doGetClassMethods(Object bean) { } + private static class SpringAiSyncStatelessToolProvider extends SyncStatelessMcpToolProvider { + + public SpringAiSyncStatelessToolProvider(List toolObjects) { + super(toolObjects); + } + + @Override + protected Method[] doGetClassMethods(Object bean) { + return AnnotationProviderUtil.beanMethods(bean); + } + + } + private static class SpringAiSyncMcpPromptProvider extends SyncMcpPromptProvider { public SpringAiSyncMcpPromptProvider(List promptObjects) { @@ -82,6 +99,19 @@ protected Method[] doGetClassMethods(Object bean) { }; + private static class SpringAiSyncStatelessPromptProvider extends SyncStatelessMcpPromptProvider { + + public SpringAiSyncStatelessPromptProvider(List promptObjects) { + super(promptObjects); + } + + @Override + protected Method[] doGetClassMethods(Object bean) { + return AnnotationProviderUtil.beanMethods(bean); + } + + } + private static class SpringAiSyncMcpResourceProvider extends SyncMcpResourceProvider { public SpringAiSyncMcpResourceProvider(List resourceObjects) { @@ -95,6 +125,19 @@ protected Method[] doGetClassMethods(Object bean) { } + private static class SpringAiSyncStatelessResourceProvider extends SyncStatelessMcpResourceProvider { + + public SpringAiSyncStatelessResourceProvider(List resourceObjects) { + super(resourceObjects); + } + + @Override + protected Method[] doGetClassMethods(Object bean) { + return AnnotationProviderUtil.beanMethods(bean); + } + + } + private static class SpringAiSyncMcpLoggingConsumerProvider extends SyncMcpLoggingConsumerProvider { public SpringAiSyncMcpLoggingConsumerProvider(List loggingObjects) { @@ -138,6 +181,11 @@ public static List createSyncToolSpecifications(List createSyncStatelessToolSpecifications( + List toolObjects) { + return new SpringAiSyncStatelessToolProvider(toolObjects).getToolSpecifications(); + } + public static List createSyncCompleteSpecifications(List completeObjects) { return new SpringAiSyncMcpCompletionProvider(completeObjects).getCompleteSpecifications(); } @@ -146,10 +194,20 @@ public static List createSyncPromptSpecifications(List< return new SpringAiSyncMcpPromptProvider(promptObjects).getPromptSpecifications(); } + public static List createSyncStatelessPromptSpecifications( + List promptObjects) { + return new SpringAiSyncStatelessPromptProvider(promptObjects).getPromptSpecifications(); + } + public static List createSyncResourceSpecifications(List resourceObjects) { return new SpringAiSyncMcpResourceProvider(resourceObjects).getResourceSpecifications(); } + public static List createSyncStatelessResourceSpecifications( + List resourceObjects) { + return new SpringAiSyncStatelessResourceProvider(resourceObjects).getResourceSpecifications(); + } + public static List> createSyncLoggingConsumers(List loggingObjects) { return new SpringAiSyncMcpLoggingConsumerProvider(loggingObjects).getLoggingConsumers(); } diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/annotation/PromptAdaptor.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/annotation/PromptAdaptor.java index 9fef332..acd4dcf 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/annotation/PromptAdaptor.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/annotation/PromptAdaptor.java @@ -44,7 +44,7 @@ public static McpSchema.Prompt asPrompt(McpPrompt mcpPrompt, Method method) { private static String getName(McpPrompt promptAnnotation, Method method) { Assert.notNull(method, "method cannot be null"); - if (promptAnnotation == null || (promptAnnotation.name() == null)) { + if (promptAnnotation == null || (promptAnnotation.name() == null) || promptAnnotation.name().isEmpty()) { return method.getName(); } return promptAnnotation.name(); diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/complete/AsyncStatelessMcpCompleteMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/complete/AsyncStatelessMcpCompleteMethodCallback.java new file mode 100644 index 0000000..be8c320 --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/complete/AsyncStatelessMcpCompleteMethodCallback.java @@ -0,0 +1,193 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.method.complete; + +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.List; +import java.util.function.BiFunction; + +import org.springaicommunity.mcp.annotation.McpComplete; + +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema.CompleteRequest; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult.CompleteCompletion; +import io.modelcontextprotocol.util.DeafaultMcpUriTemplateManagerFactory; +import reactor.core.publisher.Mono; + +/** + * Class for creating BiFunction callbacks around complete methods with asynchronous + * processing for stateless contexts. + * + * This class provides a way to convert methods annotated with {@link McpComplete} into + * callback functions that can be used to handle completion requests asynchronously in + * stateless environments. It supports various method signatures and return types, and + * handles both prompt and URI template completions. + * + * @author Christian Tzolov + */ +public final class AsyncStatelessMcpCompleteMethodCallback extends AbstractMcpCompleteMethodCallback + implements BiFunction> { + + private AsyncStatelessMcpCompleteMethodCallback(Builder builder) { + super(builder.method, builder.bean, builder.prompt, builder.uri, builder.uriTemplateManagerFactory); + this.validateMethod(this.method); + } + + /** + * Apply the callback to the given context and request. + *

+ * This method builds the arguments for the method call, invokes the method, and + * converts the result to a CompleteResult. + * @param context The transport context, may be null if the method doesn't require it + * @param request The complete request, must not be null + * @return A Mono that emits the complete result + * @throws McpCompleteMethodException if there is an error invoking the complete + * method + * @throws IllegalArgumentException if the request is null + */ + @Override + public Mono apply(McpTransportContext context, CompleteRequest request) { + if (request == null) { + return Mono.error(new IllegalArgumentException("Request must not be null")); + } + + return Mono.defer(() -> { + try { + // Build arguments for the method call + Object[] args = this.buildArgs(this.method, context, request); + + // Invoke the method + this.method.setAccessible(true); + Object result = this.method.invoke(this.bean, args); + + // Handle the result based on its type + if (result instanceof Mono) { + // If the result is already a Mono, map it to a CompleteResult + return ((Mono) result).map(r -> convertToCompleteResult(r)); + } + else { + // Otherwise, convert the result to a CompleteResult and wrap in a + // Mono + return Mono.just(convertToCompleteResult(result)); + } + } + catch (Exception e) { + return Mono.error( + new McpCompleteMethodException("Error invoking complete method: " + this.method.getName(), e)); + } + }); + } + + /** + * Converts a result object to a CompleteResult. + * @param result The result object + * @return The CompleteResult + */ + private CompleteResult convertToCompleteResult(Object result) { + if (result == null) { + return new CompleteResult(new CompleteCompletion(List.of(), 0, false)); + } + + if (result instanceof CompleteResult) { + return (CompleteResult) result; + } + + if (result instanceof CompleteCompletion) { + return new CompleteResult((CompleteCompletion) result); + } + + if (result instanceof List) { + List list = (List) result; + List values = new ArrayList<>(); + + for (Object item : list) { + if (item instanceof String) { + values.add((String) item); + } + else { + throw new IllegalArgumentException("List items must be of type String"); + } + } + + return new CompleteResult(new CompleteCompletion(values, values.size(), false)); + } + + if (result instanceof String) { + return new CompleteResult(new CompleteCompletion(List.of((String) result), 1, false)); + } + + throw new IllegalArgumentException("Unsupported return type: " + result.getClass().getName()); + } + + /** + * Builder for creating AsyncStatelessMcpCompleteMethodCallback instances. + *

+ * This builder provides a fluent API for constructing + * AsyncStatelessMcpCompleteMethodCallback instances with the required parameters. + */ + public static class Builder extends AbstractBuilder { + + /** + * Constructor for Builder. + */ + public Builder() { + this.uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); + } + + /** + * Build the callback. + * @return A new AsyncStatelessMcpCompleteMethodCallback instance + */ + @Override + public AsyncStatelessMcpCompleteMethodCallback build() { + validate(); + return new AsyncStatelessMcpCompleteMethodCallback(this); + } + + } + + /** + * Create a new builder. + * @return A new builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Validates that the method return type is compatible with the complete callback. + * @param method The method to validate + * @throws IllegalArgumentException if the return type is not compatible + */ + @Override + protected void validateReturnType(Method method) { + Class returnType = method.getReturnType(); + + boolean validReturnType = CompleteResult.class.isAssignableFrom(returnType) + || CompleteCompletion.class.isAssignableFrom(returnType) || List.class.isAssignableFrom(returnType) + || String.class.isAssignableFrom(returnType) || Mono.class.isAssignableFrom(returnType); + + if (!validReturnType) { + throw new IllegalArgumentException( + "Method must return either CompleteResult, CompleteCompletion, List, " + + "String, or Mono: " + method.getName() + " in " + method.getDeclaringClass().getName() + + " returns " + returnType.getName()); + } + } + + /** + * Checks if a parameter type is compatible with the exchange type. + * @param paramType The parameter type to check + * @return true if the parameter type is compatible with the exchange type, false + * otherwise + */ + @Override + protected boolean isExchangeType(Class paramType) { + return McpTransportContext.class.isAssignableFrom(paramType); + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/complete/SyncStatelessMcpCompleteMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/complete/SyncStatelessMcpCompleteMethodCallback.java new file mode 100644 index 0000000..4669c9f --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/complete/SyncStatelessMcpCompleteMethodCallback.java @@ -0,0 +1,180 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.method.complete; + +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.List; +import java.util.function.BiFunction; + +import org.springaicommunity.mcp.annotation.McpComplete; + +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema.CompleteRequest; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult.CompleteCompletion; +import io.modelcontextprotocol.util.DeafaultMcpUriTemplateManagerFactory; + +/** + * Class for creating BiFunction callbacks around complete methods for stateless contexts. + * + * This class provides a way to convert methods annotated with {@link McpComplete} into + * callback functions that can be used to handle completion requests in stateless + * environments. It supports various method signatures and return types, and handles both + * prompt and URI template completions. + * + * @author Christian Tzolov + */ +public final class SyncStatelessMcpCompleteMethodCallback extends AbstractMcpCompleteMethodCallback + implements BiFunction { + + private SyncStatelessMcpCompleteMethodCallback(Builder builder) { + super(builder.method, builder.bean, builder.prompt, builder.uri, builder.uriTemplateManagerFactory); + this.validateMethod(this.method); + } + + /** + * Apply the callback to the given context and request. + *

+ * This method builds the arguments for the method call, invokes the method, and + * converts the result to a CompleteResult. + * @param context The transport context, may be null if the method doesn't require it + * @param request The complete request, must not be null + * @return The complete result + * @throws McpCompleteMethodException if there is an error invoking the complete + * method + * @throws IllegalArgumentException if the request is null + */ + @Override + public CompleteResult apply(McpTransportContext context, CompleteRequest request) { + if (request == null) { + throw new IllegalArgumentException("Request must not be null"); + } + + try { + // Build arguments for the method call + Object[] args = this.buildArgs(this.method, context, request); + + // Invoke the method + this.method.setAccessible(true); + Object result = this.method.invoke(this.bean, args); + + // Convert the result to a CompleteResult + return convertToCompleteResult(result); + } + catch (Exception e) { + throw new McpCompleteMethodException("Error invoking complete method: " + this.method.getName(), e); + } + } + + /** + * Converts the method result to a CompleteResult. + * @param result The method result + * @return The CompleteResult + */ + private CompleteResult convertToCompleteResult(Object result) { + if (result == null) { + return new CompleteResult(new CompleteCompletion(List.of(), 0, false)); + } + + if (result instanceof CompleteResult) { + return (CompleteResult) result; + } + + if (result instanceof CompleteCompletion) { + return new CompleteResult((CompleteCompletion) result); + } + + if (result instanceof List) { + List list = (List) result; + List values = new ArrayList<>(); + + for (Object item : list) { + if (item instanceof String) { + values.add((String) item); + } + else { + throw new IllegalArgumentException("List items must be of type String"); + } + } + + return new CompleteResult(new CompleteCompletion(values, values.size(), false)); + } + + if (result instanceof String) { + return new CompleteResult(new CompleteCompletion(List.of((String) result), 1, false)); + } + + throw new IllegalArgumentException("Unsupported return type: " + result.getClass().getName()); + } + + /** + * Builder for creating SyncStatelessMcpCompleteMethodCallback instances. + *

+ * This builder provides a fluent API for constructing + * SyncStatelessMcpCompleteMethodCallback instances with the required parameters. + */ + public static class Builder extends AbstractBuilder { + + /** + * Constructor for Builder. + */ + public Builder() { + this.uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); + } + + /** + * Build the callback. + * @return A new SyncStatelessMcpCompleteMethodCallback instance + */ + @Override + public SyncStatelessMcpCompleteMethodCallback build() { + validate(); + return new SyncStatelessMcpCompleteMethodCallback(this); + } + + } + + /** + * Create a new builder. + * @return A new builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Validates that the method return type is compatible with the complete callback. + * @param method The method to validate + * @throws IllegalArgumentException if the return type is not compatible + */ + @Override + protected void validateReturnType(Method method) { + Class returnType = method.getReturnType(); + + boolean validReturnType = CompleteResult.class.isAssignableFrom(returnType) + || CompleteCompletion.class.isAssignableFrom(returnType) || List.class.isAssignableFrom(returnType) + || String.class.isAssignableFrom(returnType); + + if (!validReturnType) { + throw new IllegalArgumentException( + "Method must return either CompleteResult, CompleteCompletion, List, " + "or String: " + + method.getName() + " in " + method.getDeclaringClass().getName() + " returns " + + returnType.getName()); + } + } + + /** + * Checks if a parameter type is compatible with the exchange type. + * @param paramType The parameter type to check + * @return true if the parameter type is compatible with the exchange type, false + * otherwise + */ + @Override + protected boolean isExchangeType(Class paramType) { + return McpTransportContext.class.isAssignableFrom(paramType); + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/prompt/AbstractMcpPromptMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/prompt/AbstractMcpPromptMethodCallback.java index ad3bdee..7f9625c 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/prompt/AbstractMcpPromptMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/prompt/AbstractMcpPromptMethodCallback.java @@ -73,7 +73,7 @@ protected void validateMethod(Method method) { * @return true if the parameter type is compatible with the exchange type, false * otherwise */ - protected abstract boolean isExchangeType(Class paramType); + protected abstract boolean isExchangeOrContextType(Class paramType); /** * Validates method parameters. @@ -91,7 +91,7 @@ protected void validateParameters(Method method) { for (java.lang.reflect.Parameter param : parameters) { Class paramType = param.getType(); - if (isExchangeType(paramType)) { + if (isExchangeOrContextType(paramType)) { if (hasExchangeParam) { throw new IllegalArgumentException("Method cannot have more than one exchange parameter: " + method.getName() + " in " + method.getDeclaringClass().getName()); @@ -134,7 +134,7 @@ protected Object[] buildArgs(Method method, Object exchange, GetPromptRequest re java.lang.reflect.Parameter param = parameters[i]; Class paramType = param.getType(); - if (isExchangeType(paramType)) { + if (isExchangeOrContextType(paramType)) { args[i] = exchange; } else if (GetPromptRequest.class.isAssignableFrom(paramType)) { diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/prompt/AsyncMcpPromptMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/prompt/AsyncMcpPromptMethodCallback.java index df92962..198c981 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/prompt/AsyncMcpPromptMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/prompt/AsyncMcpPromptMethodCallback.java @@ -76,7 +76,7 @@ public Mono apply(McpAsyncServerExchange exchange, GetPromptReq } @Override - protected boolean isExchangeType(Class paramType) { + protected boolean isExchangeOrContextType(Class paramType) { return McpAsyncServerExchange.class.isAssignableFrom(paramType); } diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/prompt/AsyncStatelessMcpPromptMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/prompt/AsyncStatelessMcpPromptMethodCallback.java new file mode 100644 index 0000000..7f41c9b --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/prompt/AsyncStatelessMcpPromptMethodCallback.java @@ -0,0 +1,128 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.method.prompt; + +import java.lang.reflect.Method; +import java.util.List; +import java.util.function.BiFunction; + +import org.springaicommunity.mcp.annotation.McpPrompt; + +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; +import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; +import io.modelcontextprotocol.spec.McpSchema.PromptMessage; +import reactor.core.publisher.Mono; + +/** + * Class for creating BiFunction callbacks around prompt methods with asynchronous + * processing for stateless contexts. + * + * This class provides a way to convert methods annotated with {@link McpPrompt} into + * callback functions that can be used to handle prompt requests asynchronously in + * stateless environments. It supports various method signatures and return types. + * + * @author Christian Tzolov + */ +public final class AsyncStatelessMcpPromptMethodCallback extends AbstractMcpPromptMethodCallback + implements BiFunction> { + + private AsyncStatelessMcpPromptMethodCallback(Builder builder) { + super(builder.method, builder.bean, builder.prompt); + } + + /** + * Apply the callback to the given context and request. + *

+ * This method builds the arguments for the method call, invokes the method, and + * converts the result to a GetPromptResult. + * @param context The transport context, may be null if the method doesn't require it + * @param request The prompt request, must not be null + * @return A Mono that emits the prompt result + * @throws McpPromptMethodException if there is an error invoking the prompt method + * @throws IllegalArgumentException if the request is null + */ + @Override + public Mono apply(McpTransportContext context, GetPromptRequest request) { + if (request == null) { + return Mono.error(new IllegalArgumentException("Request must not be null")); + } + + return Mono.defer(() -> { + try { + // Build arguments for the method call + Object[] args = this.buildArgs(this.method, context, request); + + // Invoke the method + this.method.setAccessible(true); + Object result = this.method.invoke(this.bean, args); + + // Handle the result based on its type + if (result instanceof Mono) { + // If the result is already a Mono, map it to a GetPromptResult + return ((Mono) result).map(r -> convertToGetPromptResult(r)); + } + else { + // Otherwise, convert the result to a GetPromptResult and wrap in a + // Mono + return Mono.just(convertToGetPromptResult(result)); + } + } + catch (Exception e) { + return Mono + .error(new McpPromptMethodException("Error invoking prompt method: " + this.method.getName(), e)); + } + }); + } + + @Override + protected boolean isExchangeOrContextType(Class paramType) { + return McpTransportContext.class.isAssignableFrom(paramType); + } + + @Override + protected void validateReturnType(Method method) { + Class returnType = method.getReturnType(); + + boolean validReturnType = GetPromptResult.class.isAssignableFrom(returnType) + || List.class.isAssignableFrom(returnType) || PromptMessage.class.isAssignableFrom(returnType) + || String.class.isAssignableFrom(returnType) || Mono.class.isAssignableFrom(returnType); + + if (!validReturnType) { + throw new IllegalArgumentException("Method must return either GetPromptResult, List, " + + "List, PromptMessage, String, or Mono: " + method.getName() + " in " + + method.getDeclaringClass().getName() + " returns " + returnType.getName()); + } + } + + /** + * Builder for creating AsyncStatelessMcpPromptMethodCallback instances. + *

+ * This builder provides a fluent API for constructing + * AsyncStatelessMcpPromptMethodCallback instances with the required parameters. + */ + public static class Builder extends AbstractBuilder { + + /** + * Build the callback. + * @return A new AsyncStatelessMcpPromptMethodCallback instance + */ + @Override + public AsyncStatelessMcpPromptMethodCallback build() { + validate(); + return new AsyncStatelessMcpPromptMethodCallback(this); + } + + } + + /** + * Create a new builder. + * @return A new builder instance + */ + public static Builder builder() { + return new Builder(); + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/prompt/SyncMcpPromptMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/prompt/SyncMcpPromptMethodCallback.java index 138bddc..4f4c6a4 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/prompt/SyncMcpPromptMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/prompt/SyncMcpPromptMethodCallback.java @@ -67,7 +67,7 @@ public GetPromptResult apply(McpSyncServerExchange exchange, GetPromptRequest re } @Override - protected boolean isExchangeType(Class paramType) { + protected boolean isExchangeOrContextType(Class paramType) { return McpSyncServerExchange.class.isAssignableFrom(paramType); } diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/prompt/SyncStatelessMcpPromptMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/prompt/SyncStatelessMcpPromptMethodCallback.java new file mode 100644 index 0000000..2ac2aae --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/prompt/SyncStatelessMcpPromptMethodCallback.java @@ -0,0 +1,117 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.method.prompt; + +import java.lang.reflect.Method; +import java.util.List; +import java.util.function.BiFunction; + +import org.springaicommunity.mcp.annotation.McpPrompt; + +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; +import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; +import io.modelcontextprotocol.spec.McpSchema.PromptMessage; + +/** + * Class for creating BiFunction callbacks around prompt methods for stateless contexts. + * + * This class provides a way to convert methods annotated with {@link McpPrompt} into + * callback functions that can be used to handle prompt requests in stateless + * environments. It supports various method signatures and return types. + * + * @author Christian Tzolov + */ +public final class SyncStatelessMcpPromptMethodCallback extends AbstractMcpPromptMethodCallback + implements BiFunction { + + private SyncStatelessMcpPromptMethodCallback(Builder builder) { + super(builder.method, builder.bean, builder.prompt); + } + + /** + * Apply the callback to the given context and request. + *

+ * This method builds the arguments for the method call, invokes the method, and + * converts the result to a GetPromptResult. + * @param context The transport context, may be null if the method doesn't require it + * @param request The prompt request, must not be null + * @return The prompt result + * @throws McpPromptMethodException if there is an error invoking the prompt method + * @throws IllegalArgumentException if the request is null + */ + @Override + public GetPromptResult apply(McpTransportContext context, GetPromptRequest request) { + if (request == null) { + throw new IllegalArgumentException("Request must not be null"); + } + + try { + // Build arguments for the method call + Object[] args = this.buildArgs(this.method, context, request); + + // Invoke the method + this.method.setAccessible(true); + Object result = this.method.invoke(this.bean, args); + + // Convert the result to a GetPromptResult + GetPromptResult promptResult = this.convertToGetPromptResult(result); + + return promptResult; + } + catch (Exception e) { + throw new McpPromptMethodException("Error invoking prompt method: " + this.method.getName(), e); + } + } + + @Override + protected boolean isExchangeOrContextType(Class paramType) { + return McpTransportContext.class.isAssignableFrom(paramType); + } + + @Override + protected void validateReturnType(Method method) { + Class returnType = method.getReturnType(); + + boolean validReturnType = GetPromptResult.class.isAssignableFrom(returnType) + || List.class.isAssignableFrom(returnType) || PromptMessage.class.isAssignableFrom(returnType) + || String.class.isAssignableFrom(returnType); + + if (!validReturnType) { + throw new IllegalArgumentException("Method must return either GetPromptResult, List, " + + "List, PromptMessage, or String: " + method.getName() + " in " + + method.getDeclaringClass().getName() + " returns " + returnType.getName()); + } + } + + /** + * Builder for creating SyncStatelessMcpPromptMethodCallback instances. + *

+ * This builder provides a fluent API for constructing + * SyncStatelessMcpPromptMethodCallback instances with the required parameters. + */ + public static class Builder extends AbstractBuilder { + + /** + * Build the callback. + * @return A new SyncStatelessMcpPromptMethodCallback instance + */ + @Override + public SyncStatelessMcpPromptMethodCallback build() { + validate(); + return new SyncStatelessMcpPromptMethodCallback(this); + } + + } + + /** + * Create a new builder. + * @return A new builder instance + */ + public static Builder builder() { + return new Builder(); + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/resource/AbstractMcpResourceMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/resource/AbstractMcpResourceMethodCallback.java index 1e16c8d..7d3ec89 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/resource/AbstractMcpResourceMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/resource/AbstractMcpResourceMethodCallback.java @@ -158,7 +158,7 @@ protected void validateParametersWithoutUriVariables(Method method) { for (Parameter param : parameters) { Class paramType = param.getType(); - if (isExchangeType(paramType)) { + if (isExchangeOrContextType(paramType)) { if (hasExchangeParam) { throw new IllegalArgumentException("Method cannot have more than one exchange parameter: " + method.getName() + " in " + method.getDeclaringClass().getName()); @@ -205,7 +205,7 @@ protected void validateParametersWithUriVariables(Method method) { for (Parameter param : parameters) { Class paramType = param.getType(); - if (isExchangeType(paramType)) { + if (isExchangeOrContextType(paramType)) { exchangeParamCount++; } else if (ReadResourceRequest.class.isAssignableFrom(paramType)) { @@ -240,7 +240,7 @@ else if (ReadResourceRequest.class.isAssignableFrom(paramType)) { // Check that all non-special parameters are String type (for URI variables) for (Parameter param : parameters) { Class paramType = param.getType(); - if (!isExchangeType(paramType) && !ReadResourceRequest.class.isAssignableFrom(paramType) + if (!isExchangeOrContextType(paramType) && !ReadResourceRequest.class.isAssignableFrom(paramType) && !String.class.isAssignableFrom(paramType)) { throw new IllegalArgumentException("URI variable parameters must be of type String: " + method.getName() + " in " + method.getDeclaringClass().getName() + ", parameter of type " + paramType.getName() @@ -293,7 +293,7 @@ protected void buildArgsWithUriVariables(Parameter[] parameters, Object[] args, // First pass: assign special parameters (exchange and request) for (int i = 0; i < parameters.length; i++) { Class paramType = parameters[i].getType(); - if (isExchangeType(paramType)) { + if (isExchangeOrContextType(paramType)) { args[i] = exchange; } else if (ReadResourceRequest.class.isAssignableFrom(paramType)) { @@ -339,7 +339,7 @@ protected void buildArgsWithoutUriVariables(Parameter[] parameters, Object[] arg Parameter param = parameters[i]; Class paramType = param.getType(); - if (isExchangeType(paramType)) { + if (isExchangeOrContextType(paramType)) { args[i] = exchange; } else if (ReadResourceRequest.class.isAssignableFrom(paramType)) { @@ -361,7 +361,7 @@ else if (String.class.isAssignableFrom(paramType)) { * @return true if the parameter type is compatible with the exchange type, false * otherwise */ - protected abstract boolean isExchangeType(Class paramType); + protected abstract boolean isExchangeOrContextType(Class paramType); /** * Returns the content type of the resource. diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/resource/AsyncMcpResourceMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/resource/AsyncMcpResourceMethodCallback.java index 5790e9e..aa4ae06 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/resource/AsyncMcpResourceMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/resource/AsyncMcpResourceMethodCallback.java @@ -159,7 +159,7 @@ protected void validateReturnType(Method method) { * otherwise */ @Override - protected boolean isExchangeType(Class paramType) { + protected boolean isExchangeOrContextType(Class paramType) { return McpAsyncServerExchange.class.isAssignableFrom(paramType); } diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/resource/AsyncStatelessMcpResourceMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/resource/AsyncStatelessMcpResourceMethodCallback.java new file mode 100644 index 0000000..eff3594 --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/resource/AsyncStatelessMcpResourceMethodCallback.java @@ -0,0 +1,166 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ +package org.springaicommunity.mcp.method.resource; + +import java.lang.reflect.Method; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; + +import org.springaicommunity.mcp.annotation.McpResource; + +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceRequest; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.ResourceContents; +import reactor.core.publisher.Mono; + +/** + * Class for creating BiFunction callbacks around resource methods with asynchronous + * processing for stateless contexts. + * + * This class provides a way to convert methods annotated with {@link McpResource} into + * callback functions that can be used to handle resource requests asynchronously in + * stateless environments. It supports various method signatures and return types, and + * handles URI template variables. + * + * @author Christian Tzolov + */ +public final class AsyncStatelessMcpResourceMethodCallback extends AbstractMcpResourceMethodCallback + implements BiFunction> { + + private AsyncStatelessMcpResourceMethodCallback(Builder builder) { + super(builder.method, builder.bean, builder.uri, builder.name, builder.description, builder.mimeType, + builder.resultConverter, builder.uriTemplateManagerFactory, builder.contentType); + this.validateMethod(this.method); + } + + /** + * Apply the callback to the given context and request. + *

+ * This method extracts URI variable values from the request URI, builds the arguments + * for the method call, invokes the method, and converts the result to a + * ReadResourceResult. + * @param context The transport context, may be null if the method doesn't require it + * @param request The resource request, must not be null + * @return A Mono that emits the resource result + * @throws McpResourceMethodException if there is an error invoking the resource + * method + * @throws IllegalArgumentException if the request is null or if URI variable + * extraction fails + */ + @Override + public Mono apply(McpTransportContext context, ReadResourceRequest request) { + if (request == null) { + return Mono.error(new IllegalArgumentException("Request must not be null")); + } + + return Mono.defer(() -> { + try { + // Extract URI variable values from the request URI + Map uriVariableValues = this.uriTemplateManager.extractVariableValues(request.uri()); + + // Verify all URI variables were extracted if URI variables are expected + if (!this.uriVariables.isEmpty() && uriVariableValues.size() != this.uriVariables.size()) { + return Mono + .error(new IllegalArgumentException("Failed to extract all URI variables from request URI: " + + request.uri() + ". Expected variables: " + this.uriVariables + ", but found: " + + uriVariableValues.keySet())); + } + + // Build arguments for the method call + Object[] args = this.buildArgs(this.method, context, request, uriVariableValues); + + // Invoke the method + this.method.setAccessible(true); + Object result = this.method.invoke(this.bean, args); + + // Handle the result based on its type + if (result instanceof Mono) { + // If the result is already a Mono, use it + return ((Mono) result).map(r -> this.resultConverter.convertToReadResourceResult(r, + request.uri(), this.mimeType, this.contentType)); + } + else { + // Otherwise, convert the result to a ReadResourceResult and wrap in a + // Mono + return Mono.just(this.resultConverter.convertToReadResourceResult(result, request.uri(), + this.mimeType, this.contentType)); + } + } + catch (Exception e) { + return Mono.error( + new McpResourceMethodException("Error invoking resource method: " + this.method.getName(), e)); + } + }); + } + + /** + * Builder for creating AsyncStatelessMcpResourceMethodCallback instances. + *

+ * This builder provides a fluent API for constructing + * AsyncStatelessMcpResourceMethodCallback instances with the required parameters. + */ + public static class Builder extends AbstractBuilder { + + /** + * Constructor for Builder. + */ + public Builder() { + this.resultConverter = new DefaultMcpReadResourceResultConverter(); + } + + /** + * Build the callback. + * @return A new AsyncStatelessMcpResourceMethodCallback instance + */ + @Override + public AsyncStatelessMcpResourceMethodCallback build() { + validate(); + return new AsyncStatelessMcpResourceMethodCallback(this); + } + + } + + /** + * Create a new builder. + * @return A new builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Validates that the method return type is compatible with the resource callback. + * @param method The method to validate + * @throws IllegalArgumentException if the return type is not compatible + */ + @Override + protected void validateReturnType(Method method) { + Class returnType = method.getReturnType(); + + boolean validReturnType = ReadResourceResult.class.isAssignableFrom(returnType) + || List.class.isAssignableFrom(returnType) || ResourceContents.class.isAssignableFrom(returnType) + || String.class.isAssignableFrom(returnType) || Mono.class.isAssignableFrom(returnType); + + if (!validReturnType) { + throw new IllegalArgumentException( + "Method must return either ReadResourceResult, List, List, " + + "ResourceContents, String, or Mono: " + method.getName() + " in " + + method.getDeclaringClass().getName() + " returns " + returnType.getName()); + } + } + + /** + * Checks if a parameter type is compatible with the exchange type. + * @param paramType The parameter type to check + * @return true if the parameter type is compatible with the exchange type, false + * otherwise + */ + @Override + protected boolean isExchangeOrContextType(Class paramType) { + return McpTransportContext.class.isAssignableFrom(paramType); + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/resource/SyncMcpResourceMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/resource/SyncMcpResourceMethodCallback.java index 37ab4bd..f0b2d25 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/resource/SyncMcpResourceMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/resource/SyncMcpResourceMethodCallback.java @@ -129,7 +129,7 @@ protected void validateReturnType(Method method) { } @Override - protected boolean isExchangeType(Class paramType) { + protected boolean isExchangeOrContextType(Class paramType) { return McpSyncServerExchange.class.isAssignableFrom(paramType); } diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/resource/SyncStatelessMcpResourceMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/resource/SyncStatelessMcpResourceMethodCallback.java new file mode 100644 index 0000000..46c7ea6 --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/resource/SyncStatelessMcpResourceMethodCallback.java @@ -0,0 +1,137 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.method.resource; + +import java.lang.reflect.Method; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; + +import org.springaicommunity.mcp.annotation.McpResource; + +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceRequest; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.ResourceContents; + +/** + * Class for creating BiFunction callbacks around resource methods for stateless contexts. + * + * This class provides a way to convert methods annotated with {@link McpResource} into + * callback functions that can be used to handle resource requests in stateless + * environments. It supports various method signatures and return types, and handles URI + * template variables. + * + * @author Christian Tzolov + */ +public final class SyncStatelessMcpResourceMethodCallback extends AbstractMcpResourceMethodCallback + implements BiFunction { + + private SyncStatelessMcpResourceMethodCallback(Builder builder) { + super(builder.method, builder.bean, builder.uri, builder.name, builder.description, builder.mimeType, + builder.resultConverter, builder.uriTemplateManagerFactory, builder.contentType); + this.validateMethod(this.method); + } + + /** + * Apply the callback to the given context and request. + *

+ * This method extracts URI variable values from the request URI, builds the arguments + * for the method call, invokes the method, and converts the result to a + * ReadResourceResult. + * @param context The transport context, may be null if the method doesn't require it + * @param request The resource request, must not be null + * @return The resource result + * @throws McpResourceMethodException if there is an error invoking the resource + * method + * @throws IllegalArgumentException if the request is null or if URI variable + * extraction fails + */ + @Override + public ReadResourceResult apply(McpTransportContext context, ReadResourceRequest request) { + if (request == null) { + throw new IllegalArgumentException("Request must not be null"); + } + + try { + // Extract URI variable values from the request URI + Map uriVariableValues = this.uriTemplateManager.extractVariableValues(request.uri()); + + // Verify all URI variables were extracted if URI variables are expected + if (!this.uriVariables.isEmpty() && uriVariableValues.size() != this.uriVariables.size()) { + throw new IllegalArgumentException("Failed to extract all URI variables from request URI: " + + request.uri() + ". Expected variables: " + this.uriVariables + ", but found: " + + uriVariableValues.keySet()); + } + + // Build arguments for the method call + Object[] args = this.buildArgs(this.method, context, request, uriVariableValues); + + // Invoke the method + this.method.setAccessible(true); + Object result = this.method.invoke(this.bean, args); + + // Convert the result to a ReadResourceResult using the converter + return this.resultConverter.convertToReadResourceResult(result, request.uri(), this.mimeType, + this.contentType); + } + catch (Exception e) { + throw new McpResourceMethodException("Access error invoking resource method: " + this.method.getName(), e); + } + } + + /** + * Builder for creating SyncStatelessMcpResourceMethodCallback instances. + *

+ * This builder provides a fluent API for constructing + * SyncStatelessMcpResourceMethodCallback instances with the required parameters. + */ + public static class Builder extends AbstractBuilder { + + /** + * Constructor for Builder. + */ + private Builder() { + this.resultConverter = new DefaultMcpReadResourceResultConverter(); + } + + @Override + public SyncStatelessMcpResourceMethodCallback build() { + validate(); + return new SyncStatelessMcpResourceMethodCallback(this); + } + + } + + /** + * Create a new builder. + * @return A new builder instance + */ + public static Builder builder() { + return new Builder(); + } + + @Override + protected void validateReturnType(Method method) { + Class returnType = method.getReturnType(); + + boolean validReturnType = ReadResourceResult.class.isAssignableFrom(returnType) + || List.class.isAssignableFrom(returnType) || ResourceContents.class.isAssignableFrom(returnType) + || String.class.isAssignableFrom(returnType); + + if (!validReturnType) { + throw new IllegalArgumentException( + "Method must return either ReadResourceResult, List, List, " + + "ResourceContents, or String: " + method.getName() + " in " + + method.getDeclaringClass().getName() + " returns " + returnType.getName()); + } + } + + @Override + protected boolean isExchangeOrContextType(Class paramType) { + return McpTransportContext.class.isAssignableFrom(paramType); + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractAsyncMcpToolMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractAsyncMcpToolMethodCallback.java new file mode 100644 index 0000000..6e5b100 --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractAsyncMcpToolMethodCallback.java @@ -0,0 +1,265 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springaicommunity.mcp.method.tool; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Type; +import java.util.Map; +import java.util.stream.Stream; + +import org.reactivestreams.Publisher; +import org.springaicommunity.mcp.annotation.McpTool; +import org.springaicommunity.mcp.method.tool.utils.JsonParser; + +import com.fasterxml.jackson.core.type.TypeReference; + +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * Abstract base class for creating Function callbacks around async tool methods. + * + * This class provides common functionality for converting methods annotated with + * {@link McpTool} into callback functions that can be used to handle tool requests + * asynchronously. + * + * @param The type of the context parameter (e.g., McpAsyncServerExchange or + * McpTransportContext) + * @author Christian Tzolov + */ +public abstract class AbstractAsyncMcpToolMethodCallback { + + private static final TypeReference> MAP_TYPE_REFERENCE = new TypeReference>() { + // No implementation needed + }; + + protected final Method toolMethod; + + protected final Object toolObject; + + protected final ReturnMode returnMode; + + protected AbstractAsyncMcpToolMethodCallback(ReturnMode returnMode, Method toolMethod, Object toolObject) { + this.toolMethod = toolMethod; + this.toolObject = toolObject; + this.returnMode = returnMode; + } + + /** + * Invokes the tool method with the provided arguments. + * @param methodArguments The arguments to pass to the method + * @return The result of the method invocation + * @throws IllegalStateException if the method cannot be accessed + * @throws RuntimeException if there's an error invoking the method + */ + protected Object callMethod(Object[] methodArguments) { + this.toolMethod.setAccessible(true); + + Object result; + try { + result = this.toolMethod.invoke(this.toolObject, methodArguments); + } + catch (IllegalAccessException ex) { + throw new IllegalStateException("Could not access method: " + ex.getMessage(), ex); + } + catch (InvocationTargetException ex) { + throw new RuntimeException("Error invoking method: " + this.toolMethod.getName(), ex); + } + return result; + } + + /** + * Builds the method arguments from the context and tool input arguments. + * @param exchangeOrContext The exchange or context object (e.g., + * McpAsyncServerExchange or McpTransportContext) + * @param toolInputArguments The input arguments from the tool request + * @return An array of method arguments + */ + protected Object[] buildMethodArguments(T exchangeOrContext, Map toolInputArguments) { + return Stream.of(this.toolMethod.getParameters()).map(parameter -> { + Object rawArgument = toolInputArguments.get(parameter.getName()); + + if (isExchangeOrContextType(parameter.getType())) { + return exchangeOrContext; + } + return buildTypedArgument(rawArgument, parameter.getParameterizedType()); + }).toArray(); + } + + /** + * Builds a typed argument from a raw value and type information. + * @param value The raw value + * @param type The target type + * @return The typed argument + */ + protected Object buildTypedArgument(Object value, Type type) { + if (value == null) { + return null; + } + + if (type instanceof Class) { + return JsonParser.toTypedObject(value, (Class) type); + } + + // For generic types, use the fromJson method that accepts Type + String json = JsonParser.toJson(value); + return JsonParser.fromJson(json, type); + } + + /** + * Convert reactive types to Mono + * @param result The result from the method invocation + * @return A Mono representing the processed result + */ + protected Mono convertToCallToolResult(Object result) { + // Handle Mono types + if (result instanceof Mono) { + + Mono monoResult = (Mono) result; + + // Check if the Mono contains CallToolResult + if (ReactiveUtils.isReactiveReturnTypeOfCallToolResult(this.toolMethod)) { + return (Mono) monoResult; + } + + // Handle Mono for VOID return type + if (ReactiveUtils.isReactiveReturnTypeOfVoid(this.toolMethod)) { + return monoResult + .then(Mono.just(CallToolResult.builder().addTextContent(JsonParser.toJson("Done")).build())); + } + + // Handle other Mono types - map the emitted value to CallToolResult + return monoResult.map(this::mapValueToCallToolResult) + .onErrorResume(e -> Mono.just(CallToolResult.builder() + .isError(true) + .addTextContent("Error invoking method: %s".formatted(e.getMessage())) + .build())); + } + + // Handle Flux by taking the first element + if (result instanceof Flux) { + Flux fluxResult = (Flux) result; + + // Check if the Flux contains CallToolResult + if (ReactiveUtils.isReactiveReturnTypeOfCallToolResult(this.toolMethod)) { + return ((Flux) fluxResult).next(); + } + + // Handle Mono for VOID return type + if (ReactiveUtils.isReactiveReturnTypeOfVoid(this.toolMethod)) { + return fluxResult + .then(Mono.just(CallToolResult.builder().addTextContent(JsonParser.toJson("Done")).build())); + } + + // Handle other Flux types by taking the first element and mapping + return fluxResult.next() + .map(this::mapValueToCallToolResult) + .onErrorResume(e -> Mono.just(CallToolResult.builder() + .isError(true) + .addTextContent("Error invoking method: %s".formatted(e.getMessage())) + .build())); + } + + // Handle other Publisher types + if (result instanceof Publisher) { + Publisher publisherResult = (Publisher) result; + Mono monoFromPublisher = Mono.from(publisherResult); + + // Check if the Publisher contains CallToolResult + if (ReactiveUtils.isReactiveReturnTypeOfCallToolResult(this.toolMethod)) { + return (Mono) monoFromPublisher; + } + + // Handle Mono for VOID return type + if (ReactiveUtils.isReactiveReturnTypeOfVoid(this.toolMethod)) { + return monoFromPublisher + .then(Mono.just(CallToolResult.builder().addTextContent(JsonParser.toJson("Done")).build())); + } + + // Handle other Publisher types by mapping the emitted value + return monoFromPublisher.map(this::mapValueToCallToolResult) + .onErrorResume(e -> Mono.just(CallToolResult.builder() + .isError(true) + .addTextContent("Error invoking method: %s".formatted(e.getMessage())) + .build())); + } + + // This should not happen in async context, but handle as fallback + throw new IllegalStateException( + "Expected reactive return type but got: " + (result != null ? result.getClass().getName() : "null")); + } + + /** + * Map individual values to CallToolResult + * @param value The value to map + * @return A CallToolResult representing the mapped value + */ + protected CallToolResult mapValueToCallToolResult(Object value) { + if (value instanceof CallToolResult) { + return (CallToolResult) value; + } + + if (returnMode == ReturnMode.VOID) { + return CallToolResult.builder().addTextContent(JsonParser.toJson("Done")).build(); + } + else if (this.returnMode == ReturnMode.STRUCTURED) { + String jsonOutput = JsonParser.toJson(value); + Map structuredOutput = JsonParser.fromJson(jsonOutput, MAP_TYPE_REFERENCE); + return CallToolResult.builder().structuredContent(structuredOutput).build(); + } + + // Default to text output + return CallToolResult.builder().addTextContent(value != null ? value.toString() : "null").build(); + } + + /** + * Creates an error result for exceptions that occur during method invocation. + * @param e The exception that occurred + * @return A Mono representing the error + */ + protected Mono createErrorResult(Exception e) { + return Mono.just(CallToolResult.builder() + .isError(true) + .addTextContent("Error invoking method: %s".formatted(e.getMessage())) + .build()); + } + + /** + * Validates that the request is not null. + * @param request The request to validate + * @return A Mono error if the request is null, otherwise Mono.empty() + */ + protected Mono validateRequest(CallToolRequest request) { + if (request == null) { + return Mono.error(new IllegalArgumentException("Request must not be null")); + } + return Mono.empty(); + } + + /** + * Determines if the given parameter type is an exchange or context type that should + * be injected. Subclasses must implement this method to specify which types are + * considered exchange or context types. + * @param paramType The parameter type to check + * @return true if the parameter type is an exchange or context type, false otherwise + */ + protected abstract boolean isExchangeOrContextType(Class paramType); + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractSyncMcpToolMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractSyncMcpToolMethodCallback.java new file mode 100644 index 0000000..6a72b8d --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractSyncMcpToolMethodCallback.java @@ -0,0 +1,178 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springaicommunity.mcp.method.tool; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Type; +import java.util.Map; +import java.util.stream.Stream; + +import org.springaicommunity.mcp.annotation.McpTool; +import org.springaicommunity.mcp.method.tool.utils.JsonParser; + +import com.fasterxml.jackson.core.type.TypeReference; + +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; + +/** + * Abstract base class for creating Function callbacks around tool methods. + * + * This class provides common functionality for converting methods annotated with + * {@link McpTool} into callback functions that can be used to handle tool requests. + * + * @param The type of the context parameter (e.g., McpTransportContext or + * McpSyncServerExchange) + * @author Christian Tzolov + */ +public abstract class AbstractSyncMcpToolMethodCallback { + + private static final TypeReference> MAP_TYPE_REFERENCE = new TypeReference>() { + // No implementation needed + }; + + protected final Method toolMethod; + + protected final Object toolObject; + + protected final ReturnMode returnMode; + + protected AbstractSyncMcpToolMethodCallback(ReturnMode returnMode, Method toolMethod, Object toolObject) { + this.toolMethod = toolMethod; + this.toolObject = toolObject; + this.returnMode = returnMode; + } + + /** + * Invokes the tool method with the provided arguments. + * @param methodArguments The arguments to pass to the method + * @return The result of the method invocation + * @throws IllegalStateException if the method cannot be accessed + * @throws RuntimeException if there's an error invoking the method + */ + protected Object callMethod(Object[] methodArguments) { + this.toolMethod.setAccessible(true); + + Object result; + try { + result = this.toolMethod.invoke(this.toolObject, methodArguments); + } + catch (IllegalAccessException ex) { + throw new IllegalStateException("Could not access method: " + ex.getMessage(), ex); + } + catch (InvocationTargetException ex) { + throw new RuntimeException("Error invoking method: " + this.toolMethod.getName(), ex); + } + return result; + } + + /** + * Builds the method arguments from the context and tool input arguments. + * @param exchangeOrContext The exchange or context object (e.g., + * McpSyncServerExchange or McpTransportContext) + * @param toolInputArguments The input arguments from the tool request + * @return An array of method arguments + */ + protected Object[] buildMethodArguments(T exchangeOrContext, Map toolInputArguments) { + return Stream.of(this.toolMethod.getParameters()).map(parameter -> { + Object rawArgument = toolInputArguments.get(parameter.getName()); + + if (isExchangeOrContextType(parameter.getType())) { + return exchangeOrContext; + } + return buildTypedArgument(rawArgument, parameter.getParameterizedType()); + }).toArray(); + } + + /** + * Builds a typed argument from a raw value and type information. + * @param value The raw value + * @param type The target type + * @return The typed argument + */ + protected Object buildTypedArgument(Object value, Type type) { + if (value == null) { + return null; + } + + if (type instanceof Class) { + return JsonParser.toTypedObject(value, (Class) type); + } + + // For generic types, use the fromJson method that accepts Type + String json = JsonParser.toJson(value); + return JsonParser.fromJson(json, type); + } + + /** + * Processes the result of the method invocation and converts it to a CallToolResult. + * @param result The result from the method invocation + * @return A CallToolResult representing the processed result + */ + protected CallToolResult processResult(Object result) { + // Return the result if it's already a CallToolResult + if (result instanceof CallToolResult) { + return (CallToolResult) result; + } + + if (returnMode == ReturnMode.VOID) { + return CallToolResult.builder().addTextContent(JsonParser.toJson("Done")).build(); + } + else if (this.returnMode == ReturnMode.STRUCTURED) { + String jsonOutput = JsonParser.toJson(result); + Map structuredOutput = JsonParser.fromJson(jsonOutput, MAP_TYPE_REFERENCE); + return CallToolResult.builder().structuredContent(structuredOutput).build(); + } + + // Default to text output + return CallToolResult.builder().addTextContent(result != null ? result.toString() : "null").build(); + } + + /** + * Creates an error result for exceptions that occur during method invocation. + * @param e The exception that occurred + * @return A CallToolResult representing the error + */ + protected CallToolResult createErrorResult(Exception e) { + return CallToolResult.builder() + .isError(true) + .addTextContent("Error invoking method: %s".formatted(e.getMessage())) + .build(); + } + + /** + * Validates that the request is not null. + * @param request The request to validate + * @throws IllegalArgumentException if the request is null + */ + protected void validateRequest(CallToolRequest request) { + if (request == null) { + throw new IllegalArgumentException("Request must not be null"); + } + } + + /** + * Determines if the given parameter type is an exchange or context type that should + * be injected. Subclasses must implement this method to specify which types are + * considered exchange or context types. + * @param paramType The parameter type to check + * @return true if the parameter type is an exchange or context type, false otherwise + */ + protected abstract boolean isExchangeOrContextType(Class paramType); + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AsyncMcpToolMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AsyncMcpToolMethodCallback.java index 05a7ad2..15e4308 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AsyncMcpToolMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AsyncMcpToolMethodCallback.java @@ -16,23 +16,14 @@ package org.springaicommunity.mcp.method.tool; -import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; -import java.lang.reflect.Type; -import java.util.Map; import java.util.function.BiFunction; -import java.util.stream.Stream; -import org.reactivestreams.Publisher; import org.springaicommunity.mcp.annotation.McpTool; -import org.springaicommunity.mcp.method.tool.utils.JsonParser; - -import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.server.McpAsyncServerExchange; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; /** @@ -43,23 +34,26 @@ * * @author Christian Tzolov */ -public final class AsyncMcpToolMethodCallback +public final class AsyncMcpToolMethodCallback extends AbstractAsyncMcpToolMethodCallback implements BiFunction> { - private static final TypeReference> MAP_TYPE_REFERENCE = new TypeReference>() { - // No implementation needed - }; - - private final Method toolMethod; - - private final Object toolObject; + public AsyncMcpToolMethodCallback(ReturnMode returnMode, Method toolMethod, Object toolObject) { + super(returnMode, toolMethod, toolObject); + } - private ReturnMode returnMode; + @Override + protected boolean isExchangeOrContextType(Class paramType) { + return McpAsyncServerExchange.class.isAssignableFrom(paramType); + } - public AsyncMcpToolMethodCallback(ReturnMode returnMode, Method toolMethod, Object toolObject) { - this.toolMethod = toolMethod; - this.toolObject = toolObject; - this.returnMode = returnMode; + /** + * Public method for backward compatibility with tests. Delegates to the protected + * isExchangeOrContextType method. + * @param paramType The parameter type to check + * @return true if the parameter type is an exchange type, false otherwise + */ + public boolean isExchangeType(Class paramType) { + return isExchangeOrContextType(paramType); } /** @@ -67,17 +61,14 @@ public AsyncMcpToolMethodCallback(ReturnMode returnMode, Method toolMethod, Obje *

* This method builds the arguments for the method call, invokes the method, and * returns the result. + * @param exchange The server exchange context * @param request The tool call request, must not be null * @return The result of the method invocation */ @Override public Mono apply(McpAsyncServerExchange exchange, CallToolRequest request) { - if (request == null) { - return Mono.error(new IllegalArgumentException("Request must not be null")); - } - - return Mono.defer(() -> { + return validateRequest(request).then(Mono.defer(() -> { try { // Build arguments for the method call Object[] args = this.buildMethodArguments(exchange, request.arguments()); @@ -90,160 +81,9 @@ public Mono apply(McpAsyncServerExchange exchange, CallToolReque } catch (Exception e) { - return Mono.just(CallToolResult.builder() - .isError(true) - .addTextContent("Error invoking method: %s".formatted(e.getMessage())) - .build()); - } - }); - } - - /** - * Convert reactive types to Mono - */ - private Mono convertToCallToolResult(Object result) { - // Handle Mono types - if (result instanceof Mono) { - - Mono monoResult = (Mono) result; - - // Check if the Mono contains CallToolResult - if (ReactiveUtils.isReactiveReturnTypeOfCallToolResult(this.toolMethod)) { - return (Mono) monoResult; - } - - // Handle Mono for VOID return type - if (ReactiveUtils.isReactiveReturnTypeOfVoid(this.toolMethod)) { - return monoResult - .then(Mono.just(CallToolResult.builder().addTextContent(JsonParser.toJson("Done")).build())); + return this.createErrorResult(e); } - - // Handle other Mono types - map the emitted value to CallToolResult - return monoResult.map(this::mapValueToCallToolResult) - .onErrorResume(e -> Mono.just(CallToolResult.builder() - .isError(true) - .addTextContent("Error invoking method: %s".formatted(e.getMessage())) - .build())); - } - - // Handle Flux by taking the first element - if (result instanceof Flux) { - Flux fluxResult = (Flux) result; - - // Check if the Flux contains CallToolResult - if (ReactiveUtils.isReactiveReturnTypeOfCallToolResult(this.toolMethod)) { - return ((Flux) fluxResult).next(); - } - - // Handle Mono for VOID return type - if (ReactiveUtils.isReactiveReturnTypeOfVoid(this.toolMethod)) { - return fluxResult - .then(Mono.just(CallToolResult.builder().addTextContent(JsonParser.toJson("Done")).build())); - } - - // Handle other Flux types by taking the first element and mapping - return fluxResult.next() - .map(this::mapValueToCallToolResult) - .onErrorResume(e -> Mono.just(CallToolResult.builder() - .isError(true) - .addTextContent("Error invoking method: %s".formatted(e.getMessage())) - .build())); - } - - // Handle other Publisher types - if (result instanceof Publisher) { - Publisher publisherResult = (Publisher) result; - Mono monoFromPublisher = Mono.from(publisherResult); - - // Check if the Publisher contains CallToolResult - if (ReactiveUtils.isReactiveReturnTypeOfCallToolResult(this.toolMethod)) { - return (Mono) monoFromPublisher; - } - - // Handle Mono for VOID return type - if (ReactiveUtils.isReactiveReturnTypeOfVoid(this.toolMethod)) { - return monoFromPublisher - .then(Mono.just(CallToolResult.builder().addTextContent(JsonParser.toJson("Done")).build())); - } - - // Handle other Publisher types by mapping the emitted value - return monoFromPublisher.map(this::mapValueToCallToolResult) - .onErrorResume(e -> Mono.just(CallToolResult.builder() - .isError(true) - .addTextContent("Error invoking method: %s".formatted(e.getMessage())) - .build())); - } - - // This should not happen in async context, but handle as fallback - throw new IllegalStateException( - "Expected reactive return type but got: " + (result != null ? result.getClass().getName() : "null")); - } - - /** - * Map individual values to CallToolResult - */ - private CallToolResult mapValueToCallToolResult(Object value) { - if (value instanceof CallToolResult) { - return (CallToolResult) value; - } - - if (returnMode == ReturnMode.VOID) { - return CallToolResult.builder().addTextContent(JsonParser.toJson("Done")).build(); - } - else if (this.returnMode == ReturnMode.STRUCTURED) { - String jsonOutput = JsonParser.toJson(value); - Map structuredOutput = JsonParser.fromJson(jsonOutput, MAP_TYPE_REFERENCE); - return CallToolResult.builder().structuredContent(structuredOutput).build(); - } - - // Default to text output - return CallToolResult.builder().addTextContent(value != null ? value.toString() : "null").build(); - } - - private Object callMethod(Object[] methodArguments) { - - this.toolMethod.setAccessible(true); - - Object result; - try { - result = this.toolMethod.invoke(this.toolObject, methodArguments); - } - catch (IllegalAccessException ex) { - throw new IllegalStateException("Could not access method: " + ex.getMessage(), ex); - } - catch (InvocationTargetException ex) { - throw new RuntimeException("Error invoking method: " + this.toolMethod.getName(), ex); - // throw new ToolExecutionException(this.toolDefinition, ex.getCause()); - } - return result; - } - - private Object[] buildMethodArguments(McpAsyncServerExchange exchange, Map toolInputArguments) { - return Stream.of(this.toolMethod.getParameters()).map(parameter -> { - Object rawArgument = toolInputArguments.get(parameter.getName()); - if (isExchangeType(parameter.getType())) { - return exchange; - } - return buildTypedArgument(rawArgument, parameter.getParameterizedType()); - }).toArray(); - } - - private Object buildTypedArgument(Object value, Type type) { - if (value == null) { - return null; - } - - if (type instanceof Class) { - return JsonParser.toTypedObject(value, (Class) type); - } - - // For generic types, use the fromJson method that accepts Type - String json = JsonParser.toJson(value); - return JsonParser.fromJson(json, type); - } - - protected boolean isExchangeType(Class paramType) { - return McpAsyncServerExchange.class.isAssignableFrom(paramType); + })); } } diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AsyncStatelessMcpToolMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AsyncStatelessMcpToolMethodCallback.java new file mode 100644 index 0000000..513a23a --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AsyncStatelessMcpToolMethodCallback.java @@ -0,0 +1,80 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springaicommunity.mcp.method.tool; + +import java.util.function.BiFunction; + +import org.springaicommunity.mcp.annotation.McpTool; + +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import reactor.core.publisher.Mono; + +/** + * Class for creating Function callbacks around async stateless tool methods. + * + * This class provides a way to convert methods annotated with {@link McpTool} into + * callback functions that can be used to handle tool requests asynchronously in a + * stateless manner using McpTransportContext. + * + * @author Christian Tzolov + */ +public final class AsyncStatelessMcpToolMethodCallback extends AbstractAsyncMcpToolMethodCallback + implements BiFunction> { + + public AsyncStatelessMcpToolMethodCallback(ReturnMode returnMode, java.lang.reflect.Method toolMethod, + Object toolObject) { + super(returnMode, toolMethod, toolObject); + } + + @Override + protected boolean isExchangeOrContextType(Class paramType) { + return McpTransportContext.class.isAssignableFrom(paramType); + } + + /** + * Apply the callback to the given request. + *

+ * This method builds the arguments for the method call, invokes the method, and + * returns the result asynchronously. + * @param mcpTransportContext The transport context + * @param request The tool call request, must not be null + * @return A Mono containing the result of the method invocation + */ + @Override + public Mono apply(McpTransportContext mcpTransportContext, CallToolRequest request) { + + return validateRequest(request).then(Mono.defer(() -> { + try { + // Build arguments for the method call + Object[] args = this.buildMethodArguments(mcpTransportContext, request.arguments()); + + // Invoke the method + Object result = this.callMethod(args); + + // Handle reactive types - method return types should always be reactive + return this.convertToCallToolResult(result); + + } + catch (Exception e) { + return this.createErrorResult(e); + } + })); + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncMcpToolMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncMcpToolMethodCallback.java index 63b7637..9195048 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncMcpToolMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncMcpToolMethodCallback.java @@ -16,17 +16,9 @@ package org.springaicommunity.mcp.method.tool; -import java.lang.reflect.InvocationTargetException; -import java.lang.reflect.Method; -import java.lang.reflect.Type; -import java.util.Map; import java.util.function.BiFunction; -import java.util.stream.Stream; import org.springaicommunity.mcp.annotation.McpTool; -import org.springaicommunity.mcp.method.tool.utils.JsonParser; - -import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.server.McpSyncServerExchange; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; @@ -40,23 +32,26 @@ * * @author Christian Tzolov */ -public final class SyncMcpToolMethodCallback +public final class SyncMcpToolMethodCallback extends AbstractSyncMcpToolMethodCallback implements BiFunction { - private static final TypeReference> MAP_TYPE_REFERENCE = new TypeReference>() { - // No implementation needed - }; - - private final Method toolMethod; - - private final Object toolObject; + public SyncMcpToolMethodCallback(ReturnMode returnMode, java.lang.reflect.Method toolMethod, Object toolObject) { + super(returnMode, toolMethod, toolObject); + } - private ReturnMode returnMode; + @Override + protected boolean isExchangeOrContextType(Class paramType) { + return McpSyncServerExchange.class.isAssignableFrom(paramType); + } - public SyncMcpToolMethodCallback(ReturnMode returnMode, Method toolMethod, Object toolObject) { - this.toolMethod = toolMethod; - this.toolObject = toolObject; - this.returnMode = returnMode; + /** + * Public method for backward compatibility with tests. Delegates to the protected + * isExchangeOrContextType method. + * @param paramType The parameter type to check + * @return true if the parameter type is an exchange type, false otherwise + */ + public boolean isExchangeType(Class paramType) { + return isExchangeOrContextType(paramType); } /** @@ -64,15 +59,13 @@ public SyncMcpToolMethodCallback(ReturnMode returnMode, Method toolMethod, Objec *

* This method builds the arguments for the method call, invokes the method, and * returns the result. + * @param exchange The server exchange context * @param request The tool call request, must not be null * @return The result of the method invocation */ @Override public CallToolResult apply(McpSyncServerExchange exchange, CallToolRequest request) { - - if (request == null) { - throw new IllegalArgumentException("Request must not be null"); - } + validateRequest(request); try { // Build arguments for the method call @@ -81,78 +74,12 @@ public CallToolResult apply(McpSyncServerExchange exchange, CallToolRequest requ // Invoke the method Object result = this.callMethod(args); - // Return the result - if (result instanceof CallToolResult) { - return (CallToolResult) result; - } - - if (returnMode == ReturnMode.VOID) { - return CallToolResult.builder().addTextContent(JsonParser.toJson("Done")).build(); - } - else if (this.returnMode == ReturnMode.STRUCTURED) { - - String jsonOutput = JsonParser.toJson(result); - Map structuredOutput = JsonParser.fromJson(jsonOutput, MAP_TYPE_REFERENCE); - - return CallToolResult.builder().structuredContent(structuredOutput).build(); - } - - // Default to text output - return CallToolResult.builder().addTextContent(result != null ? result.toString() : "null").build(); - + // Return the processed result + return this.processResult(result); } catch (Exception e) { - return CallToolResult.builder() - .isError(true) - .addTextContent("Error invoking method: %s".formatted(e.getMessage())) - .build(); + return this.createErrorResult(e); } } - private Object callMethod(Object[] methodArguments) { - - this.toolMethod.setAccessible(true); - - Object result; - try { - result = this.toolMethod.invoke(this.toolObject, methodArguments); - } - catch (IllegalAccessException ex) { - throw new IllegalStateException("Could not access method: " + ex.getMessage(), ex); - } - catch (InvocationTargetException ex) { - throw new RuntimeException("Error invoking method: " + this.toolMethod.getName(), ex); - // throw new ToolExecutionException(this.toolDefinition, ex.getCause()); - } - return result; - } - - private Object[] buildMethodArguments(McpSyncServerExchange exchange, Map toolInputArguments) { - return Stream.of(this.toolMethod.getParameters()).map(parameter -> { - Object rawArgument = toolInputArguments.get(parameter.getName()); - if (isExchangeType(parameter.getType())) { - return exchange; - } - return buildTypedArgument(rawArgument, parameter.getParameterizedType()); - }).toArray(); - } - - private Object buildTypedArgument(Object value, Type type) { - if (value == null) { - return null; - } - - if (type instanceof Class) { - return JsonParser.toTypedObject(value, (Class) type); - } - - // For generic types, use the fromJson method that accepts Type - String json = JsonParser.toJson(value); - return JsonParser.fromJson(json, type); - } - - protected boolean isExchangeType(Class paramType) { - return McpSyncServerExchange.class.isAssignableFrom(paramType); - } - } diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncStatelessMcpToolMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncStatelessMcpToolMethodCallback.java new file mode 100644 index 0000000..771e6a7 --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncStatelessMcpToolMethodCallback.java @@ -0,0 +1,68 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springaicommunity.mcp.method.tool; + +import java.util.function.BiFunction; + +import org.springaicommunity.mcp.annotation.McpTool; + +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; + +/** + * Class for creating Function callbacks around tool methods. + * + * This class provides a way to convert methods annotated with {@link McpTool} into + * callback functions that can be used to handle tool requests. + * + * @author James Ward + * @author Christian Tzolov + */ +public final class SyncStatelessMcpToolMethodCallback extends AbstractSyncMcpToolMethodCallback + implements BiFunction { + + public SyncStatelessMcpToolMethodCallback(ReturnMode returnMode, java.lang.reflect.Method toolMethod, + Object toolObject) { + super(returnMode, toolMethod, toolObject); + } + + @Override + protected boolean isExchangeOrContextType(Class paramType) { + return McpTransportContext.class.isAssignableFrom(paramType); + } + + @Override + public CallToolResult apply(McpTransportContext mcpTransportContext, CallToolRequest callToolRequest) { + validateRequest(callToolRequest); + + try { + // Build arguments for the method call + Object[] args = this.buildMethodArguments(mcpTransportContext, callToolRequest.arguments()); + + // Invoke the method + Object result = this.callMethod(args); + + // Return the processed result + return this.processResult(result); + } + catch (Exception e) { + return this.createErrorResult(e); + } + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/AsyncStatelessMcpCompleteProvider.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/AsyncStatelessMcpCompleteProvider.java new file mode 100644 index 0000000..437c169 --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/AsyncStatelessMcpCompleteProvider.java @@ -0,0 +1,109 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springaicommunity.mcp.provider; + +import java.lang.reflect.Method; +import java.util.List; +import java.util.function.BiFunction; +import java.util.stream.Stream; + +import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springaicommunity.mcp.annotation.CompleteAdapter; +import org.springaicommunity.mcp.annotation.McpComplete; +import org.springaicommunity.mcp.method.complete.AsyncStatelessMcpCompleteMethodCallback; + +import io.modelcontextprotocol.server.McpStatelessServerFeatures.AsyncCompletionSpecification; +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema.CompleteRequest; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult; +import io.modelcontextprotocol.util.Assert; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * Provider for asynchronous stateless MCP complete methods. + * + * This provider creates completion specifications for methods annotated with + * {@link McpComplete} that are designed to work in a stateless manner using + * {@link McpTransportContext} and return reactive types. + * + * @author Christian Tzolov + */ +public class AsyncStatelessMcpCompleteProvider { + + private static final Logger logger = LoggerFactory.getLogger(AsyncStatelessMcpCompleteProvider.class); + + private final List completeObjects; + + /** + * Create a new AsyncStatelessMcpCompleteProvider. + * @param completeObjects the objects containing methods annotated with + * {@link McpComplete} + */ + public AsyncStatelessMcpCompleteProvider(List completeObjects) { + Assert.notNull(completeObjects, "completeObjects cannot be null"); + this.completeObjects = completeObjects; + } + + /** + * Get the async stateless completion specifications. + * @return the list of async stateless completion specifications + */ + public List getCompleteSpecifications() { + + List completeSpecs = this.completeObjects.stream() + .map(completeObject -> Stream.of(doGetClassMethods(completeObject)) + .filter(method -> method.isAnnotationPresent(McpComplete.class)) + .filter(method -> Mono.class.isAssignableFrom(method.getReturnType()) + || Flux.class.isAssignableFrom(method.getReturnType()) + || Publisher.class.isAssignableFrom(method.getReturnType())) + .map(mcpCompleteMethod -> { + var completeAnnotation = mcpCompleteMethod.getAnnotation(McpComplete.class); + var completeRef = CompleteAdapter.asCompleteReference(completeAnnotation, mcpCompleteMethod); + + BiFunction> methodCallback = AsyncStatelessMcpCompleteMethodCallback + .builder() + .method(mcpCompleteMethod) + .bean(completeObject) + .complete(completeAnnotation) + .build(); + + return new AsyncCompletionSpecification(completeRef, methodCallback); + }) + .toList()) + .flatMap(List::stream) + .toList(); + + if (completeSpecs.isEmpty()) { + logger.warn("No complete methods found in the provided complete objects: {}", this.completeObjects); + } + + return completeSpecs; + } + + /** + * Returns the methods of the given bean class. + * @param bean the bean instance + * @return the methods of the bean class + */ + protected Method[] doGetClassMethods(Object bean) { + return bean.getClass().getDeclaredMethods(); + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/AsyncStatelessMcpPromptProvider.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/AsyncStatelessMcpPromptProvider.java new file mode 100644 index 0000000..5f17ff9 --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/AsyncStatelessMcpPromptProvider.java @@ -0,0 +1,109 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springaicommunity.mcp.provider; + +import java.lang.reflect.Method; +import java.util.List; +import java.util.function.BiFunction; +import java.util.stream.Stream; + +import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springaicommunity.mcp.annotation.McpPrompt; +import org.springaicommunity.mcp.annotation.PromptAdaptor; +import org.springaicommunity.mcp.method.prompt.AsyncStatelessMcpPromptMethodCallback; + +import io.modelcontextprotocol.server.McpStatelessServerFeatures.AsyncPromptSpecification; +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; +import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; +import io.modelcontextprotocol.util.Assert; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * Provider for asynchronous stateless MCP prompt methods. + * + * This provider creates prompt specifications for methods annotated with + * {@link McpPrompt} that are designed to work in a stateless manner using + * {@link McpTransportContext} and return reactive types. + * + * @author Christian Tzolov + */ +public class AsyncStatelessMcpPromptProvider { + + private static final Logger logger = LoggerFactory.getLogger(AsyncStatelessMcpPromptProvider.class); + + private final List promptObjects; + + /** + * Create a new AsyncStatelessMcpPromptProvider. + * @param promptObjects the objects containing methods annotated with + * {@link McpPrompt} + */ + public AsyncStatelessMcpPromptProvider(List promptObjects) { + Assert.notNull(promptObjects, "promptObjects cannot be null"); + this.promptObjects = promptObjects; + } + + /** + * Get the async stateless prompt specifications. + * @return the list of async stateless prompt specifications + */ + public List getPromptSpecifications() { + + List promptSpecs = this.promptObjects.stream() + .map(promptObject -> Stream.of(doGetClassMethods(promptObject)) + .filter(method -> method.isAnnotationPresent(McpPrompt.class)) + .filter(method -> Mono.class.isAssignableFrom(method.getReturnType()) + || Flux.class.isAssignableFrom(method.getReturnType()) + || Publisher.class.isAssignableFrom(method.getReturnType())) + .map(mcpPromptMethod -> { + var promptAnnotation = mcpPromptMethod.getAnnotation(McpPrompt.class); + var mcpPrompt = PromptAdaptor.asPrompt(promptAnnotation, mcpPromptMethod); + + BiFunction> methodCallback = AsyncStatelessMcpPromptMethodCallback + .builder() + .method(mcpPromptMethod) + .bean(promptObject) + .prompt(mcpPrompt) + .build(); + + return new AsyncPromptSpecification(mcpPrompt, methodCallback); + }) + .toList()) + .flatMap(List::stream) + .toList(); + + if (promptSpecs.isEmpty()) { + logger.warn("No prompt methods found in the provided prompt objects: {}", this.promptObjects); + } + + return promptSpecs; + } + + /** + * Returns the methods of the given bean class. + * @param bean the bean instance + * @return the methods of the bean class + */ + protected Method[] doGetClassMethods(Object bean) { + return bean.getClass().getDeclaredMethods(); + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/AsyncStatelessMcpResourceProvider.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/AsyncStatelessMcpResourceProvider.java new file mode 100644 index 0000000..826e8bb --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/AsyncStatelessMcpResourceProvider.java @@ -0,0 +1,130 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springaicommunity.mcp.provider; + +import java.lang.reflect.Method; +import java.util.List; +import java.util.function.BiFunction; +import java.util.stream.Stream; + +import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springaicommunity.mcp.annotation.McpResource; +import org.springaicommunity.mcp.method.resource.AsyncStatelessMcpResourceMethodCallback; + +import io.modelcontextprotocol.server.McpStatelessServerFeatures.AsyncResourceSpecification; +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceRequest; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.util.Assert; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * Provider for asynchronous stateless MCP resource methods. + * + * This provider creates resource specifications for methods annotated with + * {@link McpResource} that are designed to work in a stateless manner using + * {@link McpTransportContext} and return reactive types. + * + * @author Christian Tzolov + */ +public class AsyncStatelessMcpResourceProvider { + + private static final Logger logger = LoggerFactory.getLogger(AsyncStatelessMcpResourceProvider.class); + + private final List resourceObjects; + + /** + * Create a new AsyncStatelessMcpResourceProvider. + * @param resourceObjects the objects containing methods annotated with + * {@link McpResource} + */ + public AsyncStatelessMcpResourceProvider(List resourceObjects) { + Assert.notNull(resourceObjects, "resourceObjects cannot be null"); + this.resourceObjects = resourceObjects; + } + + /** + * Get the async stateless resource specifications. + * @return the list of async stateless resource specifications + */ + public List getResourceSpecifications() { + + List resourceSpecs = this.resourceObjects.stream() + .map(resourceObject -> Stream.of(doGetClassMethods(resourceObject)) + .filter(method -> method.isAnnotationPresent(McpResource.class)) + .filter(method -> Mono.class.isAssignableFrom(method.getReturnType()) + || Flux.class.isAssignableFrom(method.getReturnType()) + || Publisher.class.isAssignableFrom(method.getReturnType())) + .map(mcpResourceMethod -> { + + var resourceAnnotation = doGetMcpResourceAnnotation(mcpResourceMethod); + + var uri = resourceAnnotation.uri(); + var name = getName(mcpResourceMethod, resourceAnnotation); + var description = resourceAnnotation.description(); + var mimeType = resourceAnnotation.mimeType(); + + var mcpResource = McpSchema.Resource.builder() + .uri(uri) + .name(name) + .description(description) + .mimeType(mimeType) + .build(); + + BiFunction> methodCallback = AsyncStatelessMcpResourceMethodCallback + .builder() + .method(mcpResourceMethod) + .bean(resourceObject) + .resource(mcpResource) + .build(); + + var resourceSpec = new AsyncResourceSpecification(mcpResource, methodCallback); + + return resourceSpec; + }) + .toList()) + .flatMap(List::stream) + .toList(); + + if (resourceSpecs.isEmpty()) { + logger.warn("No resource methods found in the provided resource objects: {}", this.resourceObjects); + } + + return resourceSpecs; + } + + protected Method[] doGetClassMethods(Object bean) { + return bean.getClass().getDeclaredMethods(); + } + + protected McpResource doGetMcpResourceAnnotation(Method method) { + return method.getAnnotation(McpResource.class); + } + + private static String getName(Method method, McpResource resource) { + Assert.notNull(method, "method cannot be null"); + if (resource == null || resource.name() == null || resource.name().isEmpty()) { + return method.getName(); + } + return resource.name(); + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/AsyncStatelessMcpToolProvider.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/AsyncStatelessMcpToolProvider.java new file mode 100644 index 0000000..23edf1c --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/AsyncStatelessMcpToolProvider.java @@ -0,0 +1,158 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springaicommunity.mcp.provider; + +import java.lang.reflect.Method; +import java.util.List; +import java.util.function.BiFunction; +import java.util.stream.Stream; + +import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springaicommunity.mcp.annotation.McpTool; +import org.springaicommunity.mcp.method.tool.AsyncStatelessMcpToolMethodCallback; +import org.springaicommunity.mcp.method.tool.ReactiveUtils; +import org.springaicommunity.mcp.method.tool.ReturnMode; +import org.springaicommunity.mcp.method.tool.utils.ClassUtils; +import org.springaicommunity.mcp.method.tool.utils.JsonSchemaGenerator; + +import io.modelcontextprotocol.server.McpStatelessServerFeatures.AsyncToolSpecification; +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.Utils; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * Provider for asynchronous stateless MCP tool methods. + * + * This provider creates tool specifications for methods annotated with {@link McpTool} + * that are designed to work in a stateless manner using {@link McpTransportContext} and + * return reactive types. + * + * @author Christian Tzolov + */ +public class AsyncStatelessMcpToolProvider { + + private static final Logger logger = LoggerFactory.getLogger(AsyncStatelessMcpToolProvider.class); + + private final List toolObjects; + + /** + * Create a new AsyncStatelessMcpToolProvider. + * @param toolObjects the objects containing methods annotated with {@link McpTool} + */ + public AsyncStatelessMcpToolProvider(List toolObjects) { + Assert.notNull(toolObjects, "toolObjects cannot be null"); + this.toolObjects = toolObjects; + } + + /** + * Get the async stateless tool specifications. + * @return the list of async stateless tool specifications + */ + public List getToolSpecifications() { + + List toolSpecs = this.toolObjects.stream() + .map(toolObject -> Stream.of(doGetClassMethods(toolObject)) + .filter(method -> method.isAnnotationPresent(McpTool.class)) + .filter(method -> Mono.class.isAssignableFrom(method.getReturnType()) + || Flux.class.isAssignableFrom(method.getReturnType()) + || Publisher.class.isAssignableFrom(method.getReturnType())) + .map(mcpToolMethod -> { + + var toolAnnotation = doGetMcpToolAnnotation(mcpToolMethod); + + String toolName = Utils.hasText(toolAnnotation.name()) ? toolAnnotation.name() + : mcpToolMethod.getName(); + + String toolDescrption = toolAnnotation.description(); + + String inputSchema = JsonSchemaGenerator.generateForMethodInput(mcpToolMethod); + + var toolBuilder = McpSchema.Tool.builder() + .name(toolName) + .description(toolDescrption) + .inputSchema(inputSchema); + + // Tool annotations + if (toolAnnotation.annotations() != null) { + var toolAnnotations = toolAnnotation.annotations(); + toolBuilder.annotations(new McpSchema.ToolAnnotations(toolAnnotations.title(), + toolAnnotations.readOnlyHint(), toolAnnotations.destructiveHint(), + toolAnnotations.idempotentHint(), toolAnnotations.openWorldHint(), null)); + } + + // Generate Output Schema from the method return type. + // Output schema is not generated for primitive types, void, + // CallToolResult, simple value types (String, etc.) + // or if generateOutputSchema attribute is set to false. + + if (toolAnnotation.generateOutputSchema() + && !ReactiveUtils.isReactiveReturnTypeOfVoid(mcpToolMethod) + && !ReactiveUtils.isReactiveReturnTypeOfCallToolResult(mcpToolMethod)) { + + ReactiveUtils.getReactiveReturnTypeArgument(mcpToolMethod).ifPresent(typeArgument -> { + Class methodReturnType = typeArgument instanceof Class ? (Class) typeArgument + : null; + if (!ClassUtils.isPrimitiveOrWrapper(methodReturnType) + && !ClassUtils.isSimpleValueType(methodReturnType)) { + toolBuilder + .outputSchema(JsonSchemaGenerator.generateFromClass((Class) typeArgument)); + } + }); + } + var tool = toolBuilder.build(); + + ReturnMode returnMode = tool.outputSchema() != null ? ReturnMode.STRUCTURED + : ReactiveUtils.isReactiveReturnTypeOfVoid(mcpToolMethod) ? ReturnMode.VOID + : ReturnMode.TEXT; + + BiFunction> methodCallback = new AsyncStatelessMcpToolMethodCallback( + returnMode, mcpToolMethod, toolObject); + + AsyncToolSpecification toolSpec = AsyncToolSpecification.builder() + .tool(tool) + .callHandler(methodCallback) + .build(); + + return toolSpec; + }) + .toList()) + .flatMap(List::stream) + .toList(); + + if (toolSpecs.isEmpty()) { + logger.warn("No tool methods found in the provided tool objects: {}", this.toolObjects); + } + + return toolSpecs; + } + + protected Method[] doGetClassMethods(Object bean) { + return bean.getClass().getDeclaredMethods(); + } + + protected McpTool doGetMcpToolAnnotation(Method method) { + return method.getAnnotation(McpTool.class); + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/SyncStatelessMcpCompleteProvider.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/SyncStatelessMcpCompleteProvider.java new file mode 100644 index 0000000..83ea95b --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/SyncStatelessMcpCompleteProvider.java @@ -0,0 +1,105 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springaicommunity.mcp.provider; + +import java.lang.reflect.Method; +import java.util.List; +import java.util.function.BiFunction; +import java.util.stream.Stream; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springaicommunity.mcp.annotation.CompleteAdapter; +import org.springaicommunity.mcp.annotation.McpComplete; +import org.springaicommunity.mcp.method.complete.SyncStatelessMcpCompleteMethodCallback; + +import io.modelcontextprotocol.server.McpStatelessServerFeatures.SyncCompletionSpecification; +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema.CompleteRequest; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult; +import io.modelcontextprotocol.util.Assert; +import reactor.core.publisher.Mono; + +/** + * Provider for synchronous stateless MCP complete methods. + * + * This provider creates completion specifications for methods annotated with + * {@link McpComplete} that are designed to work in a stateless manner using + * {@link McpTransportContext}. + * + * @author Christian Tzolov + */ +public class SyncStatelessMcpCompleteProvider { + + private static final Logger logger = LoggerFactory.getLogger(SyncStatelessMcpCompleteProvider.class); + + private final List completeObjects; + + /** + * Create a new SyncStatelessMcpCompleteProvider. + * @param completeObjects the objects containing methods annotated with + * {@link McpComplete} + */ + public SyncStatelessMcpCompleteProvider(List completeObjects) { + Assert.notNull(completeObjects, "completeObjects cannot be null"); + this.completeObjects = completeObjects; + } + + /** + * Get the stateless completion specifications. + * @return the list of stateless completion specifications + */ + public List getCompleteSpecifications() { + + List completeSpecs = this.completeObjects.stream() + .map(completeObject -> Stream.of(doGetClassMethods(completeObject)) + .filter(method -> method.isAnnotationPresent(McpComplete.class)) + .filter(method -> !Mono.class.isAssignableFrom(method.getReturnType())) + .map(mcpCompleteMethod -> { + var completeAnnotation = mcpCompleteMethod.getAnnotation(McpComplete.class); + var completeRef = CompleteAdapter.asCompleteReference(completeAnnotation, mcpCompleteMethod); + + BiFunction methodCallback = SyncStatelessMcpCompleteMethodCallback + .builder() + .method(mcpCompleteMethod) + .bean(completeObject) + .complete(completeAnnotation) + .build(); + + return new SyncCompletionSpecification(completeRef, methodCallback); + }) + .toList()) + .flatMap(List::stream) + .toList(); + + if (completeSpecs.isEmpty()) { + logger.warn("No complete methods found in the provided complete objects: {}", this.completeObjects); + } + + return completeSpecs; + } + + /** + * Returns the methods of the given bean class. + * @param bean the bean instance + * @return the methods of the bean class + */ + protected Method[] doGetClassMethods(Object bean) { + return bean.getClass().getDeclaredMethods(); + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/SyncStatelessMcpPromptProvider.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/SyncStatelessMcpPromptProvider.java new file mode 100644 index 0000000..03f83ad --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/SyncStatelessMcpPromptProvider.java @@ -0,0 +1,105 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springaicommunity.mcp.provider; + +import java.lang.reflect.Method; +import java.util.List; +import java.util.function.BiFunction; +import java.util.stream.Stream; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springaicommunity.mcp.annotation.McpPrompt; +import org.springaicommunity.mcp.annotation.PromptAdaptor; +import org.springaicommunity.mcp.method.prompt.SyncStatelessMcpPromptMethodCallback; + +import io.modelcontextprotocol.server.McpStatelessServerFeatures.SyncPromptSpecification; +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; +import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; +import io.modelcontextprotocol.util.Assert; +import reactor.core.publisher.Mono; + +/** + * Provider for synchronous stateless MCP prompt methods. + * + * This provider creates prompt specifications for methods annotated with + * {@link McpPrompt} that are designed to work in a stateless manner using + * {@link McpTransportContext}. + * + * @author Christian Tzolov + */ +public class SyncStatelessMcpPromptProvider { + + private static final Logger logger = LoggerFactory.getLogger(SyncStatelessMcpPromptProvider.class); + + private final List promptObjects; + + /** + * Create a new SyncStatelessMcpPromptProvider. + * @param promptObjects the objects containing methods annotated with + * {@link McpPrompt} + */ + public SyncStatelessMcpPromptProvider(List promptObjects) { + Assert.notNull(promptObjects, "promptObjects cannot be null"); + this.promptObjects = promptObjects; + } + + /** + * Get the stateless prompt specifications. + * @return the list of stateless prompt specifications + */ + public List getPromptSpecifications() { + + List promptSpecs = this.promptObjects.stream() + .map(promptObject -> Stream.of(doGetClassMethods(promptObject)) + .filter(method -> method.isAnnotationPresent(McpPrompt.class)) + .filter(method -> !Mono.class.isAssignableFrom(method.getReturnType())) + .map(mcpPromptMethod -> { + var promptAnnotation = mcpPromptMethod.getAnnotation(McpPrompt.class); + var mcpPrompt = PromptAdaptor.asPrompt(promptAnnotation, mcpPromptMethod); + + BiFunction methodCallback = SyncStatelessMcpPromptMethodCallback + .builder() + .method(mcpPromptMethod) + .bean(promptObject) + .prompt(mcpPrompt) + .build(); + + return new SyncPromptSpecification(mcpPrompt, methodCallback); + }) + .toList()) + .flatMap(List::stream) + .toList(); + + if (promptSpecs.isEmpty()) { + logger.warn("No prompt methods found in the provided prompt objects: {}", this.promptObjects); + } + + return promptSpecs; + } + + /** + * Returns the methods of the given bean class. + * @param bean the bean instance + * @return the methods of the bean class + */ + protected Method[] doGetClassMethods(Object bean) { + return bean.getClass().getDeclaredMethods(); + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/SyncStatelessMcpResourceProvider.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/SyncStatelessMcpResourceProvider.java new file mode 100644 index 0000000..2a5630e --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/SyncStatelessMcpResourceProvider.java @@ -0,0 +1,126 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springaicommunity.mcp.provider; + +import java.lang.reflect.Method; +import java.util.List; +import java.util.function.BiFunction; +import java.util.stream.Stream; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springaicommunity.mcp.annotation.McpResource; +import org.springaicommunity.mcp.method.resource.SyncStatelessMcpResourceMethodCallback; + +import io.modelcontextprotocol.server.McpStatelessServerFeatures.SyncResourceSpecification; +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceRequest; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.util.Assert; +import reactor.core.publisher.Mono; + +/** + * Provider for synchronous stateless MCP resource methods. + * + * This provider creates resource specifications for methods annotated with + * {@link McpResource} that are designed to work in a stateless manner using + * {@link McpTransportContext}. + * + * @author Christian Tzolov + */ +public class SyncStatelessMcpResourceProvider { + + private static final Logger logger = LoggerFactory.getLogger(SyncStatelessMcpResourceProvider.class); + + private final List resourceObjects; + + /** + * Create a new SyncStatelessMcpResourceProvider. + * @param resourceObjects the objects containing methods annotated with + * {@link McpResource} + */ + public SyncStatelessMcpResourceProvider(List resourceObjects) { + Assert.notNull(resourceObjects, "resourceObjects cannot be null"); + this.resourceObjects = resourceObjects; + } + + /** + * Get the stateless resource specifications. + * @return the list of stateless resource specifications + */ + public List getResourceSpecifications() { + + List resourceSpecs = this.resourceObjects.stream() + .map(resourceObject -> Stream.of(doGetClassMethods(resourceObject)) + .filter(method -> method.isAnnotationPresent(McpResource.class)) + .filter(method -> !Mono.class.isAssignableFrom(method.getReturnType())) + .map(mcpResourceMethod -> { + + var resourceAnnotation = doGetMcpResourceAnnotation(mcpResourceMethod); + + var uri = resourceAnnotation.uri(); + var name = getName(mcpResourceMethod, resourceAnnotation); + var description = resourceAnnotation.description(); + var mimeType = resourceAnnotation.mimeType(); + + var mcpResource = McpSchema.Resource.builder() + .uri(uri) + .name(name) + .description(description) + .mimeType(mimeType) + .build(); + + BiFunction methodCallback = SyncStatelessMcpResourceMethodCallback + .builder() + .method(mcpResourceMethod) + .bean(resourceObject) + .resource(mcpResource) + .build(); + + var resourceSpec = new SyncResourceSpecification(mcpResource, methodCallback); + + return resourceSpec; + }) + .toList()) + .flatMap(List::stream) + .toList(); + + if (resourceSpecs.isEmpty()) { + logger.warn("No resource methods found in the provided resource objects: {}", this.resourceObjects); + } + + return resourceSpecs; + } + + protected Method[] doGetClassMethods(Object bean) { + return bean.getClass().getDeclaredMethods(); + } + + protected McpResource doGetMcpResourceAnnotation(Method method) { + return method.getAnnotation(McpResource.class); + } + + private static String getName(Method method, McpResource resource) { + Assert.notNull(method, "method cannot be null"); + if (resource == null || resource.name() == null || resource.name().isEmpty()) { + return method.getName(); + } + return resource.name(); + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/SyncStatelessMcpToolProvider.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/SyncStatelessMcpToolProvider.java new file mode 100644 index 0000000..8b12761 --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/SyncStatelessMcpToolProvider.java @@ -0,0 +1,147 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springaicommunity.mcp.provider; + +import java.lang.reflect.Method; +import java.util.List; +import java.util.function.BiFunction; +import java.util.stream.Stream; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springaicommunity.mcp.annotation.McpTool; +import org.springaicommunity.mcp.method.tool.ReactiveUtils; +import org.springaicommunity.mcp.method.tool.ReturnMode; +import org.springaicommunity.mcp.method.tool.SyncStatelessMcpToolMethodCallback; +import org.springaicommunity.mcp.method.tool.utils.ClassUtils; +import org.springaicommunity.mcp.method.tool.utils.JsonSchemaGenerator; + +import io.modelcontextprotocol.server.McpStatelessServerFeatures.SyncToolSpecification; +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.Utils; +import reactor.core.publisher.Mono; + +/** + * Provider for synchronous stateless MCP tool methods. + * + * This provider creates tool specifications for methods annotated with {@link McpTool} + * that are designed to work in a stateless manner using {@link McpTransportContext}. + * + * @author Christian Tzolov + */ +public class SyncStatelessMcpToolProvider { + + private static final Logger logger = LoggerFactory.getLogger(SyncStatelessMcpToolProvider.class); + + private final List toolObjects; + + /** + * Create a new SyncStatelessMcpToolProvider. + * @param toolObjects the objects containing methods annotated with {@link McpTool} + */ + public SyncStatelessMcpToolProvider(List toolObjects) { + Assert.notNull(toolObjects, "toolObjects cannot be null"); + this.toolObjects = toolObjects; + } + + /** + * Get the stateless tool specifications. + * @return the list of stateless tool specifications + */ + public List getToolSpecifications() { + + List toolSpecs = this.toolObjects.stream() + .map(toolObject -> Stream.of(doGetClassMethods(toolObject)) + .filter(method -> method.isAnnotationPresent(McpTool.class)) + .filter(method -> !Mono.class.isAssignableFrom(method.getReturnType())) + .map(mcpToolMethod -> { + + var toolAnnotation = doGetMcpToolAnnotation(mcpToolMethod); + + String toolName = Utils.hasText(toolAnnotation.name()) ? toolAnnotation.name() + : mcpToolMethod.getName(); + + String toolDescrption = toolAnnotation.description(); + + String inputSchema = JsonSchemaGenerator.generateForMethodInput(mcpToolMethod); + + var toolBuilder = McpSchema.Tool.builder() + .name(toolName) + .description(toolDescrption) + .inputSchema(inputSchema); + + // Tool annotations + if (toolAnnotation.annotations() != null) { + var toolAnnotations = toolAnnotation.annotations(); + toolBuilder.annotations(new McpSchema.ToolAnnotations(toolAnnotations.title(), + toolAnnotations.readOnlyHint(), toolAnnotations.destructiveHint(), + toolAnnotations.idempotentHint(), toolAnnotations.openWorldHint(), null)); + } + + ReactiveUtils.isReactiveReturnTypeOfCallToolResult(mcpToolMethod); + // Generate Output Schema from the method return type. + // Output schema is not generated for primitive types, void, + // CallToolResult, simple value types (String, etc.) + // or if generateOutputSchema attribute is set to false. + Class methodReturnType = mcpToolMethod.getReturnType(); + if (toolAnnotation.generateOutputSchema() && methodReturnType != null + && methodReturnType != CallToolResult.class && methodReturnType != Void.class + && methodReturnType != void.class && !ClassUtils.isPrimitiveOrWrapper(methodReturnType) + && !ClassUtils.isSimpleValueType(methodReturnType)) { + + toolBuilder.outputSchema(JsonSchemaGenerator.generateFromClass(methodReturnType)); + } + + var tool = toolBuilder.build(); + + boolean useStructuredOtput = tool.outputSchema() != null; + + ReturnMode returnMode = useStructuredOtput ? ReturnMode.STRUCTURED + : (methodReturnType == Void.TYPE || methodReturnType == void.class ? ReturnMode.VOID + : ReturnMode.TEXT); + + BiFunction methodCallback = new SyncStatelessMcpToolMethodCallback( + returnMode, mcpToolMethod, toolObject); + + var toolSpec = SyncToolSpecification.builder().tool(tool).callHandler(methodCallback).build(); + + return toolSpec; + }) + .toList()) + .flatMap(List::stream) + .toList(); + + if (toolSpecs.isEmpty()) { + logger.warn("No tool methods found in the provided tool objects: {}", this.toolObjects); + } + + return toolSpecs; + } + + protected Method[] doGetClassMethods(Object bean) { + return bean.getClass().getDeclaredMethods(); + } + + protected McpTool doGetMcpToolAnnotation(Method method) { + return method.getAnnotation(McpTool.class); + } + +} diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/complete/AsyncStatelessMcpCompleteMethodCallbackTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/complete/AsyncStatelessMcpCompleteMethodCallbackTests.java new file mode 100644 index 0000000..9834187 --- /dev/null +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/complete/AsyncStatelessMcpCompleteMethodCallbackTests.java @@ -0,0 +1,635 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.method.complete; + +import java.lang.reflect.Method; +import java.util.List; +import java.util.function.BiFunction; + +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema.CompleteRequest; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult.CompleteCompletion; +import io.modelcontextprotocol.spec.McpSchema.PromptReference; +import io.modelcontextprotocol.spec.McpSchema.ResourceReference; +import org.junit.jupiter.api.Test; +import org.springaicommunity.mcp.annotation.McpComplete; + +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +/** + * Tests for {@link AsyncStatelessMcpCompleteMethodCallback}. + * + * @author Christian Tzolov + */ +public class AsyncStatelessMcpCompleteMethodCallbackTests { + + private static class TestAsyncStatelessCompleteProvider { + + public Mono getCompletionWithRequest(CompleteRequest request) { + return Mono.just(new CompleteResult(new CompleteCompletion( + List.of("Async stateless completion for " + request.argument().value()), 1, false))); + } + + public Mono getCompletionWithContext(McpTransportContext context, CompleteRequest request) { + return Mono.just(new CompleteResult(new CompleteCompletion( + List.of("Async stateless completion with context for " + request.argument().value()), 1, false))); + } + + public Mono getCompletionWithArgument(CompleteRequest.CompleteArgument argument) { + return Mono.just(new CompleteResult(new CompleteCompletion( + List.of("Async stateless completion from argument: " + argument.value()), 1, false))); + } + + public Mono getCompletionWithValue(String value) { + return Mono.just(new CompleteResult( + new CompleteCompletion(List.of("Async stateless completion from value: " + value), 1, false))); + } + + @McpComplete(prompt = "test-prompt") + public Mono getCompletionWithPrompt(CompleteRequest request) { + return Mono.just(new CompleteResult(new CompleteCompletion( + List.of("Async stateless completion for prompt with: " + request.argument().value()), 1, false))); + } + + @McpComplete(uri = "test://{variable}") + public Mono getCompletionWithUri(CompleteRequest request) { + return Mono.just(new CompleteResult(new CompleteCompletion( + List.of("Async stateless completion for URI with: " + request.argument().value()), 1, false))); + } + + public Mono getCompletionObject(CompleteRequest request) { + return Mono.just(new CompleteCompletion( + List.of("Async stateless completion object for: " + request.argument().value()), 1, false)); + } + + public Mono> getCompletionList(CompleteRequest request) { + return Mono.just(List.of("Async stateless list item 1 for: " + request.argument().value(), + "Async stateless list item 2 for: " + request.argument().value())); + } + + public Mono getCompletionString(CompleteRequest request) { + return Mono.just("Async stateless string completion for: " + request.argument().value()); + } + + // Non-reactive methods + public CompleteResult getDirectCompletionResult(CompleteRequest request) { + return new CompleteResult(new CompleteCompletion( + List.of("Direct stateless completion for " + request.argument().value()), 1, false)); + } + + public CompleteCompletion getDirectCompletionObject(CompleteRequest request) { + return new CompleteCompletion( + List.of("Direct stateless completion object for: " + request.argument().value()), 1, false); + } + + public List getDirectCompletionList(CompleteRequest request) { + return List.of("Direct stateless list item 1 for: " + request.argument().value(), + "Direct stateless list item 2 for: " + request.argument().value()); + } + + public String getDirectCompletionString(CompleteRequest request) { + return "Direct stateless string completion for: " + request.argument().value(); + } + + public void invalidReturnType(CompleteRequest request) { + // Invalid return type + } + + public Mono invalidParameters(int value) { + return Mono.just(new CompleteResult(new CompleteCompletion(List.of(), 0, false))); + } + + public Mono tooManyParameters(McpTransportContext context, CompleteRequest request, + String extraParam, String extraParam2) { + return Mono.just(new CompleteResult(new CompleteCompletion(List.of(), 0, false))); + } + + public Mono invalidParameterType(Object invalidParam) { + return Mono.just(new CompleteResult(new CompleteCompletion(List.of(), 0, false))); + } + + public Mono duplicateContextParameters(McpTransportContext context1, + McpTransportContext context2) { + return Mono.just(new CompleteResult(new CompleteCompletion(List.of(), 0, false))); + } + + public Mono duplicateRequestParameters(CompleteRequest request1, CompleteRequest request2) { + return Mono.just(new CompleteResult(new CompleteCompletion(List.of(), 0, false))); + } + + public Mono duplicateArgumentParameters(CompleteRequest.CompleteArgument arg1, + CompleteRequest.CompleteArgument arg2) { + return Mono.just(new CompleteResult(new CompleteCompletion(List.of(), 0, false))); + } + + } + + @Test + public void testCallbackWithRequestParameter() throws Exception { + TestAsyncStatelessCompleteProvider provider = new TestAsyncStatelessCompleteProvider(); + Method method = TestAsyncStatelessCompleteProvider.class.getMethod("getCompletionWithRequest", + CompleteRequest.class); + + BiFunction> callback = AsyncStatelessMcpCompleteMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt("test-prompt") + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + CompleteRequest request = new CompleteRequest(new PromptReference("test-prompt"), + new CompleteRequest.CompleteArgument("test", "value")); + + Mono resultMono = callback.apply(context, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.completion()).isNotNull(); + assertThat(result.completion().values()).hasSize(1); + assertThat(result.completion().values().get(0)).isEqualTo("Async stateless completion for value"); + }).verifyComplete(); + } + + @Test + public void testCallbackWithContextAndRequestParameters() throws Exception { + TestAsyncStatelessCompleteProvider provider = new TestAsyncStatelessCompleteProvider(); + Method method = TestAsyncStatelessCompleteProvider.class.getMethod("getCompletionWithContext", + McpTransportContext.class, CompleteRequest.class); + + BiFunction> callback = AsyncStatelessMcpCompleteMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt("test-prompt") + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + CompleteRequest request = new CompleteRequest(new PromptReference("test-prompt"), + new CompleteRequest.CompleteArgument("test", "value")); + + Mono resultMono = callback.apply(context, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.completion()).isNotNull(); + assertThat(result.completion().values()).hasSize(1); + assertThat(result.completion().values().get(0)) + .isEqualTo("Async stateless completion with context for value"); + }).verifyComplete(); + } + + @Test + public void testCallbackWithArgumentParameter() throws Exception { + TestAsyncStatelessCompleteProvider provider = new TestAsyncStatelessCompleteProvider(); + Method method = TestAsyncStatelessCompleteProvider.class.getMethod("getCompletionWithArgument", + CompleteRequest.CompleteArgument.class); + + BiFunction> callback = AsyncStatelessMcpCompleteMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt("test-prompt") + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + CompleteRequest request = new CompleteRequest(new PromptReference("test-prompt"), + new CompleteRequest.CompleteArgument("test", "value")); + + Mono resultMono = callback.apply(context, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.completion()).isNotNull(); + assertThat(result.completion().values()).hasSize(1); + assertThat(result.completion().values().get(0)) + .isEqualTo("Async stateless completion from argument: value"); + }).verifyComplete(); + } + + @Test + public void testCallbackWithValueParameter() throws Exception { + TestAsyncStatelessCompleteProvider provider = new TestAsyncStatelessCompleteProvider(); + Method method = TestAsyncStatelessCompleteProvider.class.getMethod("getCompletionWithValue", String.class); + + BiFunction> callback = AsyncStatelessMcpCompleteMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt("test-prompt") + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + CompleteRequest request = new CompleteRequest(new PromptReference("test-prompt"), + new CompleteRequest.CompleteArgument("test", "value")); + + Mono resultMono = callback.apply(context, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.completion()).isNotNull(); + assertThat(result.completion().values()).hasSize(1); + assertThat(result.completion().values().get(0)).isEqualTo("Async stateless completion from value: value"); + }).verifyComplete(); + } + + @Test + public void testCallbackWithPromptAnnotation() throws Exception { + TestAsyncStatelessCompleteProvider provider = new TestAsyncStatelessCompleteProvider(); + Method method = TestAsyncStatelessCompleteProvider.class.getMethod("getCompletionWithPrompt", + CompleteRequest.class); + McpComplete completeAnnotation = method.getAnnotation(McpComplete.class); + + BiFunction> callback = AsyncStatelessMcpCompleteMethodCallback + .builder() + .method(method) + .bean(provider) + .complete(completeAnnotation) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + CompleteRequest request = new CompleteRequest(new PromptReference("test-prompt"), + new CompleteRequest.CompleteArgument("test", "value")); + + Mono resultMono = callback.apply(context, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.completion()).isNotNull(); + assertThat(result.completion().values()).hasSize(1); + assertThat(result.completion().values().get(0)) + .isEqualTo("Async stateless completion for prompt with: value"); + }).verifyComplete(); + } + + @Test + public void testCallbackWithUriAnnotation() throws Exception { + TestAsyncStatelessCompleteProvider provider = new TestAsyncStatelessCompleteProvider(); + Method method = TestAsyncStatelessCompleteProvider.class.getMethod("getCompletionWithUri", + CompleteRequest.class); + McpComplete completeAnnotation = method.getAnnotation(McpComplete.class); + + BiFunction> callback = AsyncStatelessMcpCompleteMethodCallback + .builder() + .method(method) + .bean(provider) + .complete(completeAnnotation) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + CompleteRequest request = new CompleteRequest(new ResourceReference("test://value"), + new CompleteRequest.CompleteArgument("variable", "value")); + + Mono resultMono = callback.apply(context, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.completion()).isNotNull(); + assertThat(result.completion().values()).hasSize(1); + assertThat(result.completion().values().get(0)).isEqualTo("Async stateless completion for URI with: value"); + }).verifyComplete(); + } + + @Test + public void testCallbackWithCompletionObject() throws Exception { + TestAsyncStatelessCompleteProvider provider = new TestAsyncStatelessCompleteProvider(); + Method method = TestAsyncStatelessCompleteProvider.class.getMethod("getCompletionObject", + CompleteRequest.class); + + BiFunction> callback = AsyncStatelessMcpCompleteMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt("test-prompt") + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + CompleteRequest request = new CompleteRequest(new PromptReference("test-prompt"), + new CompleteRequest.CompleteArgument("test", "value")); + + Mono resultMono = callback.apply(context, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.completion()).isNotNull(); + assertThat(result.completion().values()).hasSize(1); + assertThat(result.completion().values().get(0)).isEqualTo("Async stateless completion object for: value"); + }).verifyComplete(); + } + + @Test + public void testCallbackWithCompletionList() throws Exception { + TestAsyncStatelessCompleteProvider provider = new TestAsyncStatelessCompleteProvider(); + Method method = TestAsyncStatelessCompleteProvider.class.getMethod("getCompletionList", CompleteRequest.class); + + BiFunction> callback = AsyncStatelessMcpCompleteMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt("test-prompt") + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + CompleteRequest request = new CompleteRequest(new PromptReference("test-prompt"), + new CompleteRequest.CompleteArgument("test", "value")); + + Mono resultMono = callback.apply(context, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.completion()).isNotNull(); + assertThat(result.completion().values()).hasSize(2); + assertThat(result.completion().values().get(0)).isEqualTo("Async stateless list item 1 for: value"); + assertThat(result.completion().values().get(1)).isEqualTo("Async stateless list item 2 for: value"); + }).verifyComplete(); + } + + @Test + public void testCallbackWithCompletionString() throws Exception { + TestAsyncStatelessCompleteProvider provider = new TestAsyncStatelessCompleteProvider(); + Method method = TestAsyncStatelessCompleteProvider.class.getMethod("getCompletionString", + CompleteRequest.class); + + BiFunction> callback = AsyncStatelessMcpCompleteMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt("test-prompt") + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + CompleteRequest request = new CompleteRequest(new PromptReference("test-prompt"), + new CompleteRequest.CompleteArgument("test", "value")); + + Mono resultMono = callback.apply(context, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.completion()).isNotNull(); + assertThat(result.completion().values()).hasSize(1); + assertThat(result.completion().values().get(0)).isEqualTo("Async stateless string completion for: value"); + }).verifyComplete(); + } + + @Test + public void testCallbackWithDirectCompletionResult() throws Exception { + TestAsyncStatelessCompleteProvider provider = new TestAsyncStatelessCompleteProvider(); + Method method = TestAsyncStatelessCompleteProvider.class.getMethod("getDirectCompletionResult", + CompleteRequest.class); + + BiFunction> callback = AsyncStatelessMcpCompleteMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt("test-prompt") + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + CompleteRequest request = new CompleteRequest(new PromptReference("test-prompt"), + new CompleteRequest.CompleteArgument("test", "value")); + + Mono resultMono = callback.apply(context, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.completion()).isNotNull(); + assertThat(result.completion().values()).hasSize(1); + assertThat(result.completion().values().get(0)).isEqualTo("Direct stateless completion for value"); + }).verifyComplete(); + } + + @Test + public void testCallbackWithDirectCompletionObject() throws Exception { + TestAsyncStatelessCompleteProvider provider = new TestAsyncStatelessCompleteProvider(); + Method method = TestAsyncStatelessCompleteProvider.class.getMethod("getDirectCompletionObject", + CompleteRequest.class); + + BiFunction> callback = AsyncStatelessMcpCompleteMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt("test-prompt") + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + CompleteRequest request = new CompleteRequest(new PromptReference("test-prompt"), + new CompleteRequest.CompleteArgument("test", "value")); + + Mono resultMono = callback.apply(context, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.completion()).isNotNull(); + assertThat(result.completion().values()).hasSize(1); + assertThat(result.completion().values().get(0)).isEqualTo("Direct stateless completion object for: value"); + }).verifyComplete(); + } + + @Test + public void testCallbackWithDirectCompletionList() throws Exception { + TestAsyncStatelessCompleteProvider provider = new TestAsyncStatelessCompleteProvider(); + Method method = TestAsyncStatelessCompleteProvider.class.getMethod("getDirectCompletionList", + CompleteRequest.class); + + BiFunction> callback = AsyncStatelessMcpCompleteMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt("test-prompt") + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + CompleteRequest request = new CompleteRequest(new PromptReference("test-prompt"), + new CompleteRequest.CompleteArgument("test", "value")); + + Mono resultMono = callback.apply(context, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.completion()).isNotNull(); + assertThat(result.completion().values()).hasSize(2); + assertThat(result.completion().values().get(0)).isEqualTo("Direct stateless list item 1 for: value"); + assertThat(result.completion().values().get(1)).isEqualTo("Direct stateless list item 2 for: value"); + }).verifyComplete(); + } + + @Test + public void testCallbackWithDirectCompletionString() throws Exception { + TestAsyncStatelessCompleteProvider provider = new TestAsyncStatelessCompleteProvider(); + Method method = TestAsyncStatelessCompleteProvider.class.getMethod("getDirectCompletionString", + CompleteRequest.class); + + BiFunction> callback = AsyncStatelessMcpCompleteMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt("test-prompt") + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + CompleteRequest request = new CompleteRequest(new PromptReference("test-prompt"), + new CompleteRequest.CompleteArgument("test", "value")); + + Mono resultMono = callback.apply(context, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.completion()).isNotNull(); + assertThat(result.completion().values()).hasSize(1); + assertThat(result.completion().values().get(0)).isEqualTo("Direct stateless string completion for: value"); + }).verifyComplete(); + } + + @Test + public void testInvalidReturnType() throws Exception { + TestAsyncStatelessCompleteProvider provider = new TestAsyncStatelessCompleteProvider(); + Method method = TestAsyncStatelessCompleteProvider.class.getMethod("invalidReturnType", CompleteRequest.class); + + assertThatThrownBy(() -> AsyncStatelessMcpCompleteMethodCallback.builder() + .method(method) + .bean(provider) + .prompt("test-prompt") + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining( + "Method must return either CompleteResult, CompleteCompletion, List, String, or Mono"); + } + + @Test + public void testInvalidParameters() throws Exception { + TestAsyncStatelessCompleteProvider provider = new TestAsyncStatelessCompleteProvider(); + Method method = TestAsyncStatelessCompleteProvider.class.getMethod("invalidParameters", int.class); + + assertThatThrownBy(() -> AsyncStatelessMcpCompleteMethodCallback.builder() + .method(method) + .bean(provider) + .prompt("test-prompt") + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Method parameters must be exchange, CompleteRequest, CompleteArgument, or String"); + } + + @Test + public void testTooManyParameters() throws Exception { + TestAsyncStatelessCompleteProvider provider = new TestAsyncStatelessCompleteProvider(); + Method method = TestAsyncStatelessCompleteProvider.class.getMethod("tooManyParameters", + McpTransportContext.class, CompleteRequest.class, String.class, String.class); + + assertThatThrownBy(() -> AsyncStatelessMcpCompleteMethodCallback.builder() + .method(method) + .bean(provider) + .prompt("test-prompt") + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Method can have at most 3 input parameters"); + } + + @Test + public void testInvalidParameterType() throws Exception { + TestAsyncStatelessCompleteProvider provider = new TestAsyncStatelessCompleteProvider(); + Method method = TestAsyncStatelessCompleteProvider.class.getMethod("invalidParameterType", Object.class); + + assertThatThrownBy(() -> AsyncStatelessMcpCompleteMethodCallback.builder() + .method(method) + .bean(provider) + .prompt("test-prompt") + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Method parameters must be exchange, CompleteRequest, CompleteArgument, or String"); + } + + @Test + public void testDuplicateContextParameters() throws Exception { + TestAsyncStatelessCompleteProvider provider = new TestAsyncStatelessCompleteProvider(); + Method method = TestAsyncStatelessCompleteProvider.class.getMethod("duplicateContextParameters", + McpTransportContext.class, McpTransportContext.class); + + assertThatThrownBy(() -> AsyncStatelessMcpCompleteMethodCallback.builder() + .method(method) + .bean(provider) + .prompt("test-prompt") + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Method cannot have more than one exchange parameter"); + } + + @Test + public void testDuplicateRequestParameters() throws Exception { + TestAsyncStatelessCompleteProvider provider = new TestAsyncStatelessCompleteProvider(); + Method method = TestAsyncStatelessCompleteProvider.class.getMethod("duplicateRequestParameters", + CompleteRequest.class, CompleteRequest.class); + + assertThatThrownBy(() -> AsyncStatelessMcpCompleteMethodCallback.builder() + .method(method) + .bean(provider) + .prompt("test-prompt") + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Method cannot have more than one CompleteRequest parameter"); + } + + @Test + public void testDuplicateArgumentParameters() throws Exception { + TestAsyncStatelessCompleteProvider provider = new TestAsyncStatelessCompleteProvider(); + Method method = TestAsyncStatelessCompleteProvider.class.getMethod("duplicateArgumentParameters", + CompleteRequest.CompleteArgument.class, CompleteRequest.CompleteArgument.class); + + assertThatThrownBy(() -> AsyncStatelessMcpCompleteMethodCallback.builder() + .method(method) + .bean(provider) + .prompt("test-prompt") + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Method cannot have more than one CompleteArgument parameter"); + } + + @Test + public void testMissingPromptAndUri() throws Exception { + TestAsyncStatelessCompleteProvider provider = new TestAsyncStatelessCompleteProvider(); + Method method = TestAsyncStatelessCompleteProvider.class.getMethod("getCompletionWithRequest", + CompleteRequest.class); + + assertThatThrownBy( + () -> AsyncStatelessMcpCompleteMethodCallback.builder().method(method).bean(provider).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Either prompt or uri must be provided"); + } + + @Test + public void testBothPromptAndUri() throws Exception { + TestAsyncStatelessCompleteProvider provider = new TestAsyncStatelessCompleteProvider(); + Method method = TestAsyncStatelessCompleteProvider.class.getMethod("getCompletionWithRequest", + CompleteRequest.class); + + assertThatThrownBy(() -> AsyncStatelessMcpCompleteMethodCallback.builder() + .method(method) + .bean(provider) + .prompt("test-prompt") + .uri("test://resource") + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Only one of prompt or uri can be provided"); + } + + @Test + public void testNullRequest() throws Exception { + TestAsyncStatelessCompleteProvider provider = new TestAsyncStatelessCompleteProvider(); + Method method = TestAsyncStatelessCompleteProvider.class.getMethod("getCompletionWithRequest", + CompleteRequest.class); + + BiFunction> callback = AsyncStatelessMcpCompleteMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt("test-prompt") + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + + StepVerifier.create(callback.apply(context, null)) + .expectErrorMatches(throwable -> throwable instanceof IllegalArgumentException + && throwable.getMessage().contains("Request must not be null")) + .verify(); + } + +} diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/complete/SyncStatelessMcpCompleteMethodCallbackTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/complete/SyncStatelessMcpCompleteMethodCallbackTests.java new file mode 100644 index 0000000..f62b103 --- /dev/null +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/complete/SyncStatelessMcpCompleteMethodCallbackTests.java @@ -0,0 +1,468 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.method.complete; + +import java.lang.reflect.Method; +import java.util.List; +import java.util.function.BiFunction; + +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema.CompleteRequest; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult.CompleteCompletion; +import io.modelcontextprotocol.spec.McpSchema.PromptReference; +import io.modelcontextprotocol.spec.McpSchema.ResourceReference; +import org.junit.jupiter.api.Test; +import org.springaicommunity.mcp.annotation.McpComplete; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +/** + * Tests for {@link SyncStatelessMcpCompleteMethodCallback}. + * + * @author Christian Tzolov + */ +public class SyncStatelessMcpCompleteMethodCallbackTests { + + private static class TestCompleteProvider { + + public CompleteResult getCompletionWithRequest(CompleteRequest request) { + return new CompleteResult( + new CompleteCompletion(List.of("Completion for " + request.argument().value()), 1, false)); + } + + public CompleteResult getCompletionWithContext(McpTransportContext context, CompleteRequest request) { + return new CompleteResult(new CompleteCompletion( + List.of("Completion with context for " + request.argument().value()), 1, false)); + } + + public CompleteResult getCompletionWithArgument(CompleteRequest.CompleteArgument argument) { + return new CompleteResult( + new CompleteCompletion(List.of("Completion from argument: " + argument.value()), 1, false)); + } + + public CompleteResult getCompletionWithValue(String value) { + return new CompleteResult(new CompleteCompletion(List.of("Completion from value: " + value), 1, false)); + } + + @McpComplete(prompt = "test-prompt") + public CompleteResult getCompletionWithPrompt(CompleteRequest request) { + return new CompleteResult(new CompleteCompletion( + List.of("Completion for prompt with: " + request.argument().value()), 1, false)); + } + + @McpComplete(uri = "test://{variable}") + public CompleteResult getCompletionWithUri(CompleteRequest request) { + return new CompleteResult(new CompleteCompletion( + List.of("Completion for URI with: " + request.argument().value()), 1, false)); + } + + public CompleteCompletion getCompletionObject(CompleteRequest request) { + return new CompleteCompletion(List.of("Completion object for: " + request.argument().value()), 1, false); + } + + public List getCompletionList(CompleteRequest request) { + return List.of("List item 1 for: " + request.argument().value(), + "List item 2 for: " + request.argument().value()); + } + + public String getCompletionString(CompleteRequest request) { + return "String completion for: " + request.argument().value(); + } + + public void invalidReturnType(CompleteRequest request) { + // Invalid return type + } + + public CompleteResult invalidParameters(int value) { + return new CompleteResult(new CompleteCompletion(List.of(), 0, false)); + } + + public CompleteResult tooManyParameters(McpTransportContext context, CompleteRequest request, String extraParam, + String extraParam2) { + return new CompleteResult(new CompleteCompletion(List.of(), 0, false)); + } + + public CompleteResult invalidParameterType(Object invalidParam) { + return new CompleteResult(new CompleteCompletion(List.of(), 0, false)); + } + + public CompleteResult duplicateContextParameters(McpTransportContext context1, McpTransportContext context2) { + return new CompleteResult(new CompleteCompletion(List.of(), 0, false)); + } + + public CompleteResult duplicateRequestParameters(CompleteRequest request1, CompleteRequest request2) { + return new CompleteResult(new CompleteCompletion(List.of(), 0, false)); + } + + public CompleteResult duplicateArgumentParameters(CompleteRequest.CompleteArgument arg1, + CompleteRequest.CompleteArgument arg2) { + return new CompleteResult(new CompleteCompletion(List.of(), 0, false)); + } + + } + + @Test + public void testCallbackWithRequestParameter() throws Exception { + TestCompleteProvider provider = new TestCompleteProvider(); + Method method = TestCompleteProvider.class.getMethod("getCompletionWithRequest", CompleteRequest.class); + + BiFunction callback = SyncStatelessMcpCompleteMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt("test-prompt") + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + CompleteRequest request = new CompleteRequest(new PromptReference("test-prompt"), + new CompleteRequest.CompleteArgument("test", "value")); + + CompleteResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.completion()).isNotNull(); + assertThat(result.completion().values()).hasSize(1); + assertThat(result.completion().values().get(0)).isEqualTo("Completion for value"); + } + + @Test + public void testCallbackWithContextAndRequestParameters() throws Exception { + TestCompleteProvider provider = new TestCompleteProvider(); + Method method = TestCompleteProvider.class.getMethod("getCompletionWithContext", McpTransportContext.class, + CompleteRequest.class); + + BiFunction callback = SyncStatelessMcpCompleteMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt("test-prompt") + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + CompleteRequest request = new CompleteRequest(new PromptReference("test-prompt"), + new CompleteRequest.CompleteArgument("test", "value")); + + CompleteResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.completion()).isNotNull(); + assertThat(result.completion().values()).hasSize(1); + assertThat(result.completion().values().get(0)).isEqualTo("Completion with context for value"); + } + + @Test + public void testCallbackWithArgumentParameter() throws Exception { + TestCompleteProvider provider = new TestCompleteProvider(); + Method method = TestCompleteProvider.class.getMethod("getCompletionWithArgument", + CompleteRequest.CompleteArgument.class); + + BiFunction callback = SyncStatelessMcpCompleteMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt("test-prompt") + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + CompleteRequest request = new CompleteRequest(new PromptReference("test-prompt"), + new CompleteRequest.CompleteArgument("test", "value")); + + CompleteResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.completion()).isNotNull(); + assertThat(result.completion().values()).hasSize(1); + assertThat(result.completion().values().get(0)).isEqualTo("Completion from argument: value"); + } + + @Test + public void testCallbackWithValueParameter() throws Exception { + TestCompleteProvider provider = new TestCompleteProvider(); + Method method = TestCompleteProvider.class.getMethod("getCompletionWithValue", String.class); + + BiFunction callback = SyncStatelessMcpCompleteMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt("test-prompt") + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + CompleteRequest request = new CompleteRequest(new PromptReference("test-prompt"), + new CompleteRequest.CompleteArgument("test", "value")); + + CompleteResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.completion()).isNotNull(); + assertThat(result.completion().values()).hasSize(1); + assertThat(result.completion().values().get(0)).isEqualTo("Completion from value: value"); + } + + @Test + public void testCallbackWithPromptAnnotation() throws Exception { + TestCompleteProvider provider = new TestCompleteProvider(); + Method method = TestCompleteProvider.class.getMethod("getCompletionWithPrompt", CompleteRequest.class); + McpComplete completeAnnotation = method.getAnnotation(McpComplete.class); + + BiFunction callback = SyncStatelessMcpCompleteMethodCallback + .builder() + .method(method) + .bean(provider) + .complete(completeAnnotation) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + CompleteRequest request = new CompleteRequest(new PromptReference("test-prompt"), + new CompleteRequest.CompleteArgument("test", "value")); + + CompleteResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.completion()).isNotNull(); + assertThat(result.completion().values()).hasSize(1); + assertThat(result.completion().values().get(0)).isEqualTo("Completion for prompt with: value"); + } + + @Test + public void testCallbackWithUriAnnotation() throws Exception { + TestCompleteProvider provider = new TestCompleteProvider(); + Method method = TestCompleteProvider.class.getMethod("getCompletionWithUri", CompleteRequest.class); + McpComplete completeAnnotation = method.getAnnotation(McpComplete.class); + + BiFunction callback = SyncStatelessMcpCompleteMethodCallback + .builder() + .method(method) + .bean(provider) + .complete(completeAnnotation) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + CompleteRequest request = new CompleteRequest(new ResourceReference("test://value"), + new CompleteRequest.CompleteArgument("variable", "value")); + + CompleteResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.completion()).isNotNull(); + assertThat(result.completion().values()).hasSize(1); + assertThat(result.completion().values().get(0)).isEqualTo("Completion for URI with: value"); + } + + @Test + public void testCallbackWithCompletionObject() throws Exception { + TestCompleteProvider provider = new TestCompleteProvider(); + Method method = TestCompleteProvider.class.getMethod("getCompletionObject", CompleteRequest.class); + + BiFunction callback = SyncStatelessMcpCompleteMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt("test-prompt") + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + CompleteRequest request = new CompleteRequest(new PromptReference("test-prompt"), + new CompleteRequest.CompleteArgument("test", "value")); + + CompleteResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.completion()).isNotNull(); + assertThat(result.completion().values()).hasSize(1); + assertThat(result.completion().values().get(0)).isEqualTo("Completion object for: value"); + } + + @Test + public void testCallbackWithCompletionList() throws Exception { + TestCompleteProvider provider = new TestCompleteProvider(); + Method method = TestCompleteProvider.class.getMethod("getCompletionList", CompleteRequest.class); + + BiFunction callback = SyncStatelessMcpCompleteMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt("test-prompt") + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + CompleteRequest request = new CompleteRequest(new PromptReference("test-prompt"), + new CompleteRequest.CompleteArgument("test", "value")); + + CompleteResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.completion()).isNotNull(); + assertThat(result.completion().values()).hasSize(2); + assertThat(result.completion().values().get(0)).isEqualTo("List item 1 for: value"); + assertThat(result.completion().values().get(1)).isEqualTo("List item 2 for: value"); + } + + @Test + public void testCallbackWithCompletionString() throws Exception { + TestCompleteProvider provider = new TestCompleteProvider(); + Method method = TestCompleteProvider.class.getMethod("getCompletionString", CompleteRequest.class); + + BiFunction callback = SyncStatelessMcpCompleteMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt("test-prompt") + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + CompleteRequest request = new CompleteRequest(new PromptReference("test-prompt"), + new CompleteRequest.CompleteArgument("test", "value")); + + CompleteResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.completion()).isNotNull(); + assertThat(result.completion().values()).hasSize(1); + assertThat(result.completion().values().get(0)).isEqualTo("String completion for: value"); + } + + @Test + public void testInvalidReturnType() throws Exception { + TestCompleteProvider provider = new TestCompleteProvider(); + Method method = TestCompleteProvider.class.getMethod("invalidReturnType", CompleteRequest.class); + + assertThatThrownBy(() -> SyncStatelessMcpCompleteMethodCallback.builder() + .method(method) + .bean(provider) + .prompt("test-prompt") + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining( + "Method must return either CompleteResult, CompleteCompletion, List, or String"); + } + + @Test + public void testInvalidParameters() throws Exception { + TestCompleteProvider provider = new TestCompleteProvider(); + Method method = TestCompleteProvider.class.getMethod("invalidParameters", int.class); + + assertThatThrownBy(() -> SyncStatelessMcpCompleteMethodCallback.builder() + .method(method) + .bean(provider) + .prompt("test-prompt") + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Method parameters must be exchange, CompleteRequest, CompleteArgument, or String"); + } + + @Test + public void testTooManyParameters() throws Exception { + TestCompleteProvider provider = new TestCompleteProvider(); + Method method = TestCompleteProvider.class.getMethod("tooManyParameters", McpTransportContext.class, + CompleteRequest.class, String.class, String.class); + + assertThatThrownBy(() -> SyncStatelessMcpCompleteMethodCallback.builder() + .method(method) + .bean(provider) + .prompt("test-prompt") + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Method can have at most 3 input parameters"); + } + + @Test + public void testInvalidParameterType() throws Exception { + TestCompleteProvider provider = new TestCompleteProvider(); + Method method = TestCompleteProvider.class.getMethod("invalidParameterType", Object.class); + + assertThatThrownBy(() -> SyncStatelessMcpCompleteMethodCallback.builder() + .method(method) + .bean(provider) + .prompt("test-prompt") + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Method parameters must be exchange, CompleteRequest, CompleteArgument, or String"); + } + + @Test + public void testDuplicateContextParameters() throws Exception { + TestCompleteProvider provider = new TestCompleteProvider(); + Method method = TestCompleteProvider.class.getMethod("duplicateContextParameters", McpTransportContext.class, + McpTransportContext.class); + + assertThatThrownBy(() -> SyncStatelessMcpCompleteMethodCallback.builder() + .method(method) + .bean(provider) + .prompt("test-prompt") + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Method cannot have more than one exchange parameter"); + } + + @Test + public void testDuplicateRequestParameters() throws Exception { + TestCompleteProvider provider = new TestCompleteProvider(); + Method method = TestCompleteProvider.class.getMethod("duplicateRequestParameters", CompleteRequest.class, + CompleteRequest.class); + + assertThatThrownBy(() -> SyncStatelessMcpCompleteMethodCallback.builder() + .method(method) + .bean(provider) + .prompt("test-prompt") + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Method cannot have more than one CompleteRequest parameter"); + } + + @Test + public void testDuplicateArgumentParameters() throws Exception { + TestCompleteProvider provider = new TestCompleteProvider(); + Method method = TestCompleteProvider.class.getMethod("duplicateArgumentParameters", + CompleteRequest.CompleteArgument.class, CompleteRequest.CompleteArgument.class); + + assertThatThrownBy(() -> SyncStatelessMcpCompleteMethodCallback.builder() + .method(method) + .bean(provider) + .prompt("test-prompt") + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Method cannot have more than one CompleteArgument parameter"); + } + + @Test + public void testMissingPromptAndUri() throws Exception { + TestCompleteProvider provider = new TestCompleteProvider(); + Method method = TestCompleteProvider.class.getMethod("getCompletionWithRequest", CompleteRequest.class); + + assertThatThrownBy(() -> SyncStatelessMcpCompleteMethodCallback.builder().method(method).bean(provider).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Either prompt or uri must be provided"); + } + + @Test + public void testBothPromptAndUri() throws Exception { + TestCompleteProvider provider = new TestCompleteProvider(); + Method method = TestCompleteProvider.class.getMethod("getCompletionWithRequest", CompleteRequest.class); + + assertThatThrownBy(() -> SyncStatelessMcpCompleteMethodCallback.builder() + .method(method) + .bean(provider) + .prompt("test-prompt") + .uri("test://resource") + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Only one of prompt or uri can be provided"); + } + + @Test + public void testNullRequest() throws Exception { + TestCompleteProvider provider = new TestCompleteProvider(); + Method method = TestCompleteProvider.class.getMethod("getCompletionWithRequest", CompleteRequest.class); + + BiFunction callback = SyncStatelessMcpCompleteMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt("test-prompt") + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + + assertThatThrownBy(() -> callback.apply(context, null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Request must not be null"); + } + +} diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/prompt/AsyncStatelessMcpPromptMethodCallbackTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/prompt/AsyncStatelessMcpPromptMethodCallbackTests.java new file mode 100644 index 0000000..6e4b983 --- /dev/null +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/prompt/AsyncStatelessMcpPromptMethodCallbackTests.java @@ -0,0 +1,687 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.method.prompt; + +import java.lang.reflect.Method; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; + +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; +import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptArgument; +import io.modelcontextprotocol.spec.McpSchema.PromptMessage; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import org.junit.jupiter.api.Test; +import org.springaicommunity.mcp.annotation.McpArg; +import org.springaicommunity.mcp.annotation.McpPrompt; + +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +/** + * Tests for {@link AsyncStatelessMcpPromptMethodCallback}. + * + * @author Christian Tzolov + */ +public class AsyncStatelessMcpPromptMethodCallbackTests { + + private static class TestPromptProvider { + + @McpPrompt(name = "greeting", description = "A simple greeting prompt") + public GetPromptResult getPromptWithRequest(GetPromptRequest request) { + return new GetPromptResult("Greeting prompt", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent("Hello from " + request.name())))); + } + + @McpPrompt(name = "context-greeting", description = "A greeting prompt with context") + public GetPromptResult getPromptWithContext(McpTransportContext context, GetPromptRequest request) { + return new GetPromptResult("Greeting with context", List + .of(new PromptMessage(Role.ASSISTANT, new TextContent("Hello with context from " + request.name())))); + } + + @McpPrompt(name = "arguments-greeting", description = "A greeting prompt with arguments") + public GetPromptResult getPromptWithArguments(Map arguments) { + String name = arguments.containsKey("name") ? arguments.get("name").toString() : "unknown"; + return new GetPromptResult("Greeting with arguments", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent("Hello " + name + " from arguments")))); + } + + @McpPrompt(name = "individual-args", description = "A prompt with individual arguments") + public GetPromptResult getPromptWithIndividualArgs( + @McpArg(name = "name", description = "The user's name", required = true) String name, + @McpArg(name = "age", description = "The user's age", required = true) Integer age) { + return new GetPromptResult("Individual arguments prompt", List.of(new PromptMessage(Role.ASSISTANT, + new TextContent("Hello " + name + ", you are " + age + " years old")))); + } + + @McpPrompt(name = "mixed-args", description = "A prompt with mixed argument types") + public GetPromptResult getPromptWithMixedArgs(McpTransportContext context, + @McpArg(name = "name", description = "The user's name", required = true) String name, + @McpArg(name = "age", description = "The user's age", required = true) Integer age) { + return new GetPromptResult("Mixed arguments prompt", List.of(new PromptMessage(Role.ASSISTANT, + new TextContent("Hello " + name + ", you are " + age + " years old (with context)")))); + } + + @McpPrompt(name = "list-messages", description = "A prompt returning a list of messages") + public List getPromptMessagesList(GetPromptRequest request) { + return List.of(new PromptMessage(Role.ASSISTANT, new TextContent("Message 1 for " + request.name())), + new PromptMessage(Role.ASSISTANT, new TextContent("Message 2 for " + request.name()))); + } + + @McpPrompt(name = "string-prompt", description = "A prompt returning a string") + public String getStringPrompt(GetPromptRequest request) { + return "Simple string response for " + request.name(); + } + + @McpPrompt(name = "single-message", description = "A prompt returning a single message") + public PromptMessage getSingleMessage(GetPromptRequest request) { + return new PromptMessage(Role.ASSISTANT, new TextContent("Single message for " + request.name())); + } + + @McpPrompt(name = "string-list", description = "A prompt returning a list of strings") + public List getStringList(GetPromptRequest request) { + return List.of("String 1 for " + request.name(), "String 2 for " + request.name(), + "String 3 for " + request.name()); + } + + @McpPrompt(name = "mono-prompt", description = "A prompt returning a Mono") + public Mono getMonoPrompt(GetPromptRequest request) { + return Mono.just(new GetPromptResult("Mono prompt", List + .of(new PromptMessage(Role.ASSISTANT, new TextContent("Async response for " + request.name()))))); + } + + @McpPrompt(name = "mono-string", description = "A prompt returning a Mono") + public Mono getMonoString(GetPromptRequest request) { + return Mono.just("Async string response for " + request.name()); + } + + @McpPrompt(name = "mono-message", description = "A prompt returning a Mono") + public Mono getMonoMessage(GetPromptRequest request) { + return Mono + .just(new PromptMessage(Role.ASSISTANT, new TextContent("Async single message for " + request.name()))); + } + + @McpPrompt(name = "mono-message-list", description = "A prompt returning a Mono>") + public Mono> getMonoMessageList(GetPromptRequest request) { + return Mono.just(List.of( + new PromptMessage(Role.ASSISTANT, new TextContent("Async message 1 for " + request.name())), + new PromptMessage(Role.ASSISTANT, new TextContent("Async message 2 for " + request.name())))); + } + + @McpPrompt(name = "mono-string-list", description = "A prompt returning a Mono>") + public Mono> getMonoStringList(GetPromptRequest request) { + return Mono.just(List.of("Async string 1 for " + request.name(), "Async string 2 for " + request.name(), + "Async string 3 for " + request.name())); + } + + public void invalidReturnType(GetPromptRequest request) { + // Invalid return type + } + + public GetPromptResult duplicateContextParameters(McpTransportContext context1, McpTransportContext context2) { + return new GetPromptResult("Invalid", List.of()); + } + + public GetPromptResult duplicateRequestParameters(GetPromptRequest request1, GetPromptRequest request2) { + return new GetPromptResult("Invalid", List.of()); + } + + public GetPromptResult duplicateMapParameters(Map args1, Map args2) { + return new GetPromptResult("Invalid", List.of()); + } + + } + + private Prompt createTestPrompt(String name, String description) { + return new Prompt(name, description, List.of(new PromptArgument("name", "User's name", true), + new PromptArgument("age", "User's age", false))); + } + + @Test + public void testCallbackWithRequestParameter() throws Exception { + TestPromptProvider provider = new TestPromptProvider(); + Method method = TestPromptProvider.class.getMethod("getPromptWithRequest", GetPromptRequest.class); + + Prompt prompt = createTestPrompt("greeting", "A simple greeting prompt"); + + BiFunction> callback = AsyncStatelessMcpPromptMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt(prompt) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + args.put("name", "John"); + GetPromptRequest request = new GetPromptRequest("greeting", args); + + Mono resultMono = callback.apply(context, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.description()).isEqualTo("Greeting prompt"); + assertThat(result.messages()).hasSize(1); + PromptMessage message = result.messages().get(0); + assertThat(message.role()).isEqualTo(Role.ASSISTANT); + assertThat(((TextContent) message.content()).text()).isEqualTo("Hello from greeting"); + }).verifyComplete(); + } + + @Test + public void testCallbackWithContextAndRequestParameters() throws Exception { + TestPromptProvider provider = new TestPromptProvider(); + Method method = TestPromptProvider.class.getMethod("getPromptWithContext", McpTransportContext.class, + GetPromptRequest.class); + + Prompt prompt = createTestPrompt("context-greeting", "A greeting prompt with context"); + + BiFunction> callback = AsyncStatelessMcpPromptMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt(prompt) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + args.put("name", "John"); + GetPromptRequest request = new GetPromptRequest("context-greeting", args); + + Mono resultMono = callback.apply(context, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.description()).isEqualTo("Greeting with context"); + assertThat(result.messages()).hasSize(1); + PromptMessage message = result.messages().get(0); + assertThat(message.role()).isEqualTo(Role.ASSISTANT); + assertThat(((TextContent) message.content()).text()).isEqualTo("Hello with context from context-greeting"); + }).verifyComplete(); + } + + @Test + public void testCallbackWithArgumentsMap() throws Exception { + TestPromptProvider provider = new TestPromptProvider(); + Method method = TestPromptProvider.class.getMethod("getPromptWithArguments", Map.class); + + Prompt prompt = createTestPrompt("arguments-greeting", "A greeting prompt with arguments"); + + BiFunction> callback = AsyncStatelessMcpPromptMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt(prompt) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + args.put("name", "John"); + GetPromptRequest request = new GetPromptRequest("arguments-greeting", args); + + Mono resultMono = callback.apply(context, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.description()).isEqualTo("Greeting with arguments"); + assertThat(result.messages()).hasSize(1); + PromptMessage message = result.messages().get(0); + assertThat(message.role()).isEqualTo(Role.ASSISTANT); + assertThat(((TextContent) message.content()).text()).isEqualTo("Hello John from arguments"); + }).verifyComplete(); + } + + @Test + public void testCallbackWithIndividualArguments() throws Exception { + TestPromptProvider provider = new TestPromptProvider(); + Method method = TestPromptProvider.class.getMethod("getPromptWithIndividualArgs", String.class, Integer.class); + + Prompt prompt = createTestPrompt("individual-args", "A prompt with individual arguments"); + + BiFunction> callback = AsyncStatelessMcpPromptMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt(prompt) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + args.put("name", "John"); + args.put("age", 30); + GetPromptRequest request = new GetPromptRequest("individual-args", args); + + Mono resultMono = callback.apply(context, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.description()).isEqualTo("Individual arguments prompt"); + assertThat(result.messages()).hasSize(1); + PromptMessage message = result.messages().get(0); + assertThat(message.role()).isEqualTo(Role.ASSISTANT); + assertThat(((TextContent) message.content()).text()).isEqualTo("Hello John, you are 30 years old"); + }).verifyComplete(); + } + + @Test + public void testCallbackWithMixedArguments() throws Exception { + TestPromptProvider provider = new TestPromptProvider(); + Method method = TestPromptProvider.class.getMethod("getPromptWithMixedArgs", McpTransportContext.class, + String.class, Integer.class); + + Prompt prompt = createTestPrompt("mixed-args", "A prompt with mixed argument types"); + + BiFunction> callback = AsyncStatelessMcpPromptMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt(prompt) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + args.put("name", "John"); + args.put("age", 30); + GetPromptRequest request = new GetPromptRequest("mixed-args", args); + + Mono resultMono = callback.apply(context, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.description()).isEqualTo("Mixed arguments prompt"); + assertThat(result.messages()).hasSize(1); + PromptMessage message = result.messages().get(0); + assertThat(message.role()).isEqualTo(Role.ASSISTANT); + assertThat(((TextContent) message.content()).text()) + .isEqualTo("Hello John, you are 30 years old (with context)"); + }).verifyComplete(); + } + + @Test + public void testCallbackWithMessagesList() throws Exception { + TestPromptProvider provider = new TestPromptProvider(); + Method method = TestPromptProvider.class.getMethod("getPromptMessagesList", GetPromptRequest.class); + + Prompt prompt = createTestPrompt("list-messages", "A prompt returning a list of messages"); + + BiFunction> callback = AsyncStatelessMcpPromptMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt(prompt) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + args.put("name", "John"); + GetPromptRequest request = new GetPromptRequest("list-messages", args); + + Mono resultMono = callback.apply(context, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.description()).isNull(); + assertThat(result.messages()).hasSize(2); + PromptMessage message1 = result.messages().get(0); + PromptMessage message2 = result.messages().get(1); + assertThat(message1.role()).isEqualTo(Role.ASSISTANT); + assertThat(message2.role()).isEqualTo(Role.ASSISTANT); + assertThat(((TextContent) message1.content()).text()).isEqualTo("Message 1 for list-messages"); + assertThat(((TextContent) message2.content()).text()).isEqualTo("Message 2 for list-messages"); + }).verifyComplete(); + } + + @Test + public void testCallbackWithStringReturn() throws Exception { + TestPromptProvider provider = new TestPromptProvider(); + Method method = TestPromptProvider.class.getMethod("getStringPrompt", GetPromptRequest.class); + + Prompt prompt = createTestPrompt("string-prompt", "A prompt returning a string"); + + BiFunction> callback = AsyncStatelessMcpPromptMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt(prompt) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + args.put("name", "John"); + GetPromptRequest request = new GetPromptRequest("string-prompt", args); + + Mono resultMono = callback.apply(context, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.messages()).hasSize(1); + PromptMessage message = result.messages().get(0); + assertThat(message.role()).isEqualTo(Role.ASSISTANT); + assertThat(((TextContent) message.content()).text()).isEqualTo("Simple string response for string-prompt"); + }).verifyComplete(); + } + + @Test + public void testCallbackWithSingleMessage() throws Exception { + TestPromptProvider provider = new TestPromptProvider(); + Method method = TestPromptProvider.class.getMethod("getSingleMessage", GetPromptRequest.class); + + Prompt prompt = createTestPrompt("single-message", "A prompt returning a single message"); + + BiFunction> callback = AsyncStatelessMcpPromptMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt(prompt) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + args.put("name", "John"); + GetPromptRequest request = new GetPromptRequest("single-message", args); + + Mono resultMono = callback.apply(context, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.description()).isNull(); + assertThat(result.messages()).hasSize(1); + PromptMessage message = result.messages().get(0); + assertThat(message.role()).isEqualTo(Role.ASSISTANT); + assertThat(((TextContent) message.content()).text()).isEqualTo("Single message for single-message"); + }).verifyComplete(); + } + + @Test + public void testCallbackWithStringList() throws Exception { + TestPromptProvider provider = new TestPromptProvider(); + Method method = TestPromptProvider.class.getMethod("getStringList", GetPromptRequest.class); + + Prompt prompt = createTestPrompt("string-list", "A prompt returning a list of strings"); + + BiFunction> callback = AsyncStatelessMcpPromptMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt(prompt) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + args.put("name", "John"); + GetPromptRequest request = new GetPromptRequest("string-list", args); + + Mono resultMono = callback.apply(context, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.description()).isNull(); + assertThat(result.messages()).hasSize(3); + + PromptMessage message1 = result.messages().get(0); + PromptMessage message2 = result.messages().get(1); + PromptMessage message3 = result.messages().get(2); + + assertThat(message1.role()).isEqualTo(Role.ASSISTANT); + assertThat(message2.role()).isEqualTo(Role.ASSISTANT); + assertThat(message3.role()).isEqualTo(Role.ASSISTANT); + + assertThat(((TextContent) message1.content()).text()).isEqualTo("String 1 for string-list"); + assertThat(((TextContent) message2.content()).text()).isEqualTo("String 2 for string-list"); + assertThat(((TextContent) message3.content()).text()).isEqualTo("String 3 for string-list"); + }).verifyComplete(); + } + + @Test + public void testCallbackWithMonoPromptResult() throws Exception { + TestPromptProvider provider = new TestPromptProvider(); + Method method = TestPromptProvider.class.getMethod("getMonoPrompt", GetPromptRequest.class); + + Prompt prompt = createTestPrompt("mono-prompt", "A prompt returning a Mono"); + + BiFunction> callback = AsyncStatelessMcpPromptMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt(prompt) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + args.put("name", "John"); + GetPromptRequest request = new GetPromptRequest("mono-prompt", args); + + Mono resultMono = callback.apply(context, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.description()).isEqualTo("Mono prompt"); + assertThat(result.messages()).hasSize(1); + PromptMessage message = result.messages().get(0); + assertThat(message.role()).isEqualTo(Role.ASSISTANT); + assertThat(((TextContent) message.content()).text()).isEqualTo("Async response for mono-prompt"); + }).verifyComplete(); + } + + @Test + public void testCallbackWithMonoString() throws Exception { + TestPromptProvider provider = new TestPromptProvider(); + Method method = TestPromptProvider.class.getMethod("getMonoString", GetPromptRequest.class); + + Prompt prompt = createTestPrompt("mono-string", "A prompt returning a Mono"); + + BiFunction> callback = AsyncStatelessMcpPromptMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt(prompt) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + args.put("name", "John"); + GetPromptRequest request = new GetPromptRequest("mono-string", args); + + Mono resultMono = callback.apply(context, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.messages()).hasSize(1); + PromptMessage message = result.messages().get(0); + assertThat(message.role()).isEqualTo(Role.ASSISTANT); + assertThat(((TextContent) message.content()).text()).isEqualTo("Async string response for mono-string"); + }).verifyComplete(); + } + + @Test + public void testCallbackWithMonoMessage() throws Exception { + TestPromptProvider provider = new TestPromptProvider(); + Method method = TestPromptProvider.class.getMethod("getMonoMessage", GetPromptRequest.class); + + Prompt prompt = createTestPrompt("mono-message", "A prompt returning a Mono"); + + BiFunction> callback = AsyncStatelessMcpPromptMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt(prompt) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + args.put("name", "John"); + GetPromptRequest request = new GetPromptRequest("mono-message", args); + + Mono resultMono = callback.apply(context, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.messages()).hasSize(1); + PromptMessage message = result.messages().get(0); + assertThat(message.role()).isEqualTo(Role.ASSISTANT); + assertThat(((TextContent) message.content()).text()).isEqualTo("Async single message for mono-message"); + }).verifyComplete(); + } + + @Test + public void testCallbackWithMonoMessageList() throws Exception { + TestPromptProvider provider = new TestPromptProvider(); + Method method = TestPromptProvider.class.getMethod("getMonoMessageList", GetPromptRequest.class); + + Prompt prompt = createTestPrompt("mono-message-list", "A prompt returning a Mono>"); + + BiFunction> callback = AsyncStatelessMcpPromptMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt(prompt) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + args.put("name", "John"); + GetPromptRequest request = new GetPromptRequest("mono-message-list", args); + + Mono resultMono = callback.apply(context, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.messages()).hasSize(2); + PromptMessage message1 = result.messages().get(0); + PromptMessage message2 = result.messages().get(1); + assertThat(message1.role()).isEqualTo(Role.ASSISTANT); + assertThat(message2.role()).isEqualTo(Role.ASSISTANT); + assertThat(((TextContent) message1.content()).text()).isEqualTo("Async message 1 for mono-message-list"); + assertThat(((TextContent) message2.content()).text()).isEqualTo("Async message 2 for mono-message-list"); + }).verifyComplete(); + } + + @Test + public void testCallbackWithMonoStringList() throws Exception { + TestPromptProvider provider = new TestPromptProvider(); + Method method = TestPromptProvider.class.getMethod("getMonoStringList", GetPromptRequest.class); + + Prompt prompt = createTestPrompt("mono-string-list", "A prompt returning a Mono>"); + + BiFunction> callback = AsyncStatelessMcpPromptMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt(prompt) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + args.put("name", "John"); + GetPromptRequest request = new GetPromptRequest("mono-string-list", args); + + Mono resultMono = callback.apply(context, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.messages()).hasSize(3); + PromptMessage message1 = result.messages().get(0); + PromptMessage message2 = result.messages().get(1); + PromptMessage message3 = result.messages().get(2); + assertThat(message1.role()).isEqualTo(Role.ASSISTANT); + assertThat(message2.role()).isEqualTo(Role.ASSISTANT); + assertThat(message3.role()).isEqualTo(Role.ASSISTANT); + assertThat(((TextContent) message1.content()).text()).isEqualTo("Async string 1 for mono-string-list"); + assertThat(((TextContent) message2.content()).text()).isEqualTo("Async string 2 for mono-string-list"); + assertThat(((TextContent) message3.content()).text()).isEqualTo("Async string 3 for mono-string-list"); + }).verifyComplete(); + } + + @Test + public void testInvalidReturnType() throws Exception { + TestPromptProvider provider = new TestPromptProvider(); + Method method = TestPromptProvider.class.getMethod("invalidReturnType", GetPromptRequest.class); + + Prompt prompt = createTestPrompt("invalid", "Invalid return type"); + + assertThatThrownBy(() -> AsyncStatelessMcpPromptMethodCallback.builder() + .method(method) + .bean(provider) + .prompt(prompt) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Method must return either GetPromptResult, List"); + } + + @Test + public void testDuplicateContextParameters() throws Exception { + TestPromptProvider provider = new TestPromptProvider(); + Method method = TestPromptProvider.class.getMethod("duplicateContextParameters", McpTransportContext.class, + McpTransportContext.class); + + Prompt prompt = createTestPrompt("invalid", "Invalid parameters"); + + assertThatThrownBy(() -> AsyncStatelessMcpPromptMethodCallback.builder() + .method(method) + .bean(provider) + .prompt(prompt) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Method cannot have more than one exchange parameter"); + } + + @Test + public void testDuplicateRequestParameters() throws Exception { + TestPromptProvider provider = new TestPromptProvider(); + Method method = TestPromptProvider.class.getMethod("duplicateRequestParameters", GetPromptRequest.class, + GetPromptRequest.class); + + Prompt prompt = createTestPrompt("invalid", "Invalid parameters"); + + assertThatThrownBy(() -> AsyncStatelessMcpPromptMethodCallback.builder() + .method(method) + .bean(provider) + .prompt(prompt) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Method cannot have more than one GetPromptRequest parameter"); + } + + @Test + public void testDuplicateMapParameters() throws Exception { + TestPromptProvider provider = new TestPromptProvider(); + Method method = TestPromptProvider.class.getMethod("duplicateMapParameters", Map.class, Map.class); + + Prompt prompt = createTestPrompt("invalid", "Invalid parameters"); + + assertThatThrownBy(() -> AsyncStatelessMcpPromptMethodCallback.builder() + .method(method) + .bean(provider) + .prompt(prompt) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Method cannot have more than one Map parameter"); + } + + @Test + public void testNullRequest() throws Exception { + TestPromptProvider provider = new TestPromptProvider(); + Method method = TestPromptProvider.class.getMethod("getMonoPrompt", GetPromptRequest.class); + + Prompt prompt = createTestPrompt("mono-prompt", "A prompt returning a Mono"); + + BiFunction> callback = AsyncStatelessMcpPromptMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt(prompt) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + + StepVerifier.create(callback.apply(context, null)).expectErrorMessage("Request must not be null").verify(); + } + +} diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/prompt/SyncStatelessMcpPromptMethodCallbackTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/prompt/SyncStatelessMcpPromptMethodCallbackTests.java new file mode 100644 index 0000000..7a17c0c --- /dev/null +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/prompt/SyncStatelessMcpPromptMethodCallbackTests.java @@ -0,0 +1,478 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.method.prompt; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +import java.lang.reflect.Method; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; + +import org.junit.jupiter.api.Test; +import org.springaicommunity.mcp.annotation.McpArg; +import org.springaicommunity.mcp.annotation.McpPrompt; + +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; +import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptArgument; +import io.modelcontextprotocol.spec.McpSchema.PromptMessage; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.TextContent; + +/** + * Tests for {@link SyncStatelessMcpPromptMethodCallback}. + * + * @author Christian Tzolov + */ +public class SyncStatelessMcpPromptMethodCallbackTests { + + private static class TestPromptProvider { + + @McpPrompt(name = "greeting", description = "A simple greeting prompt") + public GetPromptResult getPromptWithRequest(GetPromptRequest request) { + return new GetPromptResult("Greeting prompt", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent("Hello from " + request.name())))); + } + + @McpPrompt(name = "context-greeting", description = "A greeting prompt with context") + public GetPromptResult getPromptWithContext(McpTransportContext context, GetPromptRequest request) { + return new GetPromptResult("Greeting with context", List + .of(new PromptMessage(Role.ASSISTANT, new TextContent("Hello with context from " + request.name())))); + } + + @McpPrompt(name = "arguments-greeting", description = "A greeting prompt with arguments") + public GetPromptResult getPromptWithArguments(Map arguments) { + String name = arguments.containsKey("name") ? arguments.get("name").toString() : "unknown"; + return new GetPromptResult("Greeting with arguments", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent("Hello " + name + " from arguments")))); + } + + @McpPrompt(name = "individual-args", description = "A prompt with individual arguments") + public GetPromptResult getPromptWithIndividualArgs( + @McpArg(name = "name", description = "The user's name", required = true) String name, + @McpArg(name = "age", description = "The user's age", required = true) Integer age) { + return new GetPromptResult("Individual arguments prompt", List.of(new PromptMessage(Role.ASSISTANT, + new TextContent("Hello " + name + ", you are " + age + " years old")))); + } + + @McpPrompt(name = "mixed-args", description = "A prompt with mixed argument types") + public GetPromptResult getPromptWithMixedArgs(McpTransportContext context, + @McpArg(name = "name", description = "The user's name", required = true) String name, + @McpArg(name = "age", description = "The user's age", required = true) Integer age) { + return new GetPromptResult("Mixed arguments prompt", List.of(new PromptMessage(Role.ASSISTANT, + new TextContent("Hello " + name + ", you are " + age + " years old (with context)")))); + } + + @McpPrompt(name = "list-messages", description = "A prompt returning a list of messages") + public List getPromptMessagesList(GetPromptRequest request) { + return List.of(new PromptMessage(Role.ASSISTANT, new TextContent("Message 1 for " + request.name())), + new PromptMessage(Role.ASSISTANT, new TextContent("Message 2 for " + request.name()))); + } + + @McpPrompt(name = "string-prompt", description = "A prompt returning a string") + public String getStringPrompt(GetPromptRequest request) { + return "Simple string response for " + request.name(); + } + + @McpPrompt(name = "single-message", description = "A prompt returning a single message") + public PromptMessage getSingleMessage(GetPromptRequest request) { + return new PromptMessage(Role.ASSISTANT, new TextContent("Single message for " + request.name())); + } + + @McpPrompt(name = "string-list", description = "A prompt returning a list of strings") + public List getStringList(GetPromptRequest request) { + return List.of("String 1 for " + request.name(), "String 2 for " + request.name(), + "String 3 for " + request.name()); + } + + public void invalidReturnType(GetPromptRequest request) { + // Invalid return type + } + + public GetPromptResult duplicateContextParameters(McpTransportContext context1, McpTransportContext context2) { + return new GetPromptResult("Invalid", List.of()); + } + + public GetPromptResult duplicateRequestParameters(GetPromptRequest request1, GetPromptRequest request2) { + return new GetPromptResult("Invalid", List.of()); + } + + public GetPromptResult duplicateMapParameters(Map args1, Map args2) { + return new GetPromptResult("Invalid", List.of()); + } + + } + + private Prompt createTestPrompt(String name, String description) { + return new Prompt(name, description, List.of(new PromptArgument("name", "User's name", true), + new PromptArgument("age", "User's age", false))); + } + + @Test + public void testCallbackWithRequestParameter() throws Exception { + TestPromptProvider provider = new TestPromptProvider(); + Method method = TestPromptProvider.class.getMethod("getPromptWithRequest", GetPromptRequest.class); + + Prompt prompt = createTestPrompt("greeting", "A simple greeting prompt"); + + BiFunction callback = SyncStatelessMcpPromptMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt(prompt) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + args.put("name", "John"); + GetPromptRequest request = new GetPromptRequest("greeting", args); + + GetPromptResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.description()).isEqualTo("Greeting prompt"); + assertThat(result.messages()).hasSize(1); + PromptMessage message = result.messages().get(0); + assertThat(message.role()).isEqualTo(Role.ASSISTANT); + assertThat(((TextContent) message.content()).text()).isEqualTo("Hello from greeting"); + } + + @Test + public void testCallbackWithContextAndRequestParameters() throws Exception { + TestPromptProvider provider = new TestPromptProvider(); + Method method = TestPromptProvider.class.getMethod("getPromptWithContext", McpTransportContext.class, + GetPromptRequest.class); + + Prompt prompt = createTestPrompt("context-greeting", "A greeting prompt with context"); + + BiFunction callback = SyncStatelessMcpPromptMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt(prompt) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + args.put("name", "John"); + GetPromptRequest request = new GetPromptRequest("context-greeting", args); + + GetPromptResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.description()).isEqualTo("Greeting with context"); + assertThat(result.messages()).hasSize(1); + PromptMessage message = result.messages().get(0); + assertThat(message.role()).isEqualTo(Role.ASSISTANT); + assertThat(((TextContent) message.content()).text()).isEqualTo("Hello with context from context-greeting"); + } + + @Test + public void testCallbackWithArgumentsMap() throws Exception { + TestPromptProvider provider = new TestPromptProvider(); + Method method = TestPromptProvider.class.getMethod("getPromptWithArguments", Map.class); + + Prompt prompt = createTestPrompt("arguments-greeting", "A greeting prompt with arguments"); + + BiFunction callback = SyncStatelessMcpPromptMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt(prompt) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + args.put("name", "John"); + GetPromptRequest request = new GetPromptRequest("arguments-greeting", args); + + GetPromptResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.description()).isEqualTo("Greeting with arguments"); + assertThat(result.messages()).hasSize(1); + PromptMessage message = result.messages().get(0); + assertThat(message.role()).isEqualTo(Role.ASSISTANT); + assertThat(((TextContent) message.content()).text()).isEqualTo("Hello John from arguments"); + } + + @Test + public void testCallbackWithIndividualArguments() throws Exception { + TestPromptProvider provider = new TestPromptProvider(); + Method method = TestPromptProvider.class.getMethod("getPromptWithIndividualArgs", String.class, Integer.class); + + Prompt prompt = createTestPrompt("individual-args", "A prompt with individual arguments"); + + BiFunction callback = SyncStatelessMcpPromptMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt(prompt) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + args.put("name", "John"); + args.put("age", 30); + GetPromptRequest request = new GetPromptRequest("individual-args", args); + + GetPromptResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.description()).isEqualTo("Individual arguments prompt"); + assertThat(result.messages()).hasSize(1); + PromptMessage message = result.messages().get(0); + assertThat(message.role()).isEqualTo(Role.ASSISTANT); + assertThat(((TextContent) message.content()).text()).isEqualTo("Hello John, you are 30 years old"); + } + + @Test + public void testCallbackWithMixedArguments() throws Exception { + TestPromptProvider provider = new TestPromptProvider(); + Method method = TestPromptProvider.class.getMethod("getPromptWithMixedArgs", McpTransportContext.class, + String.class, Integer.class); + + Prompt prompt = createTestPrompt("mixed-args", "A prompt with mixed argument types"); + + BiFunction callback = SyncStatelessMcpPromptMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt(prompt) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + args.put("name", "John"); + args.put("age", 30); + GetPromptRequest request = new GetPromptRequest("mixed-args", args); + + GetPromptResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.description()).isEqualTo("Mixed arguments prompt"); + assertThat(result.messages()).hasSize(1); + PromptMessage message = result.messages().get(0); + assertThat(message.role()).isEqualTo(Role.ASSISTANT); + assertThat(((TextContent) message.content()).text()) + .isEqualTo("Hello John, you are 30 years old (with context)"); + } + + @Test + public void testCallbackWithMessagesList() throws Exception { + TestPromptProvider provider = new TestPromptProvider(); + Method method = TestPromptProvider.class.getMethod("getPromptMessagesList", GetPromptRequest.class); + + Prompt prompt = createTestPrompt("list-messages", "A prompt returning a list of messages"); + + BiFunction callback = SyncStatelessMcpPromptMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt(prompt) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + args.put("name", "John"); + GetPromptRequest request = new GetPromptRequest("list-messages", args); + + GetPromptResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.description()).isNull(); + assertThat(result.messages()).hasSize(2); + PromptMessage message1 = result.messages().get(0); + PromptMessage message2 = result.messages().get(1); + assertThat(message1.role()).isEqualTo(Role.ASSISTANT); + assertThat(message2.role()).isEqualTo(Role.ASSISTANT); + assertThat(((TextContent) message1.content()).text()).isEqualTo("Message 1 for list-messages"); + assertThat(((TextContent) message2.content()).text()).isEqualTo("Message 2 for list-messages"); + } + + @Test + public void testCallbackWithStringReturn() throws Exception { + TestPromptProvider provider = new TestPromptProvider(); + Method method = TestPromptProvider.class.getMethod("getStringPrompt", GetPromptRequest.class); + + Prompt prompt = createTestPrompt("string-prompt", "A prompt returning a string"); + + BiFunction callback = SyncStatelessMcpPromptMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt(prompt) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + args.put("name", "John"); + GetPromptRequest request = new GetPromptRequest("string-prompt", args); + + GetPromptResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.messages()).hasSize(1); + PromptMessage message = result.messages().get(0); + assertThat(message.role()).isEqualTo(Role.ASSISTANT); + assertThat(((TextContent) message.content()).text()).isEqualTo("Simple string response for string-prompt"); + } + + @Test + public void testCallbackWithSingleMessage() throws Exception { + TestPromptProvider provider = new TestPromptProvider(); + Method method = TestPromptProvider.class.getMethod("getSingleMessage", GetPromptRequest.class); + + Prompt prompt = createTestPrompt("single-message", "A prompt returning a single message"); + + BiFunction callback = SyncStatelessMcpPromptMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt(prompt) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + args.put("name", "John"); + GetPromptRequest request = new GetPromptRequest("single-message", args); + + GetPromptResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.description()).isNull(); + assertThat(result.messages()).hasSize(1); + PromptMessage message = result.messages().get(0); + assertThat(message.role()).isEqualTo(Role.ASSISTANT); + assertThat(((TextContent) message.content()).text()).isEqualTo("Single message for single-message"); + } + + @Test + public void testCallbackWithStringList() throws Exception { + TestPromptProvider provider = new TestPromptProvider(); + Method method = TestPromptProvider.class.getMethod("getStringList", GetPromptRequest.class); + + Prompt prompt = createTestPrompt("string-list", "A prompt returning a list of strings"); + + BiFunction callback = SyncStatelessMcpPromptMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt(prompt) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + args.put("name", "John"); + GetPromptRequest request = new GetPromptRequest("string-list", args); + + GetPromptResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.description()).isNull(); + assertThat(result.messages()).hasSize(3); + + PromptMessage message1 = result.messages().get(0); + PromptMessage message2 = result.messages().get(1); + PromptMessage message3 = result.messages().get(2); + + assertThat(message1.role()).isEqualTo(Role.ASSISTANT); + assertThat(message2.role()).isEqualTo(Role.ASSISTANT); + assertThat(message3.role()).isEqualTo(Role.ASSISTANT); + + assertThat(((TextContent) message1.content()).text()).isEqualTo("String 1 for string-list"); + assertThat(((TextContent) message2.content()).text()).isEqualTo("String 2 for string-list"); + assertThat(((TextContent) message3.content()).text()).isEqualTo("String 3 for string-list"); + } + + @Test + public void testInvalidReturnType() throws Exception { + TestPromptProvider provider = new TestPromptProvider(); + Method method = TestPromptProvider.class.getMethod("invalidReturnType", GetPromptRequest.class); + + Prompt prompt = createTestPrompt("invalid", "Invalid return type"); + + assertThatThrownBy(() -> SyncStatelessMcpPromptMethodCallback.builder() + .method(method) + .bean(provider) + .prompt(prompt) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Method must return either GetPromptResult, List"); + } + + @Test + public void testDuplicateContextParameters() throws Exception { + TestPromptProvider provider = new TestPromptProvider(); + Method method = TestPromptProvider.class.getMethod("duplicateContextParameters", McpTransportContext.class, + McpTransportContext.class); + + Prompt prompt = createTestPrompt("invalid", "Invalid parameters"); + + assertThatThrownBy(() -> SyncStatelessMcpPromptMethodCallback.builder() + .method(method) + .bean(provider) + .prompt(prompt) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Method cannot have more than one exchange parameter"); + } + + @Test + public void testDuplicateRequestParameters() throws Exception { + TestPromptProvider provider = new TestPromptProvider(); + Method method = TestPromptProvider.class.getMethod("duplicateRequestParameters", GetPromptRequest.class, + GetPromptRequest.class); + + Prompt prompt = createTestPrompt("invalid", "Invalid parameters"); + + assertThatThrownBy(() -> SyncStatelessMcpPromptMethodCallback.builder() + .method(method) + .bean(provider) + .prompt(prompt) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Method cannot have more than one GetPromptRequest parameter"); + } + + @Test + public void testDuplicateMapParameters() throws Exception { + TestPromptProvider provider = new TestPromptProvider(); + Method method = TestPromptProvider.class.getMethod("duplicateMapParameters", Map.class, Map.class); + + Prompt prompt = createTestPrompt("invalid", "Invalid parameters"); + + assertThatThrownBy(() -> SyncStatelessMcpPromptMethodCallback.builder() + .method(method) + .bean(provider) + .prompt(prompt) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Method cannot have more than one Map parameter"); + } + + @Test + public void testNullRequest() throws Exception { + TestPromptProvider provider = new TestPromptProvider(); + Method method = TestPromptProvider.class.getMethod("getPromptWithRequest", GetPromptRequest.class); + + Prompt prompt = createTestPrompt("greeting", "A simple greeting prompt"); + + BiFunction callback = SyncStatelessMcpPromptMethodCallback + .builder() + .method(method) + .bean(provider) + .prompt(prompt) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + + assertThatThrownBy(() -> callback.apply(context, null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Request must not be null"); + } + +} diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/resource/AsyncMcpResourceMethodCallbackTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/resource/AsyncMcpResourceMethodCallbackTests.java index 1af7008..4d46a77 100644 --- a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/resource/AsyncMcpResourceMethodCallbackTests.java +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/resource/AsyncMcpResourceMethodCallbackTests.java @@ -20,7 +20,6 @@ import org.junit.jupiter.api.Test; import org.springaicommunity.mcp.annotation.McpResource; import org.springaicommunity.mcp.annotation.ResourceAdaptor; -import org.springaicommunity.mcp.method.resource.AsyncMcpResourceMethodCallback; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/resource/AsyncStatelessMcpResourceMethodCallbackTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/resource/AsyncStatelessMcpResourceMethodCallbackTests.java new file mode 100644 index 0000000..97edddd --- /dev/null +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/resource/AsyncStatelessMcpResourceMethodCallbackTests.java @@ -0,0 +1,646 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.method.resource; + +import java.lang.reflect.Method; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; + +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema.BlobResourceContents; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceRequest; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.ResourceContents; +import io.modelcontextprotocol.spec.McpSchema.TextResourceContents; +import io.modelcontextprotocol.util.McpUriTemplateManager; +import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; +import org.junit.jupiter.api.Test; +import org.springaicommunity.mcp.annotation.McpResource; +import org.springaicommunity.mcp.annotation.ResourceAdaptor; + +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +/** + * Tests for {@link AsyncStatelessMcpResourceMethodCallback}. + * + * @author Christian Tzolov + */ +public class AsyncStatelessMcpResourceMethodCallbackTests { + + private static class TestAsyncStatelessResourceProvider { + + // Regular return types (will be wrapped in Mono by the callback) + public ReadResourceResult getResourceWithRequest(ReadResourceRequest request) { + return new ReadResourceResult( + List.of(new TextResourceContents(request.uri(), "text/plain", "Content for " + request.uri()))); + } + + public ReadResourceResult getResourceWithContext(McpTransportContext context, ReadResourceRequest request) { + return new ReadResourceResult(List.of(new TextResourceContents(request.uri(), "text/plain", + "Content with context for " + request.uri()))); + } + + public ReadResourceResult getResourceWithUri(String uri) { + return new ReadResourceResult( + List.of(new TextResourceContents(uri, "text/plain", "Content from URI: " + uri))); + } + + @McpResource(uri = "users/{userId}/posts/{postId}") + public ReadResourceResult getResourceWithUriVariables(String userId, String postId) { + return new ReadResourceResult(List.of(new TextResourceContents("users/" + userId + "/posts/" + postId, + "text/plain", "User: " + userId + ", Post: " + postId))); + } + + @McpResource(uri = "users/{userId}/profile") + public ReadResourceResult getResourceWithContextAndUriVariable(McpTransportContext context, String userId) { + return new ReadResourceResult(List.of(new TextResourceContents("users/" + userId + "/profile", "text/plain", + "Profile for user: " + userId))); + } + + // Mono return types + public Mono getResourceWithRequestAsync(ReadResourceRequest request) { + return Mono.just(new ReadResourceResult(List + .of(new TextResourceContents(request.uri(), "text/plain", "Async content for " + request.uri())))); + } + + public Mono getResourceWithContextAsync(McpTransportContext context, + ReadResourceRequest request) { + return Mono.just(new ReadResourceResult(List.of(new TextResourceContents(request.uri(), "text/plain", + "Async content with context for " + request.uri())))); + } + + @McpResource(uri = "async/users/{userId}/posts/{postId}") + public Mono getResourceWithUriVariablesAsync(String userId, String postId) { + return Mono.just(new ReadResourceResult( + List.of(new TextResourceContents("async/users/" + userId + "/posts/" + postId, "text/plain", + "Async User: " + userId + ", Post: " + postId)))); + } + + public Mono> getResourceContentsListAsync(ReadResourceRequest request) { + return Mono.just(List + .of(new TextResourceContents(request.uri(), "text/plain", "Async content list for " + request.uri()))); + } + + public Mono getSingleStringAsync(ReadResourceRequest request) { + return Mono.just("Async single string for " + request.uri()); + } + + @McpResource(uri = "text-content://async-resource", mimeType = "text/plain") + public Mono getStringWithTextContentTypeAsync(ReadResourceRequest request) { + return Mono.just("Async text content type for " + request.uri()); + } + + @McpResource(uri = "blob-content://async-resource", mimeType = "application/octet-stream") + public Mono getStringWithBlobContentTypeAsync(ReadResourceRequest request) { + return Mono.just("Async blob content type for " + request.uri()); + } + + public void invalidReturnType(ReadResourceRequest request) { + // Invalid return type + } + + public Mono invalidMonoReturnType(ReadResourceRequest request) { + return Mono.empty(); + } + + public Mono invalidParameters(int value) { + return Mono.just(new ReadResourceResult(List.of())); + } + + public Mono tooManyParameters(McpTransportContext context, ReadResourceRequest request, + String extraParam) { + return Mono.just(new ReadResourceResult(List.of())); + } + + public Mono invalidParameterType(Object invalidParam) { + return Mono.just(new ReadResourceResult(List.of())); + } + + public Mono duplicateContextParameters(McpTransportContext context1, + McpTransportContext context2) { + return Mono.just(new ReadResourceResult(List.of())); + } + + public Mono duplicateRequestParameters(ReadResourceRequest request1, + ReadResourceRequest request2) { + return Mono.just(new ReadResourceResult(List.of())); + } + + } + + // Helper method to create a mock McpResource annotation + private McpResource createMockMcpResource() { + return new McpResource() { + @Override + public Class annotationType() { + return McpResource.class; + } + + @Override + public String uri() { + return "test://resource"; + } + + @Override + public String name() { + return ""; + } + + @Override + public String description() { + return ""; + } + + @Override + public String mimeType() { + return "text/plain"; + } + + }; + } + + @Test + public void testCallbackWithRequestParameter() throws Exception { + TestAsyncStatelessResourceProvider provider = new TestAsyncStatelessResourceProvider(); + Method method = TestAsyncStatelessResourceProvider.class.getMethod("getResourceWithRequest", + ReadResourceRequest.class); + + // Provide a mock McpResource annotation since the method doesn't have one + BiFunction> callback = AsyncStatelessMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(createMockMcpResource())) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + ReadResourceRequest request = new ReadResourceRequest("test/resource"); + + Mono resultMono = callback.apply(context, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(1); + assertThat(result.contents().get(0)).isInstanceOf(TextResourceContents.class); + TextResourceContents textContent = (TextResourceContents) result.contents().get(0); + assertThat(textContent.text()).isEqualTo("Content for test/resource"); + }).verifyComplete(); + } + + @Test + public void testCallbackWithContextAndRequestParameters() throws Exception { + TestAsyncStatelessResourceProvider provider = new TestAsyncStatelessResourceProvider(); + Method method = TestAsyncStatelessResourceProvider.class.getMethod("getResourceWithContext", + McpTransportContext.class, ReadResourceRequest.class); + + // Use the builder to provide a mock McpResource annotation + BiFunction> callback = AsyncStatelessMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(createMockMcpResource())) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + ReadResourceRequest request = new ReadResourceRequest("test/resource"); + + Mono resultMono = callback.apply(context, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(1); + assertThat(result.contents().get(0)).isInstanceOf(TextResourceContents.class); + TextResourceContents textContent = (TextResourceContents) result.contents().get(0); + assertThat(textContent.text()).isEqualTo("Content with context for test/resource"); + }).verifyComplete(); + } + + @Test + public void testCallbackWithUriVariables() throws Exception { + TestAsyncStatelessResourceProvider provider = new TestAsyncStatelessResourceProvider(); + Method method = TestAsyncStatelessResourceProvider.class.getMethod("getResourceWithUriVariables", String.class, + String.class); + McpResource resourceAnnotation = method.getAnnotation(McpResource.class); + + BiFunction> callback = AsyncStatelessMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(resourceAnnotation)) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + ReadResourceRequest request = new ReadResourceRequest("users/123/posts/456"); + + Mono resultMono = callback.apply(context, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(1); + assertThat(result.contents().get(0)).isInstanceOf(TextResourceContents.class); + TextResourceContents textContent = (TextResourceContents) result.contents().get(0); + assertThat(textContent.text()).isEqualTo("User: 123, Post: 456"); + }).verifyComplete(); + } + + @Test + public void testCallbackWithRequestParameterAsync() throws Exception { + TestAsyncStatelessResourceProvider provider = new TestAsyncStatelessResourceProvider(); + Method method = TestAsyncStatelessResourceProvider.class.getMethod("getResourceWithRequestAsync", + ReadResourceRequest.class); + + // Provide a mock McpResource annotation since the method doesn't have one + BiFunction> callback = AsyncStatelessMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(createMockMcpResource())) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + ReadResourceRequest request = new ReadResourceRequest("test/resource"); + + Mono resultMono = callback.apply(context, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(1); + assertThat(result.contents().get(0)).isInstanceOf(TextResourceContents.class); + TextResourceContents textContent = (TextResourceContents) result.contents().get(0); + assertThat(textContent.text()).isEqualTo("Async content for test/resource"); + }).verifyComplete(); + } + + @Test + public void testCallbackWithContextAndRequestParametersAsync() throws Exception { + TestAsyncStatelessResourceProvider provider = new TestAsyncStatelessResourceProvider(); + Method method = TestAsyncStatelessResourceProvider.class.getMethod("getResourceWithContextAsync", + McpTransportContext.class, ReadResourceRequest.class); + + // Use the builder to provide a mock McpResource annotation + BiFunction> callback = AsyncStatelessMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(createMockMcpResource())) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + ReadResourceRequest request = new ReadResourceRequest("test/resource"); + + Mono resultMono = callback.apply(context, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(1); + assertThat(result.contents().get(0)).isInstanceOf(TextResourceContents.class); + TextResourceContents textContent = (TextResourceContents) result.contents().get(0); + assertThat(textContent.text()).isEqualTo("Async content with context for test/resource"); + }).verifyComplete(); + } + + @Test + public void testCallbackWithUriVariablesAsync() throws Exception { + TestAsyncStatelessResourceProvider provider = new TestAsyncStatelessResourceProvider(); + Method method = TestAsyncStatelessResourceProvider.class.getMethod("getResourceWithUriVariablesAsync", + String.class, String.class); + McpResource resourceAnnotation = method.getAnnotation(McpResource.class); + + BiFunction> callback = AsyncStatelessMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(resourceAnnotation)) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + ReadResourceRequest request = new ReadResourceRequest("async/users/123/posts/456"); + + Mono resultMono = callback.apply(context, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(1); + assertThat(result.contents().get(0)).isInstanceOf(TextResourceContents.class); + TextResourceContents textContent = (TextResourceContents) result.contents().get(0); + assertThat(textContent.text()).isEqualTo("Async User: 123, Post: 456"); + }).verifyComplete(); + } + + @Test + public void testCallbackWithStringAsync() throws Exception { + TestAsyncStatelessResourceProvider provider = new TestAsyncStatelessResourceProvider(); + Method method = TestAsyncStatelessResourceProvider.class.getMethod("getSingleStringAsync", + ReadResourceRequest.class); + + BiFunction> callback = AsyncStatelessMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(createMockMcpResource())) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + ReadResourceRequest request = new ReadResourceRequest("test/resource"); + + Mono resultMono = callback.apply(context, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(1); + assertThat(result.contents().get(0)).isInstanceOf(TextResourceContents.class); + TextResourceContents textContent = (TextResourceContents) result.contents().get(0); + assertThat(textContent.text()).isEqualTo("Async single string for test/resource"); + }).verifyComplete(); + } + + @Test + public void testCallbackWithTextContentTypeAsync() throws Exception { + TestAsyncStatelessResourceProvider provider = new TestAsyncStatelessResourceProvider(); + Method method = TestAsyncStatelessResourceProvider.class.getMethod("getStringWithTextContentTypeAsync", + ReadResourceRequest.class); + McpResource resourceAnnotation = method.getAnnotation(McpResource.class); + + BiFunction> callback = AsyncStatelessMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(resourceAnnotation)) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + ReadResourceRequest request = new ReadResourceRequest("test/resource"); + + Mono resultMono = callback.apply(context, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(1); + assertThat(result.contents().get(0)).isInstanceOf(TextResourceContents.class); + TextResourceContents textContent = (TextResourceContents) result.contents().get(0); + assertThat(textContent.text()).isEqualTo("Async text content type for test/resource"); + assertThat(textContent.mimeType()).isEqualTo("text/plain"); + }).verifyComplete(); + } + + @Test + public void testCallbackWithBlobContentTypeAsync() throws Exception { + TestAsyncStatelessResourceProvider provider = new TestAsyncStatelessResourceProvider(); + Method method = TestAsyncStatelessResourceProvider.class.getMethod("getStringWithBlobContentTypeAsync", + ReadResourceRequest.class); + McpResource resourceAnnotation = method.getAnnotation(McpResource.class); + + BiFunction> callback = AsyncStatelessMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(resourceAnnotation)) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + ReadResourceRequest request = new ReadResourceRequest("test/resource"); + + Mono resultMono = callback.apply(context, request); + + StepVerifier.create(resultMono).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(1); + assertThat(result.contents().get(0)).isInstanceOf(BlobResourceContents.class); + BlobResourceContents blobContent = (BlobResourceContents) result.contents().get(0); + assertThat(blobContent.blob()).isEqualTo("Async blob content type for test/resource"); + assertThat(blobContent.mimeType()).isEqualTo("application/octet-stream"); + }).verifyComplete(); + } + + @Test + public void testInvalidReturnType() throws Exception { + TestAsyncStatelessResourceProvider provider = new TestAsyncStatelessResourceProvider(); + Method method = TestAsyncStatelessResourceProvider.class.getMethod("invalidReturnType", + ReadResourceRequest.class); + + assertThatThrownBy( + () -> AsyncStatelessMcpResourceMethodCallback.builder().method(method).bean(provider).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("URI must not be null or empty"); + } + + @Test + public void testInvalidMonoReturnType() throws Exception { + TestAsyncStatelessResourceProvider provider = new TestAsyncStatelessResourceProvider(); + Method method = TestAsyncStatelessResourceProvider.class.getMethod("invalidMonoReturnType", + ReadResourceRequest.class); + + assertThatThrownBy( + () -> AsyncStatelessMcpResourceMethodCallback.builder().method(method).bean(provider).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("URI must not be null or empty"); + } + + @Test + public void testInvalidUriVariableParameters() throws Exception { + TestAsyncStatelessResourceProvider provider = new TestAsyncStatelessResourceProvider(); + Method method = TestAsyncStatelessResourceProvider.class.getMethod("getResourceWithUriVariables", String.class, + String.class); + McpResource resourceAnnotation = method.getAnnotation(McpResource.class); + + // Create a mock annotation with a different URI template that has more + // variables + // than the method has parameters + McpResource mockResourceAnnotation = new McpResource() { + @Override + public Class annotationType() { + return McpResource.class; + } + + @Override + public String uri() { + return "users/{userId}/posts/{postId}/comments/{commentId}"; + } + + @Override + public String name() { + return ""; + } + + @Override + public String description() { + return ""; + } + + @Override + public String mimeType() { + return ""; + } + + }; + + assertThatThrownBy(() -> AsyncStatelessMcpResourceMethodCallback.builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(mockResourceAnnotation)) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Method must have parameters for all URI variables"); + } + + @Test + public void testNullRequest() throws Exception { + TestAsyncStatelessResourceProvider provider = new TestAsyncStatelessResourceProvider(); + Method method = TestAsyncStatelessResourceProvider.class.getMethod("getResourceWithRequest", + ReadResourceRequest.class); + + BiFunction> callback = AsyncStatelessMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(createMockMcpResource())) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + + Mono resultMono = callback.apply(context, null); + + StepVerifier.create(resultMono) + .expectErrorMatches(throwable -> throwable instanceof IllegalArgumentException + && throwable.getMessage().contains("Request must not be null")) + .verify(); + } + + @Test + public void testMethodInvocationError() throws Exception { + TestAsyncStatelessResourceProvider provider = new TestAsyncStatelessResourceProvider(); + Method method = TestAsyncStatelessResourceProvider.class.getMethod("getResourceWithRequest", + ReadResourceRequest.class); + + BiFunction> callback = AsyncStatelessMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(createMockMcpResource())) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + // Create a request with a URI that will cause the URI template extraction to + // fail + ReadResourceRequest request = new ReadResourceRequest("invalid:uri"); + + // Mock the URI template manager to throw an exception when extracting variables + McpUriTemplateManager mockUriTemplateManager = new McpUriTemplateManager() { + @Override + public List getVariableNames() { + return List.of(); + } + + @Override + public Map extractVariableValues(String uri) { + throw new RuntimeException("Simulated extraction error"); + } + + @Override + public boolean matches(String uri) { + return false; + } + + @Override + public boolean isUriTemplate(String uri) { + return uri != null && uri.contains("{"); + } + }; + + BiFunction> callbackWithMockTemplate = AsyncStatelessMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(createMockMcpResource())) + .uriTemplateManagerFactory(new McpUriTemplateManagerFactory() { + public McpUriTemplateManager create(String uriTemplate) { + return mockUriTemplateManager; + }; + }) + .build(); + + Mono resultMono = callbackWithMockTemplate.apply(context, request); + + StepVerifier.create(resultMono) + .expectErrorMatches( + throwable -> throwable instanceof AsyncStatelessMcpResourceMethodCallback.McpResourceMethodException + && throwable.getMessage().contains("Error invoking resource method")) + .verify(); + } + + @Test + public void testIsExchangeOrContextType() throws Exception { + TestAsyncStatelessResourceProvider provider = new TestAsyncStatelessResourceProvider(); + Method method = TestAsyncStatelessResourceProvider.class.getMethod("getResourceWithRequest", + ReadResourceRequest.class); + AsyncStatelessMcpResourceMethodCallback callback = AsyncStatelessMcpResourceMethodCallback.builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(createMockMcpResource())) + .build(); + + // Test that McpTransportContext is recognized as context type + // Note: We need to use reflection to access the protected method for testing + java.lang.reflect.Method isContextTypeMethod = AsyncStatelessMcpResourceMethodCallback.class + .getDeclaredMethod("isExchangeOrContextType", Class.class); + isContextTypeMethod.setAccessible(true); + + assertThat((Boolean) isContextTypeMethod.invoke(callback, McpTransportContext.class)).isTrue(); + + // Test that other types are not recognized as context type + assertThat((Boolean) isContextTypeMethod.invoke(callback, String.class)).isFalse(); + assertThat((Boolean) isContextTypeMethod.invoke(callback, Integer.class)).isFalse(); + assertThat((Boolean) isContextTypeMethod.invoke(callback, Object.class)).isFalse(); + } + + @Test + public void testBuilderValidation() { + TestAsyncStatelessResourceProvider provider = new TestAsyncStatelessResourceProvider(); + + // Test null method + assertThatThrownBy(() -> AsyncStatelessMcpResourceMethodCallback.builder().bean(provider).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Method must not be null"); + + // Test null bean + assertThatThrownBy(() -> AsyncStatelessMcpResourceMethodCallback.builder() + .method(TestAsyncStatelessResourceProvider.class.getMethod("getResourceWithRequest", + ReadResourceRequest.class)) + .build()).isInstanceOf(IllegalArgumentException.class).hasMessage("Bean must not be null"); + } + + @Test + public void testUriVariableExtraction() throws Exception { + TestAsyncStatelessResourceProvider provider = new TestAsyncStatelessResourceProvider(); + Method method = TestAsyncStatelessResourceProvider.class.getMethod("getResourceWithUriVariables", String.class, + String.class); + McpResource resourceAnnotation = method.getAnnotation(McpResource.class); + + BiFunction> callback = AsyncStatelessMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(resourceAnnotation)) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + + // Test with mismatched URI that doesn't contain expected variables + ReadResourceRequest invalidRequest = new ReadResourceRequest("invalid/uri/format"); + + Mono resultMono = callback.apply(context, invalidRequest); + + StepVerifier.create(resultMono) + .expectErrorMatches(throwable -> throwable instanceof IllegalArgumentException + && throwable.getMessage().contains("Failed to extract all URI variables from request URI")) + .verify(); + } + +} diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/resource/SyncStatelessMcpResourceMethodCallbackTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/resource/SyncStatelessMcpResourceMethodCallbackTests.java new file mode 100644 index 0000000..6d48763 --- /dev/null +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/resource/SyncStatelessMcpResourceMethodCallbackTests.java @@ -0,0 +1,640 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.method.resource; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +import java.lang.reflect.Method; +import java.util.List; +import java.util.function.BiFunction; + +import org.junit.jupiter.api.Test; +import org.springaicommunity.mcp.annotation.McpResource; +import org.springaicommunity.mcp.annotation.ResourceAdaptor; + +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema.BlobResourceContents; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceRequest; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.ResourceContents; +import io.modelcontextprotocol.spec.McpSchema.TextResourceContents; + +/** + * Tests for {@link SyncStatelessMcpResourceMethodCallback}. + * + * @author Christian Tzolov + */ +public class SyncStatelessMcpResourceMethodCallbackTests { + + private static class TestResourceProvider { + + public ReadResourceResult getResourceWithRequest(ReadResourceRequest request) { + return new ReadResourceResult( + List.of(new TextResourceContents(request.uri(), "text/plain", "Content for " + request.uri()))); + } + + public ReadResourceResult getResourceWithContext(McpTransportContext context, ReadResourceRequest request) { + return new ReadResourceResult(List.of(new TextResourceContents(request.uri(), "text/plain", + "Content with context for " + request.uri()))); + } + + public ReadResourceResult getResourceWithUri(String uri) { + return new ReadResourceResult( + List.of(new TextResourceContents(uri, "text/plain", "Content from URI: " + uri))); + } + + @McpResource(uri = "users/{userId}/posts/{postId}") + public ReadResourceResult getResourceWithUriVariables(String userId, String postId) { + return new ReadResourceResult(List.of(new TextResourceContents("users/" + userId + "/posts/" + postId, + "text/plain", "User: " + userId + ", Post: " + postId))); + } + + @McpResource(uri = "users/{userId}/profile") + public ReadResourceResult getResourceWithContextAndUriVariable(McpTransportContext context, String userId) { + return new ReadResourceResult(List.of(new TextResourceContents("users/" + userId + "/profile", "text/plain", + "Profile for user: " + userId))); + } + + public List getResourceContentsList(ReadResourceRequest request) { + return List.of(new TextResourceContents(request.uri(), "text/plain", "Content list for " + request.uri())); + } + + public List getStringList(ReadResourceRequest request) { + return List.of("String 1 for " + request.uri(), "String 2 for " + request.uri()); + } + + public ResourceContents getSingleResourceContents(ReadResourceRequest request) { + return new TextResourceContents(request.uri(), "text/plain", + "Single resource content for " + request.uri()); + } + + public String getSingleString(ReadResourceRequest request) { + return "Single string for " + request.uri(); + } + + @McpResource(uri = "text-content://resource", mimeType = "text/plain") + public String getStringWithTextContentType(ReadResourceRequest request) { + return "Text content type for " + request.uri(); + } + + @McpResource(uri = "blob-content://resource", mimeType = "application/octet-stream") + public String getStringWithBlobContentType(ReadResourceRequest request) { + return "Blob content type for " + request.uri(); + } + + @McpResource(uri = "text-list://resource", mimeType = "text/html") + public List getStringListWithTextContentType(ReadResourceRequest request) { + return List.of("HTML text 1 for " + request.uri(), "HTML text 2 for " + request.uri()); + } + + @McpResource(uri = "blob-list://resource", mimeType = "image/png") + public List getStringListWithBlobContentType(ReadResourceRequest request) { + return List.of("PNG blob 1 for " + request.uri(), "PNG blob 2 for " + request.uri()); + } + + public void invalidReturnType(ReadResourceRequest request) { + // Invalid return type + } + + public ReadResourceResult invalidParameters(int value) { + return new ReadResourceResult(List.of()); + } + + public ReadResourceResult tooManyParameters(McpTransportContext context, ReadResourceRequest request, + String extraParam) { + return new ReadResourceResult(List.of()); + } + + public ReadResourceResult invalidParameterType(Object invalidParam) { + return new ReadResourceResult(List.of()); + } + + public ReadResourceResult duplicateContextParameters(McpTransportContext context1, + McpTransportContext context2) { + return new ReadResourceResult(List.of()); + } + + public ReadResourceResult duplicateRequestParameters(ReadResourceRequest request1, + ReadResourceRequest request2) { + return new ReadResourceResult(List.of()); + } + + } + + // Helper method to create a mock McpResource annotation + private McpResource createMockMcpResource() { + return new McpResource() { + @Override + public Class annotationType() { + return McpResource.class; + } + + @Override + public String uri() { + return "test://resource"; + } + + @Override + public String name() { + return "testResource"; + } + + @Override + public String description() { + return "Test resource description"; + } + + @Override + public String mimeType() { + return "text/plain"; + } + }; + } + + @Test + public void testCallbackWithRequestParameter() throws Exception { + TestResourceProvider provider = new TestResourceProvider(); + Method method = TestResourceProvider.class.getMethod("getResourceWithRequest", ReadResourceRequest.class); + + // Provide a mock McpResource annotation since the method doesn't have one + BiFunction callback = SyncStatelessMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(createMockMcpResource())) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + ReadResourceRequest request = new ReadResourceRequest("test/resource"); + + ReadResourceResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(1); + assertThat(result.contents().get(0)).isInstanceOf(TextResourceContents.class); + TextResourceContents textContent = (TextResourceContents) result.contents().get(0); + assertThat(textContent.text()).isEqualTo("Content for test/resource"); + } + + @Test + public void testCallbackWithContextAndRequestParameters() throws Exception { + TestResourceProvider provider = new TestResourceProvider(); + Method method = TestResourceProvider.class.getMethod("getResourceWithContext", McpTransportContext.class, + ReadResourceRequest.class); + + // Use the builder to provide a mock McpResource annotation + BiFunction callback = SyncStatelessMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(createMockMcpResource())) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + ReadResourceRequest request = new ReadResourceRequest("test/resource"); + + ReadResourceResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(1); + assertThat(result.contents().get(0)).isInstanceOf(TextResourceContents.class); + TextResourceContents textContent = (TextResourceContents) result.contents().get(0); + assertThat(textContent.text()).isEqualTo("Content with context for test/resource"); + } + + @Test + public void testCallbackWithUriParameter() throws Exception { + TestResourceProvider provider = new TestResourceProvider(); + Method method = TestResourceProvider.class.getMethod("getResourceWithUri", String.class); + + BiFunction callback = SyncStatelessMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(createMockMcpResource())) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + ReadResourceRequest request = new ReadResourceRequest("test/resource"); + + ReadResourceResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(1); + assertThat(result.contents().get(0)).isInstanceOf(TextResourceContents.class); + TextResourceContents textContent = (TextResourceContents) result.contents().get(0); + assertThat(textContent.text()).isEqualTo("Content from URI: test/resource"); + } + + @Test + public void testCallbackWithUriVariables() throws Exception { + TestResourceProvider provider = new TestResourceProvider(); + Method method = TestResourceProvider.class.getMethod("getResourceWithUriVariables", String.class, String.class); + McpResource resourceAnnotation = method.getAnnotation(McpResource.class); + + BiFunction callback = SyncStatelessMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(resourceAnnotation)) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + ReadResourceRequest request = new ReadResourceRequest("users/123/posts/456"); + + ReadResourceResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(1); + assertThat(result.contents().get(0)).isInstanceOf(TextResourceContents.class); + TextResourceContents textContent = (TextResourceContents) result.contents().get(0); + assertThat(textContent.text()).isEqualTo("User: 123, Post: 456"); + } + + @Test + public void testCallbackWithContextAndUriVariable() throws Exception { + TestResourceProvider provider = new TestResourceProvider(); + Method method = TestResourceProvider.class.getMethod("getResourceWithContextAndUriVariable", + McpTransportContext.class, String.class); + McpResource resourceAnnotation = method.getAnnotation(McpResource.class); + + BiFunction callback = SyncStatelessMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(resourceAnnotation)) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + ReadResourceRequest request = new ReadResourceRequest("users/789/profile"); + + ReadResourceResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(1); + assertThat(result.contents().get(0)).isInstanceOf(TextResourceContents.class); + TextResourceContents textContent = (TextResourceContents) result.contents().get(0); + assertThat(textContent.text()).isEqualTo("Profile for user: 789"); + } + + @Test + public void testCallbackWithResourceContentsList() throws Exception { + TestResourceProvider provider = new TestResourceProvider(); + Method method = TestResourceProvider.class.getMethod("getResourceContentsList", ReadResourceRequest.class); + + BiFunction callback = SyncStatelessMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(createMockMcpResource())) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + ReadResourceRequest request = new ReadResourceRequest("test/resource"); + + ReadResourceResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(1); + assertThat(result.contents().get(0)).isInstanceOf(TextResourceContents.class); + TextResourceContents textContent = (TextResourceContents) result.contents().get(0); + assertThat(textContent.text()).isEqualTo("Content list for test/resource"); + } + + @Test + public void testCallbackWithStringList() throws Exception { + TestResourceProvider provider = new TestResourceProvider(); + Method method = TestResourceProvider.class.getMethod("getStringList", ReadResourceRequest.class); + + BiFunction callback = SyncStatelessMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(createMockMcpResource())) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + ReadResourceRequest request = new ReadResourceRequest("test/resource"); + + ReadResourceResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(2); + assertThat(result.contents().get(0)).isInstanceOf(TextResourceContents.class); + TextResourceContents textContent1 = (TextResourceContents) result.contents().get(0); + TextResourceContents textContent2 = (TextResourceContents) result.contents().get(1); + assertThat(textContent1.text()).isEqualTo("String 1 for test/resource"); + assertThat(textContent2.text()).isEqualTo("String 2 for test/resource"); + } + + @Test + public void testCallbackWithSingleResourceContents() throws Exception { + TestResourceProvider provider = new TestResourceProvider(); + Method method = TestResourceProvider.class.getMethod("getSingleResourceContents", ReadResourceRequest.class); + + BiFunction callback = SyncStatelessMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(createMockMcpResource())) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + ReadResourceRequest request = new ReadResourceRequest("test/resource"); + + ReadResourceResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(1); + assertThat(result.contents().get(0)).isInstanceOf(TextResourceContents.class); + TextResourceContents textContent = (TextResourceContents) result.contents().get(0); + assertThat(textContent.text()).isEqualTo("Single resource content for test/resource"); + } + + @Test + public void testCallbackWithSingleString() throws Exception { + TestResourceProvider provider = new TestResourceProvider(); + Method method = TestResourceProvider.class.getMethod("getSingleString", ReadResourceRequest.class); + + BiFunction callback = SyncStatelessMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(createMockMcpResource())) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + ReadResourceRequest request = new ReadResourceRequest("test/resource"); + + ReadResourceResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(1); + assertThat(result.contents().get(0)).isInstanceOf(TextResourceContents.class); + TextResourceContents textContent = (TextResourceContents) result.contents().get(0); + assertThat(textContent.text()).isEqualTo("Single string for test/resource"); + } + + @Test + public void testCallbackWithStringAndTextContentType() throws Exception { + TestResourceProvider provider = new TestResourceProvider(); + Method method = TestResourceProvider.class.getMethod("getStringWithTextContentType", ReadResourceRequest.class); + McpResource resourceAnnotation = method.getAnnotation(McpResource.class); + + BiFunction callback = SyncStatelessMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(resourceAnnotation)) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + ReadResourceRequest request = new ReadResourceRequest("test/resource"); + + ReadResourceResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(1); + assertThat(result.contents().get(0)).isInstanceOf(TextResourceContents.class); + TextResourceContents textContent = (TextResourceContents) result.contents().get(0); + assertThat(textContent.text()).isEqualTo("Text content type for test/resource"); + assertThat(textContent.mimeType()).isEqualTo("text/plain"); + } + + @Test + public void testCallbackWithStringAndBlobContentType() throws Exception { + TestResourceProvider provider = new TestResourceProvider(); + Method method = TestResourceProvider.class.getMethod("getStringWithBlobContentType", ReadResourceRequest.class); + McpResource resourceAnnotation = method.getAnnotation(McpResource.class); + + BiFunction callback = SyncStatelessMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(resourceAnnotation)) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + ReadResourceRequest request = new ReadResourceRequest("test/resource"); + + ReadResourceResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(1); + assertThat(result.contents().get(0)).isInstanceOf(BlobResourceContents.class); + BlobResourceContents blobContent = (BlobResourceContents) result.contents().get(0); + assertThat(blobContent.blob()).isEqualTo("Blob content type for test/resource"); + assertThat(blobContent.mimeType()).isEqualTo("application/octet-stream"); + } + + @Test + public void testCallbackWithStringListAndTextContentType() throws Exception { + TestResourceProvider provider = new TestResourceProvider(); + Method method = TestResourceProvider.class.getMethod("getStringListWithTextContentType", + ReadResourceRequest.class); + McpResource resourceAnnotation = method.getAnnotation(McpResource.class); + + BiFunction callback = SyncStatelessMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(resourceAnnotation)) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + ReadResourceRequest request = new ReadResourceRequest("test/resource"); + + ReadResourceResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(2); + assertThat(result.contents().get(0)).isInstanceOf(TextResourceContents.class); + TextResourceContents textContent1 = (TextResourceContents) result.contents().get(0); + TextResourceContents textContent2 = (TextResourceContents) result.contents().get(1); + assertThat(textContent1.text()).isEqualTo("HTML text 1 for test/resource"); + assertThat(textContent2.text()).isEqualTo("HTML text 2 for test/resource"); + assertThat(textContent1.mimeType()).isEqualTo("text/html"); + assertThat(textContent2.mimeType()).isEqualTo("text/html"); + } + + @Test + public void testCallbackWithStringListAndBlobContentType() throws Exception { + TestResourceProvider provider = new TestResourceProvider(); + Method method = TestResourceProvider.class.getMethod("getStringListWithBlobContentType", + ReadResourceRequest.class); + McpResource resourceAnnotation = method.getAnnotation(McpResource.class); + + BiFunction callback = SyncStatelessMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(resourceAnnotation)) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + ReadResourceRequest request = new ReadResourceRequest("test/resource"); + + ReadResourceResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(2); + assertThat(result.contents().get(0)).isInstanceOf(BlobResourceContents.class); + BlobResourceContents blobContent1 = (BlobResourceContents) result.contents().get(0); + BlobResourceContents blobContent2 = (BlobResourceContents) result.contents().get(1); + assertThat(blobContent1.blob()).isEqualTo("PNG blob 1 for test/resource"); + assertThat(blobContent2.blob()).isEqualTo("PNG blob 2 for test/resource"); + assertThat(blobContent1.mimeType()).isEqualTo("image/png"); + assertThat(blobContent2.mimeType()).isEqualTo("image/png"); + } + + @Test + public void testInvalidReturnType() throws Exception { + TestResourceProvider provider = new TestResourceProvider(); + Method method = TestResourceProvider.class.getMethod("invalidReturnType", ReadResourceRequest.class); + + assertThatThrownBy(() -> SyncStatelessMcpResourceMethodCallback.builder().method(method).bean(provider).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("URI must not be null or empty"); + } + + @Test + public void testInvalidUriVariableParameters() throws Exception { + TestResourceProvider provider = new TestResourceProvider(); + Method method = TestResourceProvider.class.getMethod("getResourceWithUriVariables", String.class, String.class); + McpResource resourceAnnotation = method.getAnnotation(McpResource.class); + + // Create a mock annotation with a different URI template that has more variables + // than the method has parameters + McpResource mockResourceAnnotation = new McpResource() { + @Override + public Class annotationType() { + return McpResource.class; + } + + @Override + public String uri() { + return "users/{userId}/posts/{postId}/comments/{commentId}"; + } + + @Override + public String name() { + return "testResourceWithExtraVariables"; + } + + @Override + public String description() { + return "Test resource with extra URI variables"; + } + + @Override + public String mimeType() { + return "text/plain"; + } + }; + + assertThatThrownBy(() -> SyncStatelessMcpResourceMethodCallback.builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(mockResourceAnnotation)) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Method must have parameters for all URI variables"); + } + + @Test + public void testNullRequest() throws Exception { + TestResourceProvider provider = new TestResourceProvider(); + Method method = TestResourceProvider.class.getMethod("getSingleString", ReadResourceRequest.class); + + BiFunction callback = SyncStatelessMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(createMockMcpResource())) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + + assertThatThrownBy(() -> callback.apply(context, null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Request must not be null"); + } + + @Test + public void testIsExchangeOrContextType() throws Exception { + TestResourceProvider provider = new TestResourceProvider(); + Method method = TestResourceProvider.class.getMethod("getSingleString", ReadResourceRequest.class); + SyncStatelessMcpResourceMethodCallback callback = SyncStatelessMcpResourceMethodCallback.builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(createMockMcpResource())) + .build(); + + // Test that McpTransportContext is recognized as exchange type + // Note: We need to use reflection to access the protected method for testing + java.lang.reflect.Method isExchangeOrContextTypeMethod = SyncStatelessMcpResourceMethodCallback.class + .getDeclaredMethod("isExchangeOrContextType", Class.class); + isExchangeOrContextTypeMethod.setAccessible(true); + + assertThat((Boolean) isExchangeOrContextTypeMethod.invoke(callback, McpTransportContext.class)).isTrue(); + + // Test that other types are not recognized as exchange type + assertThat((Boolean) isExchangeOrContextTypeMethod.invoke(callback, String.class)).isFalse(); + assertThat((Boolean) isExchangeOrContextTypeMethod.invoke(callback, Integer.class)).isFalse(); + assertThat((Boolean) isExchangeOrContextTypeMethod.invoke(callback, Object.class)).isFalse(); + } + + @Test + public void testMethodWithoutMcpResourceAnnotation() throws Exception { + TestResourceProvider provider = new TestResourceProvider(); + // Use a method that doesn't have the McpResource annotation + Method method = TestResourceProvider.class.getMethod("getResourceWithRequest", ReadResourceRequest.class); + + // Create a callback without explicitly providing the annotation + // This should now throw an exception since the method doesn't have the annotation + assertThatThrownBy(() -> SyncStatelessMcpResourceMethodCallback.builder().method(method).bean(provider).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("URI must not be null or empty"); + } + + @Test + public void testBuilderValidation() { + TestResourceProvider provider = new TestResourceProvider(); + + // Test null method + assertThatThrownBy(() -> SyncStatelessMcpResourceMethodCallback.builder().bean(provider).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Method must not be null"); + + // Test null bean + assertThatThrownBy(() -> SyncStatelessMcpResourceMethodCallback.builder() + .method(TestResourceProvider.class.getMethod("getSingleString", ReadResourceRequest.class)) + .build()).isInstanceOf(IllegalArgumentException.class).hasMessage("Bean must not be null"); + } + + @Test + public void testUriVariableExtraction() throws Exception { + TestResourceProvider provider = new TestResourceProvider(); + Method method = TestResourceProvider.class.getMethod("getResourceWithUriVariables", String.class, String.class); + McpResource resourceAnnotation = method.getAnnotation(McpResource.class); + + BiFunction callback = SyncStatelessMcpResourceMethodCallback + .builder() + .method(method) + .bean(provider) + .resource(ResourceAdaptor.asResource(resourceAnnotation)) + .build(); + + McpTransportContext context = mock(McpTransportContext.class); + + // Test with mismatched URI that doesn't contain expected variables + ReadResourceRequest invalidRequest = new ReadResourceRequest("invalid/uri/format"); + + assertThatThrownBy(() -> callback.apply(context, invalidRequest)) + .isInstanceOf(AbstractMcpResourceMethodCallback.McpResourceMethodException.class) + .hasMessageContaining("Access error invoking resource method"); + } + +} diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/tool/AsyncStatelessMcpToolMethodCallbackTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/tool/AsyncStatelessMcpToolMethodCallbackTests.java new file mode 100644 index 0000000..c1a6c14 --- /dev/null +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/tool/AsyncStatelessMcpToolMethodCallbackTests.java @@ -0,0 +1,770 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.method.tool; + +import java.lang.reflect.Method; +import java.util.List; +import java.util.Map; + +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import org.junit.jupiter.api.Test; +import org.reactivestreams.Publisher; +import org.springaicommunity.mcp.annotation.McpTool; +import org.springaicommunity.mcp.annotation.McpToolParam; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; + +/** + * Tests for {@link AsyncStatelessMcpToolMethodCallback}. + * + * @author Christian Tzolov + */ +public class AsyncStatelessMcpToolMethodCallbackTests { + + private static class TestAsyncStatelessToolProvider { + + @McpTool(name = "simple-mono-tool", description = "A simple mono tool") + public Mono simpleMonoTool(String input) { + return Mono.just("Processed: " + input); + } + + @McpTool(name = "simple-flux-tool", description = "A simple flux tool") + public Flux simpleFluxTool(String input) { + return Flux.just("Processed: " + input); + } + + @McpTool(name = "simple-publisher-tool", description = "A simple publisher tool") + public Publisher simplePublisherTool(String input) { + return Mono.just("Processed: " + input); + } + + @McpTool(name = "math-mono-tool", description = "A math mono tool") + public Mono addNumbersMono(int a, int b) { + return Mono.just(a + b); + } + + @McpTool(name = "complex-mono-tool", description = "A complex mono tool") + public Mono complexMonoTool(String name, int age, boolean active) { + return Mono.just(CallToolResult.builder() + .addTextContent("Name: " + name + ", Age: " + age + ", Active: " + active) + .build()); + } + + @McpTool(name = "complex-flux-tool", description = "A complex flux tool") + public Flux complexFluxTool(String name, int age, boolean active) { + return Flux.just(CallToolResult.builder() + .addTextContent("Name: " + name + ", Age: " + age + ", Active: " + active) + .build()); + } + + @McpTool(name = "context-mono-tool", description = "Mono tool with context parameter") + public Mono monoToolWithContext(McpTransportContext context, String message) { + return Mono.just("Context tool: " + message); + } + + @McpTool(name = "list-mono-tool", description = "Mono tool with list parameter") + public Mono processListMono(List items) { + return Mono.just("Items: " + String.join(", ", items)); + } + + @McpTool(name = "object-mono-tool", description = "Mono tool with object parameter") + public Mono processObjectMono(TestObject obj) { + return Mono.just("Object: " + obj.name + " - " + obj.value); + } + + @McpTool(name = "optional-params-mono-tool", description = "Mono tool with optional parameters") + public Mono monoToolWithOptionalParams(@McpToolParam(required = true) String required, + @McpToolParam(required = false) String optional) { + return Mono.just("Required: " + required + ", Optional: " + (optional != null ? optional : "null")); + } + + @McpTool(name = "no-params-mono-tool", description = "Mono tool with no parameters") + public Mono noParamsMonoTool() { + return Mono.just("No parameters needed"); + } + + @McpTool(name = "exception-mono-tool", description = "Mono tool that throws exception") + public Mono exceptionMonoTool(String input) { + return Mono.error(new RuntimeException("Tool execution failed: " + input)); + } + + @McpTool(name = "null-return-mono-tool", description = "Mono tool that returns null") + public Mono nullReturnMonoTool() { + return Mono.just((String) null); + } + + @McpTool(name = "void-mono-tool", description = "Mono tool") + public Mono voidMonoTool(String input) { + return Mono.empty(); + } + + @McpTool(name = "void-flux-tool", description = "Flux tool") + public Flux voidFluxTool(String input) { + return Flux.empty(); + } + + @McpTool(name = "enum-mono-tool", description = "Mono tool with enum parameter") + public Mono enumMonoTool(TestEnum enumValue) { + return Mono.just("Enum: " + enumValue.name()); + } + + @McpTool(name = "primitive-types-mono-tool", description = "Mono tool with primitive types") + public Mono primitiveTypesMonoTool(boolean flag, byte b, short s, int i, long l, float f, double d) { + return Mono.just(String.format("Primitives: %b, %d, %d, %d, %d, %.1f, %.1f", flag, b, s, i, l, f, d)); + } + + @McpTool(name = "return-object-mono-tool", description = "Mono tool that returns a complex object") + public Mono returnObjectMonoTool(String name, int value) { + return Mono.just(new TestObject(name, value)); + } + + @McpTool(name = "delayed-mono-tool", description = "Mono tool with delay") + public Mono delayedMonoTool(String input) { + return Mono.just("Delayed: " + input); + } + + @McpTool(name = "empty-mono-tool", description = "Mono tool that returns empty") + public Mono emptyMonoTool() { + return Mono.empty(); + } + + @McpTool(name = "multiple-flux-tool", description = "Flux tool that emits multiple values") + public Flux multipleFluxTool(String prefix) { + return Flux.just(prefix + "1", prefix + "2", prefix + "3"); + } + + @McpTool(name = "private-mono-tool", description = "Private mono tool") + private Mono privateMonoTool(String input) { + return Mono.just("Private: " + input); + } + + // Non-reactive method that should cause error in async context + @McpTool(name = "non-reactive-tool", description = "Non-reactive tool") + public String nonReactiveTool(String input) { + return "Non-reactive: " + input; + } + + } + + public static class TestObject { + + public String name; + + public int value; + + public TestObject() { + } + + public TestObject(String name, int value) { + this.name = name; + this.value = value; + } + + } + + public enum TestEnum { + + OPTION_A, OPTION_B, OPTION_C + + } + + @Test + public void testSimpleMonoToolCallback() throws Exception { + TestAsyncStatelessToolProvider provider = new TestAsyncStatelessToolProvider(); + Method method = TestAsyncStatelessToolProvider.class.getMethod("simpleMonoTool", String.class); + AsyncStatelessMcpToolMethodCallback callback = new AsyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("simple-mono-tool", Map.of("input", "test message")); + + StepVerifier.create(callback.apply(context, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Processed: test message"); + }).verifyComplete(); + } + + @Test + public void testSimpleFluxToolCallback() throws Exception { + TestAsyncStatelessToolProvider provider = new TestAsyncStatelessToolProvider(); + Method method = TestAsyncStatelessToolProvider.class.getMethod("simpleFluxTool", String.class); + AsyncStatelessMcpToolMethodCallback callback = new AsyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("simple-flux-tool", Map.of("input", "test message")); + + StepVerifier.create(callback.apply(context, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Processed: test message"); + }).verifyComplete(); + } + + @Test + public void testSimplePublisherToolCallback() throws Exception { + TestAsyncStatelessToolProvider provider = new TestAsyncStatelessToolProvider(); + Method method = TestAsyncStatelessToolProvider.class.getMethod("simplePublisherTool", String.class); + AsyncStatelessMcpToolMethodCallback callback = new AsyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("simple-publisher-tool", Map.of("input", "test message")); + + StepVerifier.create(callback.apply(context, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Processed: test message"); + }).verifyComplete(); + } + + @Test + public void testMathMonoToolCallback() throws Exception { + TestAsyncStatelessToolProvider provider = new TestAsyncStatelessToolProvider(); + Method method = TestAsyncStatelessToolProvider.class.getMethod("addNumbersMono", int.class, int.class); + AsyncStatelessMcpToolMethodCallback callback = new AsyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("math-mono-tool", Map.of("a", 5, "b", 3)); + + StepVerifier.create(callback.apply(context, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("8"); + }).verifyComplete(); + } + + @Test + public void testMonoToolThatThrowsException() throws Exception { + TestAsyncStatelessToolProvider provider = new TestAsyncStatelessToolProvider(); + Method method = TestAsyncStatelessToolProvider.class.getMethod("exceptionMonoTool", String.class); + AsyncStatelessMcpToolMethodCallback callback = new AsyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("exception-mono-tool", Map.of("input", "test")); + + StepVerifier.create(callback.apply(context, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isTrue(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).contains("Error invoking method"); + }).verifyComplete(); + } + + @Test + public void testComplexFluxToolCallback() throws Exception { + TestAsyncStatelessToolProvider provider = new TestAsyncStatelessToolProvider(); + Method method = TestAsyncStatelessToolProvider.class.getMethod("complexFluxTool", String.class, int.class, + boolean.class); + AsyncStatelessMcpToolMethodCallback callback = new AsyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("complex-flux-tool", + Map.of("name", "Alice", "age", 25, "active", false)); + + StepVerifier.create(callback.apply(context, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Name: Alice, Age: 25, Active: false"); + }).verifyComplete(); + } + + @Test + public void testMonoToolWithContextParameter() throws Exception { + TestAsyncStatelessToolProvider provider = new TestAsyncStatelessToolProvider(); + Method method = TestAsyncStatelessToolProvider.class.getMethod("monoToolWithContext", McpTransportContext.class, + String.class); + AsyncStatelessMcpToolMethodCallback callback = new AsyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("context-mono-tool", Map.of("message", "hello")); + + StepVerifier.create(callback.apply(context, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Context tool: hello"); + }).verifyComplete(); + } + + @Test + public void testMonoToolWithListParameter() throws Exception { + TestAsyncStatelessToolProvider provider = new TestAsyncStatelessToolProvider(); + Method method = TestAsyncStatelessToolProvider.class.getMethod("processListMono", List.class); + AsyncStatelessMcpToolMethodCallback callback = new AsyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("list-mono-tool", + Map.of("items", List.of("item1", "item2", "item3"))); + + StepVerifier.create(callback.apply(context, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Items: item1, item2, item3"); + }).verifyComplete(); + } + + @Test + public void testMonoToolWithObjectParameter() throws Exception { + TestAsyncStatelessToolProvider provider = new TestAsyncStatelessToolProvider(); + Method method = TestAsyncStatelessToolProvider.class.getMethod("processObjectMono", TestObject.class); + AsyncStatelessMcpToolMethodCallback callback = new AsyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("object-mono-tool", + Map.of("obj", Map.of("name", "test", "value", 42))); + + StepVerifier.create(callback.apply(context, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Object: test - 42"); + }).verifyComplete(); + } + + @Test + public void testMonoToolWithNoParameters() throws Exception { + TestAsyncStatelessToolProvider provider = new TestAsyncStatelessToolProvider(); + Method method = TestAsyncStatelessToolProvider.class.getMethod("noParamsMonoTool"); + AsyncStatelessMcpToolMethodCallback callback = new AsyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("no-params-mono-tool", Map.of()); + + StepVerifier.create(callback.apply(context, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("No parameters needed"); + }).verifyComplete(); + } + + @Test + public void testMonoToolWithEnumParameter() throws Exception { + TestAsyncStatelessToolProvider provider = new TestAsyncStatelessToolProvider(); + Method method = TestAsyncStatelessToolProvider.class.getMethod("enumMonoTool", TestEnum.class); + AsyncStatelessMcpToolMethodCallback callback = new AsyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("enum-mono-tool", Map.of("enumValue", "OPTION_B")); + + StepVerifier.create(callback.apply(context, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Enum: OPTION_B"); + }).verifyComplete(); + } + + @Test + public void testComplexMonoToolCallback() throws Exception { + TestAsyncStatelessToolProvider provider = new TestAsyncStatelessToolProvider(); + Method method = TestAsyncStatelessToolProvider.class.getMethod("complexMonoTool", String.class, int.class, + boolean.class); + AsyncStatelessMcpToolMethodCallback callback = new AsyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("complex-mono-tool", + Map.of("name", "John", "age", 30, "active", true)); + + StepVerifier.create(callback.apply(context, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Name: John, Age: 30, Active: true"); + }).verifyComplete(); + } + + @Test + public void testMonoToolWithMissingParameters() throws Exception { + TestAsyncStatelessToolProvider provider = new TestAsyncStatelessToolProvider(); + Method method = TestAsyncStatelessToolProvider.class.getMethod("simpleMonoTool", String.class); + AsyncStatelessMcpToolMethodCallback callback = new AsyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("simple-mono-tool", Map.of()); + + StepVerifier.create(callback.apply(context, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Processed: null"); + }).verifyComplete(); + } + + @Test + public void testMonoToolWithPrimitiveTypes() throws Exception { + TestAsyncStatelessToolProvider provider = new TestAsyncStatelessToolProvider(); + Method method = TestAsyncStatelessToolProvider.class.getMethod("primitiveTypesMonoTool", boolean.class, + byte.class, short.class, int.class, long.class, float.class, double.class); + AsyncStatelessMcpToolMethodCallback callback = new AsyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("primitive-types-mono-tool", + Map.of("flag", true, "b", 1, "s", 2, "i", 3, "l", 4L, "f", 5.5f, "d", 6.6)); + + StepVerifier.create(callback.apply(context, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()) + .isEqualTo("Primitives: true, 1, 2, 3, 4, 5.5, 6.6"); + }).verifyComplete(); + } + + @Test + public void testMonoToolWithNullParameters() throws Exception { + TestAsyncStatelessToolProvider provider = new TestAsyncStatelessToolProvider(); + Method method = TestAsyncStatelessToolProvider.class.getMethod("simpleMonoTool", String.class); + AsyncStatelessMcpToolMethodCallback callback = new AsyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + Map args = new java.util.HashMap<>(); + args.put("input", null); + CallToolRequest request = new CallToolRequest("simple-mono-tool", args); + + StepVerifier.create(callback.apply(context, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Processed: null"); + }).verifyComplete(); + } + + @Test + public void testMonoToolThatReturnsNull() throws Exception { + TestAsyncStatelessToolProvider provider = new TestAsyncStatelessToolProvider(); + Method method = TestAsyncStatelessToolProvider.class.getMethod("nullReturnMonoTool"); + AsyncStatelessMcpToolMethodCallback callback = new AsyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("null-return-mono-tool", Map.of()); + + StepVerifier.create(callback.apply(context, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isTrue(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()) + .isEqualTo("Error invoking method: Error invoking method: nullReturnMonoTool"); + }).verifyComplete(); + } + + @Test + public void testVoidMonoTool() throws Exception { + TestAsyncStatelessToolProvider provider = new TestAsyncStatelessToolProvider(); + Method method = TestAsyncStatelessToolProvider.class.getMethod("voidMonoTool", String.class); + AsyncStatelessMcpToolMethodCallback callback = new AsyncStatelessMcpToolMethodCallback(ReturnMode.VOID, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("void-mono-tool", Map.of("input", "test")); + + StepVerifier.create(callback.apply(context, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("\"Done\""); + }).verifyComplete(); + } + + @Test + public void testVoidFluxTool() throws Exception { + TestAsyncStatelessToolProvider provider = new TestAsyncStatelessToolProvider(); + Method method = TestAsyncStatelessToolProvider.class.getMethod("voidFluxTool", String.class); + AsyncStatelessMcpToolMethodCallback callback = new AsyncStatelessMcpToolMethodCallback(ReturnMode.VOID, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("void-flux-tool", Map.of("input", "test")); + + StepVerifier.create(callback.apply(context, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("\"Done\""); + }).verifyComplete(); + } + + @Test + public void testPrivateMonoToolMethod() throws Exception { + TestAsyncStatelessToolProvider provider = new TestAsyncStatelessToolProvider(); + Method method = TestAsyncStatelessToolProvider.class.getDeclaredMethod("privateMonoTool", String.class); + AsyncStatelessMcpToolMethodCallback callback = new AsyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("private-mono-tool", Map.of("input", "test")); + + StepVerifier.create(callback.apply(context, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Private: test"); + }).verifyComplete(); + } + + @Test + public void testNullRequest() throws Exception { + TestAsyncStatelessToolProvider provider = new TestAsyncStatelessToolProvider(); + Method method = TestAsyncStatelessToolProvider.class.getMethod("simpleMonoTool", String.class); + AsyncStatelessMcpToolMethodCallback callback = new AsyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + + StepVerifier.create(callback.apply(context, null)).expectError(IllegalArgumentException.class).verify(); + } + + @Test + public void testMonoToolReturningComplexObject() throws Exception { + TestAsyncStatelessToolProvider provider = new TestAsyncStatelessToolProvider(); + Method method = TestAsyncStatelessToolProvider.class.getMethod("returnObjectMonoTool", String.class, int.class); + AsyncStatelessMcpToolMethodCallback callback = new AsyncStatelessMcpToolMethodCallback(ReturnMode.STRUCTURED, + method, provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("return-object-mono-tool", Map.of("name", "test", "value", 42)); + + StepVerifier.create(callback.apply(context, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).isEmpty(); + assertThat(result.structuredContent()).isNotNull(); + assertThat(result.structuredContent()).containsEntry("name", "test"); + assertThat(result.structuredContent()).containsEntry("value", 42); + }).verifyComplete(); + } + + @Test + public void testEmptyMonoTool() throws Exception { + TestAsyncStatelessToolProvider provider = new TestAsyncStatelessToolProvider(); + Method method = TestAsyncStatelessToolProvider.class.getMethod("emptyMonoTool"); + AsyncStatelessMcpToolMethodCallback callback = new AsyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("empty-mono-tool", Map.of()); + + StepVerifier.create(callback.apply(context, request)).verifyComplete(); + } + + @Test + public void testMultipleFluxTool() throws Exception { + TestAsyncStatelessToolProvider provider = new TestAsyncStatelessToolProvider(); + Method method = TestAsyncStatelessToolProvider.class.getMethod("multipleFluxTool", String.class); + AsyncStatelessMcpToolMethodCallback callback = new AsyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("multiple-flux-tool", Map.of("prefix", "item")); + + // Flux tools should take the first element + StepVerifier.create(callback.apply(context, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("item1"); + }).verifyComplete(); + } + + @Test + public void testNonReactiveToolShouldFail() throws Exception { + TestAsyncStatelessToolProvider provider = new TestAsyncStatelessToolProvider(); + Method method = TestAsyncStatelessToolProvider.class.getMethod("nonReactiveTool", String.class); + AsyncStatelessMcpToolMethodCallback callback = new AsyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("non-reactive-tool", Map.of("input", "test")); + + StepVerifier.create(callback.apply(context, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isTrue(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).contains("Error invoking method"); + }).verifyComplete(); + } + + @Test + public void testMonoToolWithInvalidJsonConversion() throws Exception { + TestAsyncStatelessToolProvider provider = new TestAsyncStatelessToolProvider(); + Method method = TestAsyncStatelessToolProvider.class.getMethod("processObjectMono", TestObject.class); + AsyncStatelessMcpToolMethodCallback callback = new AsyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + // Pass invalid object structure that can't be converted to TestObject + CallToolRequest request = new CallToolRequest("object-mono-tool", Map.of("obj", "invalid-object-string")); + + StepVerifier.create(callback.apply(context, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isTrue(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).contains("Error invoking method"); + }).verifyComplete(); + } + + @Test + public void testConstructorParameters() { + TestAsyncStatelessToolProvider provider = new TestAsyncStatelessToolProvider(); + Method method = TestAsyncStatelessToolProvider.class.getMethods()[0]; // Any + // method + + AsyncStatelessMcpToolMethodCallback callback = new AsyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + // Verify that the callback was created successfully + assertThat(callback).isNotNull(); + } + + @Test + public void testIsExchangeOrContextType() throws Exception { + TestAsyncStatelessToolProvider provider = new TestAsyncStatelessToolProvider(); + Method method = TestAsyncStatelessToolProvider.class.getMethod("simpleMonoTool", String.class); + AsyncStatelessMcpToolMethodCallback callback = new AsyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + // Test that McpTransportContext is recognized as context type + assertThat(callback.isExchangeOrContextType(McpTransportContext.class)).isTrue(); + + // Test that other types are not recognized as context type + assertThat(callback.isExchangeOrContextType(String.class)).isFalse(); + assertThat(callback.isExchangeOrContextType(Integer.class)).isFalse(); + assertThat(callback.isExchangeOrContextType(Object.class)).isFalse(); + } + + @Test + public void testMonoToolWithOptionalParameters() throws Exception { + TestAsyncStatelessToolProvider provider = new TestAsyncStatelessToolProvider(); + Method method = TestAsyncStatelessToolProvider.class.getMethod("monoToolWithOptionalParams", String.class, + String.class); + AsyncStatelessMcpToolMethodCallback callback = new AsyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("optional-params-mono-tool", + Map.of("required", "test", "optional", "optional-value")); + + StepVerifier.create(callback.apply(context, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()) + .isEqualTo("Required: test, Optional: optional-value"); + }).verifyComplete(); + } + + @Test + public void testMonoToolWithOptionalParametersMissing() throws Exception { + TestAsyncStatelessToolProvider provider = new TestAsyncStatelessToolProvider(); + Method method = TestAsyncStatelessToolProvider.class.getMethod("monoToolWithOptionalParams", String.class, + String.class); + AsyncStatelessMcpToolMethodCallback callback = new AsyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("optional-params-mono-tool", Map.of("required", "test")); + + StepVerifier.create(callback.apply(context, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Required: test, Optional: null"); + }).verifyComplete(); + } + + @Test + public void testMonoToolWithStructuredOutput() throws Exception { + TestAsyncStatelessToolProvider provider = new TestAsyncStatelessToolProvider(); + Method method = TestAsyncStatelessToolProvider.class.getMethod("processObjectMono", TestObject.class); + AsyncStatelessMcpToolMethodCallback callback = new AsyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("object-mono-tool", + Map.of("obj", Map.of("name", "test", "value", 42))); + + StepVerifier.create(callback.apply(context, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Object: test - 42"); + }).verifyComplete(); + } + + @Test + public void testCallbackReturnsCallToolResult() throws Exception { + TestAsyncStatelessToolProvider provider = new TestAsyncStatelessToolProvider(); + Method method = TestAsyncStatelessToolProvider.class.getMethod("complexMonoTool", String.class, int.class, + boolean.class); + AsyncStatelessMcpToolMethodCallback callback = new AsyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("complex-mono-tool", + Map.of("name", "Alice", "age", 25, "active", false)); + + StepVerifier.create(callback.apply(context, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Name: Alice, Age: 25, Active: false"); + }).verifyComplete(); + } + +} diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/tool/SyncStatelessMcpToolMethodCallbackTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/tool/SyncStatelessMcpToolMethodCallbackTests.java new file mode 100644 index 0000000..6996222 --- /dev/null +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/tool/SyncStatelessMcpToolMethodCallbackTests.java @@ -0,0 +1,538 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.method.tool; + +import java.lang.reflect.Method; +import java.util.List; +import java.util.Map; + +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import org.junit.jupiter.api.Test; +import org.springaicommunity.mcp.annotation.McpTool; +import org.springaicommunity.mcp.annotation.McpToolParam; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +/** + * Tests for {@link SyncStatelessMcpToolMethodCallback}. + * + * @author Christian Tzolov + */ +public class SyncStatelessMcpToolMethodCallbackTests { + + private static class TestToolProvider { + + @McpTool(name = "simple-tool", description = "A simple tool") + public String simpleTool(String input) { + return "Processed: " + input; + } + + @McpTool(name = "math-tool", description = "A math tool") + public int addNumbers(int a, int b) { + return a + b; + } + + @McpTool(name = "complex-tool", description = "A complex tool") + public CallToolResult complexTool(String name, int age, boolean active) { + return CallToolResult.builder() + .addTextContent("Name: " + name + ", Age: " + age + ", Active: " + active) + .build(); + } + + @McpTool(name = "context-tool", description = "Tool with context parameter") + public String toolWithContext(McpTransportContext context, String message) { + return "Context tool: " + message; + } + + @McpTool(name = "list-tool", description = "Tool with list parameter") + public String processList(List items) { + return "Items: " + String.join(", ", items); + } + + @McpTool(name = "object-tool", description = "Tool with object parameter") + public String processObject(TestObject obj) { + return "Object: " + obj.name + " - " + obj.value; + } + + @McpTool(name = "optional-params-tool", description = "Tool with optional parameters") + public String toolWithOptionalParams(@McpToolParam(required = true) String required, + @McpToolParam(required = false) String optional) { + return "Required: " + required + ", Optional: " + (optional != null ? optional : "null"); + } + + @McpTool(name = "no-params-tool", description = "Tool with no parameters") + public String noParamsTool() { + return "No parameters needed"; + } + + @McpTool(name = "exception-tool", description = "Tool that throws exception") + public String exceptionTool(String input) { + throw new RuntimeException("Tool execution failed: " + input); + } + + @McpTool(name = "null-return-tool", description = "Tool that returns null") + public String nullReturnTool() { + return null; + } + + public String nonAnnotatedTool(String input) { + return "Non-annotated: " + input; + } + + @McpTool(name = "private-tool", description = "Private tool") + private String privateTool(String input) { + return "Private: " + input; + } + + @McpTool(name = "enum-tool", description = "Tool with enum parameter") + public String enumTool(TestEnum enumValue) { + return "Enum: " + enumValue.name(); + } + + @McpTool(name = "primitive-types-tool", description = "Tool with primitive types") + public String primitiveTypesTool(boolean flag, byte b, short s, int i, long l, float f, double d) { + return String.format("Primitives: %b, %d, %d, %d, %d, %.1f, %.1f", flag, b, s, i, l, f, d); + } + + @McpTool(name = "return-object-tool", description = "Tool that returns a complex object") + public TestObject returnObjectTool(String name, int value) { + return new TestObject(name, value); + } + + @McpTool(name = "void-tool", description = "Tool with void return") + public void voidTool(String input) { + // Do nothing + } + + } + + public static class TestObject { + + public String name; + + public int value; + + public TestObject() { + } + + public TestObject(String name, int value) { + this.name = name; + this.value = value; + } + + } + + public enum TestEnum { + + OPTION_A, OPTION_B, OPTION_C + + } + + @Test + public void testSimpleToolCallback() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethod("simpleTool", String.class); + SyncStatelessMcpToolMethodCallback callback = new SyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("simple-tool", Map.of("input", "test message")); + + CallToolResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Processed: test message"); + } + + @Test + public void testMathToolCallback() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethod("addNumbers", int.class, int.class); + SyncStatelessMcpToolMethodCallback callback = new SyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("math-tool", Map.of("a", 5, "b", 3)); + + CallToolResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("8"); + } + + @Test + public void testComplexToolCallback() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethod("complexTool", String.class, int.class, boolean.class); + SyncStatelessMcpToolMethodCallback callback = new SyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("complex-tool", + Map.of("name", "John", "age", 30, "active", true)); + + CallToolResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Name: John, Age: 30, Active: true"); + } + + @Test + public void testToolWithContextParameter() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethod("toolWithContext", McpTransportContext.class, String.class); + SyncStatelessMcpToolMethodCallback callback = new SyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("context-tool", Map.of("message", "hello")); + + CallToolResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Context tool: hello"); + } + + @Test + public void testToolWithListParameter() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethod("processList", List.class); + SyncStatelessMcpToolMethodCallback callback = new SyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("list-tool", Map.of("items", List.of("item1", "item2", "item3"))); + + CallToolResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Items: item1, item2, item3"); + } + + @Test + public void testToolWithObjectParameter() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethod("processObject", TestObject.class); + SyncStatelessMcpToolMethodCallback callback = new SyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("object-tool", + Map.of("obj", Map.of("name", "test", "value", 42))); + + CallToolResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Object: test - 42"); + } + + @Test + public void testToolWithNoParameters() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethod("noParamsTool"); + SyncStatelessMcpToolMethodCallback callback = new SyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("no-params-tool", Map.of()); + + CallToolResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("No parameters needed"); + } + + @Test + public void testToolWithEnumParameter() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethod("enumTool", TestEnum.class); + SyncStatelessMcpToolMethodCallback callback = new SyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("enum-tool", Map.of("enumValue", "OPTION_B")); + + CallToolResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Enum: OPTION_B"); + } + + @Test + public void testToolWithPrimitiveTypes() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethod("primitiveTypesTool", boolean.class, byte.class, short.class, + int.class, long.class, float.class, double.class); + SyncStatelessMcpToolMethodCallback callback = new SyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("primitive-types-tool", + Map.of("flag", true, "b", 1, "s", 2, "i", 3, "l", 4L, "f", 5.5f, "d", 6.6)); + + CallToolResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Primitives: true, 1, 2, 3, 4, 5.5, 6.6"); + } + + @Test + public void testToolWithNullParameters() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethod("simpleTool", String.class); + SyncStatelessMcpToolMethodCallback callback = new SyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + Map args = new java.util.HashMap<>(); + args.put("input", null); + CallToolRequest request = new CallToolRequest("simple-tool", args); + + CallToolResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Processed: null"); + } + + @Test + public void testToolWithMissingParameters() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethod("simpleTool", String.class); + SyncStatelessMcpToolMethodCallback callback = new SyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("simple-tool", Map.of()); + + CallToolResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Processed: null"); + } + + @Test + public void testToolThatThrowsException() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethod("exceptionTool", String.class); + SyncStatelessMcpToolMethodCallback callback = new SyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("exception-tool", Map.of("input", "test")); + + CallToolResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isTrue(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).contains("Error invoking method"); + assertThat(((TextContent) result.content().get(0)).text()).contains("exceptionTool"); + } + + @Test + public void testToolThatReturnsNull() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethod("nullReturnTool"); + SyncStatelessMcpToolMethodCallback callback = new SyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("null-return-tool", Map.of()); + + CallToolResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("null"); + } + + @Test + public void testPrivateToolMethod() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getDeclaredMethod("privateTool", String.class); + SyncStatelessMcpToolMethodCallback callback = new SyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("private-tool", Map.of("input", "test")); + + CallToolResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Private: test"); + } + + @Test + public void testNullRequest() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethod("simpleTool", String.class); + SyncStatelessMcpToolMethodCallback callback = new SyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + + assertThatThrownBy(() -> callback.apply(context, null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("Request must not be null"); + } + + @Test + public void testCallbackReturnsCallToolResult() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethod("complexTool", String.class, int.class, boolean.class); + SyncStatelessMcpToolMethodCallback callback = new SyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("complex-tool", + Map.of("name", "Alice", "age", 25, "active", false)); + + CallToolResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Name: Alice, Age: 25, Active: false"); + } + + @Test + public void testIsExchangeOrContextType() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethod("simpleTool", String.class); + SyncStatelessMcpToolMethodCallback callback = new SyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + // Test that McpTransportContext is recognized as context type + // Note: We need to use reflection to access the protected method for testing + java.lang.reflect.Method isContextTypeMethod = SyncStatelessMcpToolMethodCallback.class + .getDeclaredMethod("isExchangeOrContextType", Class.class); + isContextTypeMethod.setAccessible(true); + + assertThat((Boolean) isContextTypeMethod.invoke(callback, McpTransportContext.class)).isTrue(); + + // Test that other types are not recognized as context type + assertThat((Boolean) isContextTypeMethod.invoke(callback, String.class)).isFalse(); + assertThat((Boolean) isContextTypeMethod.invoke(callback, Integer.class)).isFalse(); + assertThat((Boolean) isContextTypeMethod.invoke(callback, Object.class)).isFalse(); + } + + @Test + public void testToolWithInvalidJsonConversion() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethod("processObject", TestObject.class); + SyncStatelessMcpToolMethodCallback callback = new SyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + // Pass invalid object structure that can't be converted to TestObject + CallToolRequest request = new CallToolRequest("object-tool", Map.of("obj", "invalid-object-string")); + + CallToolResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isTrue(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).contains("Error invoking method"); + } + + @Test + public void testConstructorParameters() { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethods()[0]; // Any method + + SyncStatelessMcpToolMethodCallback callback = new SyncStatelessMcpToolMethodCallback(ReturnMode.TEXT, method, + provider); + + // Verify that the callback was created successfully + assertThat(callback).isNotNull(); + } + + @Test + public void testToolReturningComplexObject() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethod("returnObjectTool", String.class, int.class); + SyncStatelessMcpToolMethodCallback callback = new SyncStatelessMcpToolMethodCallback(ReturnMode.STRUCTURED, + method, provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("return-object-tool", Map.of("name", "test", "value", 42)); + + CallToolResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + // For complex return types (non-primitive, non-wrapper, non-CallToolResult), + // the new implementation should return structured content + assertThat(result.content()).isEmpty(); + assertThat(result.structuredContent()).isNotNull(); + assertThat(result.structuredContent()).containsEntry("name", "test"); + assertThat(result.structuredContent()).containsEntry("value", 42); + } + + @Test + public void testVoidReturnMode() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethod("voidTool", String.class); + SyncStatelessMcpToolMethodCallback callback = new SyncStatelessMcpToolMethodCallback(ReturnMode.VOID, method, + provider); + + McpTransportContext context = mock(McpTransportContext.class); + CallToolRequest request = new CallToolRequest("void-tool", Map.of("input", "test")); + + CallToolResult result = callback.apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("\"Done\""); + } + +} diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/AsyncStatelessMcpCompleteProviderTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/AsyncStatelessMcpCompleteProviderTests.java new file mode 100644 index 0000000..a16e489 --- /dev/null +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/AsyncStatelessMcpCompleteProviderTests.java @@ -0,0 +1,468 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springaicommunity.mcp.provider; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.springaicommunity.mcp.annotation.McpComplete; + +import io.modelcontextprotocol.server.McpStatelessServerFeatures.AsyncCompletionSpecification; +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema.CompleteRequest; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult.CompleteCompletion; +import io.modelcontextprotocol.spec.McpSchema.PromptReference; +import io.modelcontextprotocol.spec.McpSchema.ResourceReference; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +/** + * Tests for {@link AsyncStatelessMcpCompleteProvider}. + * + * @author Christian Tzolov + */ +public class AsyncStatelessMcpCompleteProviderTests { + + @Test + void testConstructorWithNullCompleteObjects() { + assertThatThrownBy(() -> new AsyncStatelessMcpCompleteProvider(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("completeObjects cannot be null"); + } + + @Test + void testGetCompleteSpecificationsWithSingleValidComplete() { + // Create a class with only one valid async complete method + class SingleValidComplete { + + @McpComplete(prompt = "test-prompt") + public Mono testComplete(CompleteRequest request) { + return Mono.just(new CompleteResult(new CompleteCompletion( + List.of("Async completion for " + request.argument().value()), 1, false))); + } + + } + + SingleValidComplete completeObject = new SingleValidComplete(); + AsyncStatelessMcpCompleteProvider provider = new AsyncStatelessMcpCompleteProvider(List.of(completeObject)); + + List completeSpecs = provider.getCompleteSpecifications(); + + assertThat(completeSpecs).isNotNull(); + assertThat(completeSpecs).hasSize(1); + + AsyncCompletionSpecification completeSpec = completeSpecs.get(0); + assertThat(completeSpec.referenceKey()).isInstanceOf(PromptReference.class); + PromptReference promptRef = (PromptReference) completeSpec.referenceKey(); + assertThat(promptRef.name()).isEqualTo("test-prompt"); + assertThat(completeSpec.completionHandler()).isNotNull(); + + // Test that the handler works + McpTransportContext context = mock(McpTransportContext.class); + CompleteRequest request = new CompleteRequest(new PromptReference("test-prompt"), + new CompleteRequest.CompleteArgument("test", "value")); + Mono result = completeSpec.completionHandler().apply(context, request); + + StepVerifier.create(result).assertNext(completeResult -> { + assertThat(completeResult).isNotNull(); + assertThat(completeResult.completion()).isNotNull(); + assertThat(completeResult.completion().values()).hasSize(1); + assertThat(completeResult.completion().values().get(0)).isEqualTo("Async completion for value"); + }).verifyComplete(); + } + + @Test + void testGetCompleteSpecificationsWithUriReference() { + class UriComplete { + + @McpComplete(uri = "test://{variable}") + public Mono uriComplete(CompleteRequest request) { + return Mono.just(new CompleteResult(new CompleteCompletion( + List.of("Async URI completion for " + request.argument().value()), 1, false))); + } + + } + + UriComplete completeObject = new UriComplete(); + AsyncStatelessMcpCompleteProvider provider = new AsyncStatelessMcpCompleteProvider(List.of(completeObject)); + + List completeSpecs = provider.getCompleteSpecifications(); + + assertThat(completeSpecs).hasSize(1); + assertThat(completeSpecs.get(0).referenceKey()).isInstanceOf(ResourceReference.class); + ResourceReference resourceRef = (ResourceReference) completeSpecs.get(0).referenceKey(); + assertThat(resourceRef.uri()).isEqualTo("test://{variable}"); + + // Test that the handler works + McpTransportContext context = mock(McpTransportContext.class); + CompleteRequest request = new CompleteRequest(new ResourceReference("test://value"), + new CompleteRequest.CompleteArgument("variable", "value")); + Mono result = completeSpecs.get(0).completionHandler().apply(context, request); + + StepVerifier.create(result).assertNext(completeResult -> { + assertThat(completeResult).isNotNull(); + assertThat(completeResult.completion()).isNotNull(); + assertThat(completeResult.completion().values()).hasSize(1); + assertThat(completeResult.completion().values().get(0)).isEqualTo("Async URI completion for value"); + }).verifyComplete(); + } + + @Test + void testGetCompleteSpecificationsFiltersOutNonReactiveReturnTypes() { + class MixedReturnComplete { + + @McpComplete(prompt = "sync-complete") + public CompleteResult syncComplete(CompleteRequest request) { + return new CompleteResult( + new CompleteCompletion(List.of("Sync completion for " + request.argument().value()), 1, false)); + } + + @McpComplete(prompt = "async-complete") + public Mono asyncComplete(CompleteRequest request) { + return Mono.just(new CompleteResult(new CompleteCompletion( + List.of("Async completion for " + request.argument().value()), 1, false))); + } + + } + + MixedReturnComplete completeObject = new MixedReturnComplete(); + AsyncStatelessMcpCompleteProvider provider = new AsyncStatelessMcpCompleteProvider(List.of(completeObject)); + + List completeSpecs = provider.getCompleteSpecifications(); + + assertThat(completeSpecs).hasSize(1); + PromptReference promptRef = (PromptReference) completeSpecs.get(0).referenceKey(); + assertThat(promptRef.name()).isEqualTo("async-complete"); + } + + @Test + void testGetCompleteSpecificationsWithMultipleCompleteMethods() { + class MultipleCompleteMethods { + + @McpComplete(prompt = "complete1") + public Mono firstComplete(CompleteRequest request) { + return Mono.just(new CompleteResult(new CompleteCompletion( + List.of("First completion for " + request.argument().value()), 1, false))); + } + + @McpComplete(prompt = "complete2") + public Mono secondComplete(CompleteRequest request) { + return Mono.just(new CompleteResult(new CompleteCompletion( + List.of("Second completion for " + request.argument().value()), 1, false))); + } + + } + + MultipleCompleteMethods completeObject = new MultipleCompleteMethods(); + AsyncStatelessMcpCompleteProvider provider = new AsyncStatelessMcpCompleteProvider(List.of(completeObject)); + + List completeSpecs = provider.getCompleteSpecifications(); + + assertThat(completeSpecs).hasSize(2); + PromptReference promptRef1 = (PromptReference) completeSpecs.get(0).referenceKey(); + PromptReference promptRef2 = (PromptReference) completeSpecs.get(1).referenceKey(); + assertThat(promptRef1.name()).isIn("complete1", "complete2"); + assertThat(promptRef2.name()).isIn("complete1", "complete2"); + assertThat(promptRef1.name()).isNotEqualTo(promptRef2.name()); + } + + @Test + void testGetCompleteSpecificationsWithMultipleCompleteObjects() { + class FirstCompleteObject { + + @McpComplete(prompt = "first-complete") + public Mono firstComplete(CompleteRequest request) { + return Mono.just(new CompleteResult(new CompleteCompletion( + List.of("First completion for " + request.argument().value()), 1, false))); + } + + } + + class SecondCompleteObject { + + @McpComplete(prompt = "second-complete") + public Mono secondComplete(CompleteRequest request) { + return Mono.just(new CompleteResult(new CompleteCompletion( + List.of("Second completion for " + request.argument().value()), 1, false))); + } + + } + + FirstCompleteObject firstObject = new FirstCompleteObject(); + SecondCompleteObject secondObject = new SecondCompleteObject(); + AsyncStatelessMcpCompleteProvider provider = new AsyncStatelessMcpCompleteProvider( + List.of(firstObject, secondObject)); + + List completeSpecs = provider.getCompleteSpecifications(); + + assertThat(completeSpecs).hasSize(2); + PromptReference promptRef1 = (PromptReference) completeSpecs.get(0).referenceKey(); + PromptReference promptRef2 = (PromptReference) completeSpecs.get(1).referenceKey(); + assertThat(promptRef1.name()).isIn("first-complete", "second-complete"); + assertThat(promptRef2.name()).isIn("first-complete", "second-complete"); + assertThat(promptRef1.name()).isNotEqualTo(promptRef2.name()); + } + + @Test + void testGetCompleteSpecificationsWithMixedMethods() { + class MixedMethods { + + @McpComplete(prompt = "valid-complete") + public Mono validComplete(CompleteRequest request) { + return Mono.just(new CompleteResult(new CompleteCompletion( + List.of("Valid completion for " + request.argument().value()), 1, false))); + } + + public CompleteResult nonAnnotatedMethod(CompleteRequest request) { + return new CompleteResult(new CompleteCompletion( + List.of("Non-annotated completion for " + request.argument().value()), 1, false)); + } + + @McpComplete(prompt = "sync-complete") + public CompleteResult syncComplete(CompleteRequest request) { + return new CompleteResult( + new CompleteCompletion(List.of("Sync completion for " + request.argument().value()), 1, false)); + } + + } + + MixedMethods completeObject = new MixedMethods(); + AsyncStatelessMcpCompleteProvider provider = new AsyncStatelessMcpCompleteProvider(List.of(completeObject)); + + List completeSpecs = provider.getCompleteSpecifications(); + + assertThat(completeSpecs).hasSize(1); + PromptReference promptRef = (PromptReference) completeSpecs.get(0).referenceKey(); + assertThat(promptRef.name()).isEqualTo("valid-complete"); + } + + @Test + void testGetCompleteSpecificationsWithPrivateMethod() { + class PrivateMethodComplete { + + @McpComplete(prompt = "private-complete") + private Mono privateComplete(CompleteRequest request) { + return Mono.just(new CompleteResult(new CompleteCompletion( + List.of("Private completion for " + request.argument().value()), 1, false))); + } + + } + + PrivateMethodComplete completeObject = new PrivateMethodComplete(); + AsyncStatelessMcpCompleteProvider provider = new AsyncStatelessMcpCompleteProvider(List.of(completeObject)); + + List completeSpecs = provider.getCompleteSpecifications(); + + assertThat(completeSpecs).hasSize(1); + PromptReference promptRef = (PromptReference) completeSpecs.get(0).referenceKey(); + assertThat(promptRef.name()).isEqualTo("private-complete"); + + // Test that the handler works with private methods + McpTransportContext context = mock(McpTransportContext.class); + CompleteRequest request = new CompleteRequest(new PromptReference("private-complete"), + new CompleteRequest.CompleteArgument("test", "value")); + Mono result = completeSpecs.get(0).completionHandler().apply(context, request); + + StepVerifier.create(result).assertNext(completeResult -> { + assertThat(completeResult).isNotNull(); + assertThat(completeResult.completion()).isNotNull(); + assertThat(completeResult.completion().values()).hasSize(1); + assertThat(completeResult.completion().values().get(0)).isEqualTo("Private completion for value"); + }).verifyComplete(); + } + + @Test + void testGetCompleteSpecificationsWithMonoStringReturn() { + class MonoStringReturnComplete { + + @McpComplete(prompt = "mono-string-complete") + public Mono monoStringComplete(CompleteRequest request) { + return Mono.just("Simple string completion for " + request.argument().value()); + } + + } + + MonoStringReturnComplete completeObject = new MonoStringReturnComplete(); + AsyncStatelessMcpCompleteProvider provider = new AsyncStatelessMcpCompleteProvider(List.of(completeObject)); + + List completeSpecs = provider.getCompleteSpecifications(); + + assertThat(completeSpecs).hasSize(1); + PromptReference promptRef = (PromptReference) completeSpecs.get(0).referenceKey(); + assertThat(promptRef.name()).isEqualTo("mono-string-complete"); + + // Test that the handler works with Mono return type + McpTransportContext context = mock(McpTransportContext.class); + CompleteRequest request = new CompleteRequest(new PromptReference("mono-string-complete"), + new CompleteRequest.CompleteArgument("test", "value")); + Mono result = completeSpecs.get(0).completionHandler().apply(context, request); + + StepVerifier.create(result).assertNext(completeResult -> { + assertThat(completeResult).isNotNull(); + assertThat(completeResult.completion()).isNotNull(); + assertThat(completeResult.completion().values()).hasSize(1); + assertThat(completeResult.completion().values().get(0)).isEqualTo("Simple string completion for value"); + }).verifyComplete(); + } + + @Test + void testGetCompleteSpecificationsWithContextParameter() { + class ContextParameterComplete { + + @McpComplete(prompt = "context-complete") + public Mono contextComplete(McpTransportContext context, CompleteRequest request) { + return Mono.just(new CompleteResult(new CompleteCompletion(List.of("Completion with context: " + + (context != null ? "present" : "null") + ", value: " + request.argument().value()), 1, + false))); + } + + } + + ContextParameterComplete completeObject = new ContextParameterComplete(); + AsyncStatelessMcpCompleteProvider provider = new AsyncStatelessMcpCompleteProvider(List.of(completeObject)); + + List completeSpecs = provider.getCompleteSpecifications(); + + assertThat(completeSpecs).hasSize(1); + PromptReference promptRef = (PromptReference) completeSpecs.get(0).referenceKey(); + assertThat(promptRef.name()).isEqualTo("context-complete"); + + // Test that the handler works with context parameter + McpTransportContext context = mock(McpTransportContext.class); + CompleteRequest request = new CompleteRequest(new PromptReference("context-complete"), + new CompleteRequest.CompleteArgument("test", "value")); + Mono result = completeSpecs.get(0).completionHandler().apply(context, request); + + StepVerifier.create(result).assertNext(completeResult -> { + assertThat(completeResult).isNotNull(); + assertThat(completeResult.completion()).isNotNull(); + assertThat(completeResult.completion().values()).hasSize(1); + assertThat(completeResult.completion().values().get(0)) + .isEqualTo("Completion with context: present, value: value"); + }).verifyComplete(); + } + + @Test + void testGetCompleteSpecificationsWithMonoListReturn() { + class MonoListReturnComplete { + + @McpComplete(prompt = "mono-list-complete") + public Mono> monoListComplete(CompleteRequest request) { + return Mono.just(List.of("First completion for " + request.argument().value(), + "Second completion for " + request.argument().value())); + } + + } + + MonoListReturnComplete completeObject = new MonoListReturnComplete(); + AsyncStatelessMcpCompleteProvider provider = new AsyncStatelessMcpCompleteProvider(List.of(completeObject)); + + List completeSpecs = provider.getCompleteSpecifications(); + + assertThat(completeSpecs).hasSize(1); + PromptReference promptRef = (PromptReference) completeSpecs.get(0).referenceKey(); + assertThat(promptRef.name()).isEqualTo("mono-list-complete"); + + // Test that the handler works with Mono> return type + McpTransportContext context = mock(McpTransportContext.class); + CompleteRequest request = new CompleteRequest(new PromptReference("mono-list-complete"), + new CompleteRequest.CompleteArgument("test", "value")); + Mono result = completeSpecs.get(0).completionHandler().apply(context, request); + + StepVerifier.create(result).assertNext(completeResult -> { + assertThat(completeResult).isNotNull(); + assertThat(completeResult.completion()).isNotNull(); + assertThat(completeResult.completion().values()).hasSize(2); + assertThat(completeResult.completion().values().get(0)).isEqualTo("First completion for value"); + assertThat(completeResult.completion().values().get(1)).isEqualTo("Second completion for value"); + }).verifyComplete(); + } + + @Test + void testGetCompleteSpecificationsWithMonoCompletionReturn() { + class MonoCompletionReturnComplete { + + @McpComplete(prompt = "mono-completion-complete") + public Mono monoCompletionComplete(CompleteRequest request) { + return Mono.just(new CompleteCompletion(List.of("Completion object for " + request.argument().value()), + 1, false)); + } + + } + + MonoCompletionReturnComplete completeObject = new MonoCompletionReturnComplete(); + AsyncStatelessMcpCompleteProvider provider = new AsyncStatelessMcpCompleteProvider(List.of(completeObject)); + + List completeSpecs = provider.getCompleteSpecifications(); + + assertThat(completeSpecs).hasSize(1); + PromptReference promptRef = (PromptReference) completeSpecs.get(0).referenceKey(); + assertThat(promptRef.name()).isEqualTo("mono-completion-complete"); + + // Test that the handler works with Mono return type + McpTransportContext context = mock(McpTransportContext.class); + CompleteRequest request = new CompleteRequest(new PromptReference("mono-completion-complete"), + new CompleteRequest.CompleteArgument("test", "value")); + Mono result = completeSpecs.get(0).completionHandler().apply(context, request); + + StepVerifier.create(result).assertNext(completeResult -> { + assertThat(completeResult).isNotNull(); + assertThat(completeResult.completion()).isNotNull(); + assertThat(completeResult.completion().values()).hasSize(1); + assertThat(completeResult.completion().values().get(0)).isEqualTo("Completion object for value"); + }).verifyComplete(); + } + + @Test + void testGetCompleteSpecificationsWithEmptyList() { + AsyncStatelessMcpCompleteProvider provider = new AsyncStatelessMcpCompleteProvider(List.of()); + + List completeSpecs = provider.getCompleteSpecifications(); + + assertThat(completeSpecs).isNotNull(); + assertThat(completeSpecs).isEmpty(); + } + + @Test + void testGetCompleteSpecificationsWithNoValidMethods() { + class NoValidMethods { + + public void voidMethod() { + // No return value + } + + public String nonAnnotatedMethod() { + return "Not annotated"; + } + + } + + NoValidMethods completeObject = new NoValidMethods(); + AsyncStatelessMcpCompleteProvider provider = new AsyncStatelessMcpCompleteProvider(List.of(completeObject)); + + List completeSpecs = provider.getCompleteSpecifications(); + + assertThat(completeSpecs).isNotNull(); + assertThat(completeSpecs).isEmpty(); + } + +} diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/AsyncStatelessMcpPromptProviderTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/AsyncStatelessMcpPromptProviderTests.java new file mode 100644 index 0000000..19c61b5 --- /dev/null +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/AsyncStatelessMcpPromptProviderTests.java @@ -0,0 +1,566 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springaicommunity.mcp.provider; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; +import org.springaicommunity.mcp.annotation.McpArg; +import org.springaicommunity.mcp.annotation.McpPrompt; + +import io.modelcontextprotocol.server.McpStatelessServerFeatures.AsyncPromptSpecification; +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; +import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; +import io.modelcontextprotocol.spec.McpSchema.PromptMessage; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +/** + * Tests for {@link AsyncStatelessMcpPromptProvider}. + * + * @author Christian Tzolov + */ +public class AsyncStatelessMcpPromptProviderTests { + + @Test + void testConstructorWithNullPromptObjects() { + assertThatThrownBy(() -> new AsyncStatelessMcpPromptProvider(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("promptObjects cannot be null"); + } + + @Test + void testGetPromptSpecificationsWithSingleValidPrompt() { + // Create a class with only one valid async prompt method + class SingleValidPrompt { + + @McpPrompt(name = "test-prompt", description = "A test prompt") + public Mono testPrompt(GetPromptRequest request) { + return Mono.just(new GetPromptResult("Test prompt result", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent("Hello from " + request.name()))))); + } + + } + + SingleValidPrompt promptObject = new SingleValidPrompt(); + AsyncStatelessMcpPromptProvider provider = new AsyncStatelessMcpPromptProvider(List.of(promptObject)); + + List promptSpecs = provider.getPromptSpecifications(); + + assertThat(promptSpecs).isNotNull(); + assertThat(promptSpecs).hasSize(1); + + AsyncPromptSpecification promptSpec = promptSpecs.get(0); + assertThat(promptSpec.prompt().name()).isEqualTo("test-prompt"); + assertThat(promptSpec.prompt().description()).isEqualTo("A test prompt"); + assertThat(promptSpec.promptHandler()).isNotNull(); + + // Test that the handler works + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + args.put("name", "John"); + GetPromptRequest request = new GetPromptRequest("test-prompt", args); + Mono result = promptSpec.promptHandler().apply(context, request); + + StepVerifier.create(result).assertNext(promptResult -> { + assertThat(promptResult.description()).isEqualTo("Test prompt result"); + assertThat(promptResult.messages()).hasSize(1); + PromptMessage message = promptResult.messages().get(0); + assertThat(message.role()).isEqualTo(Role.ASSISTANT); + assertThat(((TextContent) message.content()).text()).isEqualTo("Hello from test-prompt"); + }).verifyComplete(); + } + + @Test + void testGetPromptSpecificationsWithCustomPromptName() { + class CustomNamePrompt { + + @McpPrompt(name = "custom-name", description = "Custom named prompt") + public Mono methodWithDifferentName() { + return Mono.just(new GetPromptResult("Custom prompt result", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent("Custom prompt content"))))); + } + + } + + CustomNamePrompt promptObject = new CustomNamePrompt(); + AsyncStatelessMcpPromptProvider provider = new AsyncStatelessMcpPromptProvider(List.of(promptObject)); + + List promptSpecs = provider.getPromptSpecifications(); + + assertThat(promptSpecs).hasSize(1); + assertThat(promptSpecs.get(0).prompt().name()).isEqualTo("custom-name"); + assertThat(promptSpecs.get(0).prompt().description()).isEqualTo("Custom named prompt"); + } + + @Test + void testGetPromptSpecificationsWithDefaultPromptName() { + class DefaultNamePrompt { + + @McpPrompt(description = "Prompt with default name") + public Mono defaultNameMethod() { + return Mono.just(new GetPromptResult("Default prompt result", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent("Default prompt content"))))); + } + + } + + DefaultNamePrompt promptObject = new DefaultNamePrompt(); + AsyncStatelessMcpPromptProvider provider = new AsyncStatelessMcpPromptProvider(List.of(promptObject)); + + List promptSpecs = provider.getPromptSpecifications(); + + assertThat(promptSpecs).hasSize(1); + assertThat(promptSpecs.get(0).prompt().name()).isEqualTo("defaultNameMethod"); + assertThat(promptSpecs.get(0).prompt().description()).isEqualTo("Prompt with default name"); + } + + @Test + void testGetPromptSpecificationsWithEmptyPromptName() { + class EmptyNamePrompt { + + @McpPrompt(name = "", description = "Prompt with empty name") + public Mono emptyNameMethod() { + return Mono.just(new GetPromptResult("Empty name prompt result", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent("Empty name prompt content"))))); + } + + } + + EmptyNamePrompt promptObject = new EmptyNamePrompt(); + AsyncStatelessMcpPromptProvider provider = new AsyncStatelessMcpPromptProvider(List.of(promptObject)); + + List promptSpecs = provider.getPromptSpecifications(); + + assertThat(promptSpecs).hasSize(1); + assertThat(promptSpecs.get(0).prompt().name()).isEqualTo("emptyNameMethod"); + assertThat(promptSpecs.get(0).prompt().description()).isEqualTo("Prompt with empty name"); + } + + @Test + void testGetPromptSpecificationsFiltersOutNonReactiveReturnTypes() { + class MixedReturnPrompt { + + @McpPrompt(name = "sync-prompt", description = "Synchronous prompt") + public GetPromptResult syncPrompt() { + return new GetPromptResult("Sync prompt result", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent("Sync prompt content")))); + } + + @McpPrompt(name = "async-prompt", description = "Asynchronous prompt") + public Mono asyncPrompt() { + return Mono.just(new GetPromptResult("Async prompt result", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent("Async prompt content"))))); + } + + } + + MixedReturnPrompt promptObject = new MixedReturnPrompt(); + AsyncStatelessMcpPromptProvider provider = new AsyncStatelessMcpPromptProvider(List.of(promptObject)); + + List promptSpecs = provider.getPromptSpecifications(); + + assertThat(promptSpecs).hasSize(1); + assertThat(promptSpecs.get(0).prompt().name()).isEqualTo("async-prompt"); + assertThat(promptSpecs.get(0).prompt().description()).isEqualTo("Asynchronous prompt"); + } + + @Test + void testGetPromptSpecificationsWithMultiplePromptMethods() { + class MultiplePromptMethods { + + @McpPrompt(name = "prompt1", description = "First prompt") + public Mono firstPrompt() { + return Mono.just(new GetPromptResult("First prompt result", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent("First prompt content"))))); + } + + @McpPrompt(name = "prompt2", description = "Second prompt") + public Mono secondPrompt() { + return Mono.just(new GetPromptResult("Second prompt result", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent("Second prompt content"))))); + } + + } + + MultiplePromptMethods promptObject = new MultiplePromptMethods(); + AsyncStatelessMcpPromptProvider provider = new AsyncStatelessMcpPromptProvider(List.of(promptObject)); + + List promptSpecs = provider.getPromptSpecifications(); + + assertThat(promptSpecs).hasSize(2); + assertThat(promptSpecs.get(0).prompt().name()).isIn("prompt1", "prompt2"); + assertThat(promptSpecs.get(1).prompt().name()).isIn("prompt1", "prompt2"); + assertThat(promptSpecs.get(0).prompt().name()).isNotEqualTo(promptSpecs.get(1).prompt().name()); + } + + @Test + void testGetPromptSpecificationsWithMultiplePromptObjects() { + class FirstPromptObject { + + @McpPrompt(name = "first-prompt", description = "First prompt") + public Mono firstPrompt() { + return Mono.just(new GetPromptResult("First prompt result", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent("First prompt content"))))); + } + + } + + class SecondPromptObject { + + @McpPrompt(name = "second-prompt", description = "Second prompt") + public Mono secondPrompt() { + return Mono.just(new GetPromptResult("Second prompt result", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent("Second prompt content"))))); + } + + } + + FirstPromptObject firstObject = new FirstPromptObject(); + SecondPromptObject secondObject = new SecondPromptObject(); + AsyncStatelessMcpPromptProvider provider = new AsyncStatelessMcpPromptProvider( + List.of(firstObject, secondObject)); + + List promptSpecs = provider.getPromptSpecifications(); + + assertThat(promptSpecs).hasSize(2); + assertThat(promptSpecs.get(0).prompt().name()).isIn("first-prompt", "second-prompt"); + assertThat(promptSpecs.get(1).prompt().name()).isIn("first-prompt", "second-prompt"); + assertThat(promptSpecs.get(0).prompt().name()).isNotEqualTo(promptSpecs.get(1).prompt().name()); + } + + @Test + void testGetPromptSpecificationsWithMixedMethods() { + class MixedMethods { + + @McpPrompt(name = "valid-prompt", description = "Valid prompt") + public Mono validPrompt() { + return Mono.just(new GetPromptResult("Valid prompt result", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent("Valid prompt content"))))); + } + + public GetPromptResult nonAnnotatedMethod() { + return new GetPromptResult("Non-annotated result", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent("Non-annotated content")))); + } + + @McpPrompt(name = "sync-prompt", description = "Sync prompt") + public GetPromptResult syncPrompt() { + return new GetPromptResult("Sync prompt result", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent("Sync prompt content")))); + } + + } + + MixedMethods promptObject = new MixedMethods(); + AsyncStatelessMcpPromptProvider provider = new AsyncStatelessMcpPromptProvider(List.of(promptObject)); + + List promptSpecs = provider.getPromptSpecifications(); + + assertThat(promptSpecs).hasSize(1); + assertThat(promptSpecs.get(0).prompt().name()).isEqualTo("valid-prompt"); + assertThat(promptSpecs.get(0).prompt().description()).isEqualTo("Valid prompt"); + } + + @Test + void testGetPromptSpecificationsWithArguments() { + class ArgumentPrompt { + + @McpPrompt(name = "argument-prompt", description = "Prompt with arguments") + public Mono argumentPrompt( + @McpArg(name = "name", description = "User's name", required = true) String name, + @McpArg(name = "age", description = "User's age", required = false) Integer age) { + return Mono.just(new GetPromptResult("Argument prompt result", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent( + "Hello " + name + ", you are " + (age != null ? age : "unknown") + " years old"))))); + } + + } + + ArgumentPrompt promptObject = new ArgumentPrompt(); + AsyncStatelessMcpPromptProvider provider = new AsyncStatelessMcpPromptProvider(List.of(promptObject)); + + List promptSpecs = provider.getPromptSpecifications(); + + assertThat(promptSpecs).hasSize(1); + assertThat(promptSpecs.get(0).prompt().name()).isEqualTo("argument-prompt"); + assertThat(promptSpecs.get(0).prompt().arguments()).hasSize(2); + + // Test that the handler works with arguments + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + args.put("name", "John"); + args.put("age", 30); + GetPromptRequest request = new GetPromptRequest("argument-prompt", args); + Mono result = promptSpecs.get(0).promptHandler().apply(context, request); + + StepVerifier.create(result).assertNext(promptResult -> { + assertThat(promptResult.description()).isEqualTo("Argument prompt result"); + assertThat(promptResult.messages()).hasSize(1); + PromptMessage message = promptResult.messages().get(0); + assertThat(message.role()).isEqualTo(Role.ASSISTANT); + assertThat(((TextContent) message.content()).text()).isEqualTo("Hello John, you are 30 years old"); + }).verifyComplete(); + } + + @Test + void testGetPromptSpecificationsWithPrivateMethod() { + class PrivateMethodPrompt { + + @McpPrompt(name = "private-prompt", description = "Private prompt method") + private Mono privatePrompt() { + return Mono.just(new GetPromptResult("Private prompt result", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent("Private prompt content"))))); + } + + } + + PrivateMethodPrompt promptObject = new PrivateMethodPrompt(); + AsyncStatelessMcpPromptProvider provider = new AsyncStatelessMcpPromptProvider(List.of(promptObject)); + + List promptSpecs = provider.getPromptSpecifications(); + + assertThat(promptSpecs).hasSize(1); + assertThat(promptSpecs.get(0).prompt().name()).isEqualTo("private-prompt"); + assertThat(promptSpecs.get(0).prompt().description()).isEqualTo("Private prompt method"); + + // Test that the handler works with private methods + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + GetPromptRequest request = new GetPromptRequest("private-prompt", args); + Mono result = promptSpecs.get(0).promptHandler().apply(context, request); + + StepVerifier.create(result).assertNext(promptResult -> { + assertThat(promptResult.description()).isEqualTo("Private prompt result"); + assertThat(promptResult.messages()).hasSize(1); + PromptMessage message = promptResult.messages().get(0); + assertThat(message.role()).isEqualTo(Role.ASSISTANT); + assertThat(((TextContent) message.content()).text()).isEqualTo("Private prompt content"); + }).verifyComplete(); + } + + @Test + void testGetPromptSpecificationsWithMonoStringReturn() { + class MonoStringReturnPrompt { + + @McpPrompt(name = "mono-string-prompt", description = "Prompt returning Mono") + public Mono monoStringPrompt() { + return Mono.just("Simple string response"); + } + + } + + MonoStringReturnPrompt promptObject = new MonoStringReturnPrompt(); + AsyncStatelessMcpPromptProvider provider = new AsyncStatelessMcpPromptProvider(List.of(promptObject)); + + List promptSpecs = provider.getPromptSpecifications(); + + assertThat(promptSpecs).hasSize(1); + assertThat(promptSpecs.get(0).prompt().name()).isEqualTo("mono-string-prompt"); + + // Test that the handler works with Mono return type + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + GetPromptRequest request = new GetPromptRequest("mono-string-prompt", args); + Mono result = promptSpecs.get(0).promptHandler().apply(context, request); + + StepVerifier.create(result).assertNext(promptResult -> { + assertThat(promptResult.messages()).hasSize(1); + PromptMessage message = promptResult.messages().get(0); + assertThat(message.role()).isEqualTo(Role.ASSISTANT); + assertThat(((TextContent) message.content()).text()).isEqualTo("Simple string response"); + }).verifyComplete(); + } + + @Test + void testGetPromptSpecificationsWithContextParameter() { + class ContextParameterPrompt { + + @McpPrompt(name = "context-prompt", description = "Prompt with context parameter") + public Mono contextPrompt(McpTransportContext context, GetPromptRequest request) { + return Mono.just(new GetPromptResult("Context prompt result", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent("Prompt with context: " + + (context != null ? "present" : "null") + ", name: " + request.name()))))); + } + + } + + ContextParameterPrompt promptObject = new ContextParameterPrompt(); + AsyncStatelessMcpPromptProvider provider = new AsyncStatelessMcpPromptProvider(List.of(promptObject)); + + List promptSpecs = provider.getPromptSpecifications(); + + assertThat(promptSpecs).hasSize(1); + assertThat(promptSpecs.get(0).prompt().name()).isEqualTo("context-prompt"); + + // Test that the handler works with context parameter + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + GetPromptRequest request = new GetPromptRequest("context-prompt", args); + Mono result = promptSpecs.get(0).promptHandler().apply(context, request); + + StepVerifier.create(result).assertNext(promptResult -> { + assertThat(promptResult.description()).isEqualTo("Context prompt result"); + assertThat(promptResult.messages()).hasSize(1); + PromptMessage message = promptResult.messages().get(0); + assertThat(message.role()).isEqualTo(Role.ASSISTANT); + assertThat(((TextContent) message.content()).text()) + .isEqualTo("Prompt with context: present, name: context-prompt"); + }).verifyComplete(); + } + + @Test + void testGetPromptSpecificationsWithRequestParameter() { + class RequestParameterPrompt { + + @McpPrompt(name = "request-prompt", description = "Prompt with request parameter") + public Mono requestPrompt(GetPromptRequest request) { + return Mono.just(new GetPromptResult("Request prompt result", List + .of(new PromptMessage(Role.ASSISTANT, new TextContent("Prompt for name: " + request.name()))))); + } + + } + + RequestParameterPrompt promptObject = new RequestParameterPrompt(); + AsyncStatelessMcpPromptProvider provider = new AsyncStatelessMcpPromptProvider(List.of(promptObject)); + + List promptSpecs = provider.getPromptSpecifications(); + + assertThat(promptSpecs).hasSize(1); + assertThat(promptSpecs.get(0).prompt().name()).isEqualTo("request-prompt"); + + // Test that the handler works with request parameter + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + GetPromptRequest request = new GetPromptRequest("request-prompt", args); + Mono result = promptSpecs.get(0).promptHandler().apply(context, request); + + StepVerifier.create(result).assertNext(promptResult -> { + assertThat(promptResult.description()).isEqualTo("Request prompt result"); + assertThat(promptResult.messages()).hasSize(1); + PromptMessage message = promptResult.messages().get(0); + assertThat(message.role()).isEqualTo(Role.ASSISTANT); + assertThat(((TextContent) message.content()).text()).isEqualTo("Prompt for name: request-prompt"); + }).verifyComplete(); + } + + @Test + void testGetPromptSpecificationsWithMonoMessagesList() { + class MonoMessagesListPrompt { + + @McpPrompt(name = "mono-messages-list-prompt", description = "Prompt returning Mono>") + public Mono> monoMessagesListPrompt() { + return Mono.just(List.of(new PromptMessage(Role.ASSISTANT, new TextContent("First message")), + new PromptMessage(Role.ASSISTANT, new TextContent("Second message")))); + } + + } + + MonoMessagesListPrompt promptObject = new MonoMessagesListPrompt(); + AsyncStatelessMcpPromptProvider provider = new AsyncStatelessMcpPromptProvider(List.of(promptObject)); + + List promptSpecs = provider.getPromptSpecifications(); + + assertThat(promptSpecs).hasSize(1); + assertThat(promptSpecs.get(0).prompt().name()).isEqualTo("mono-messages-list-prompt"); + + // Test that the handler works with Mono> return type + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + GetPromptRequest request = new GetPromptRequest("mono-messages-list-prompt", args); + Mono result = promptSpecs.get(0).promptHandler().apply(context, request); + + StepVerifier.create(result).assertNext(promptResult -> { + assertThat(promptResult.messages()).hasSize(2); + assertThat(((TextContent) promptResult.messages().get(0).content()).text()).isEqualTo("First message"); + assertThat(((TextContent) promptResult.messages().get(1).content()).text()).isEqualTo("Second message"); + }).verifyComplete(); + } + + @Test + void testGetPromptSpecificationsWithMonoSingleMessage() { + class MonoSingleMessagePrompt { + + @McpPrompt(name = "mono-single-message-prompt", description = "Prompt returning Mono") + public Mono monoSingleMessagePrompt() { + return Mono.just(new PromptMessage(Role.ASSISTANT, new TextContent("Single message"))); + } + + } + + MonoSingleMessagePrompt promptObject = new MonoSingleMessagePrompt(); + AsyncStatelessMcpPromptProvider provider = new AsyncStatelessMcpPromptProvider(List.of(promptObject)); + + List promptSpecs = provider.getPromptSpecifications(); + + assertThat(promptSpecs).hasSize(1); + assertThat(promptSpecs.get(0).prompt().name()).isEqualTo("mono-single-message-prompt"); + + // Test that the handler works with Mono return type + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + GetPromptRequest request = new GetPromptRequest("mono-single-message-prompt", args); + Mono result = promptSpecs.get(0).promptHandler().apply(context, request); + + StepVerifier.create(result).assertNext(promptResult -> { + assertThat(promptResult.messages()).hasSize(1); + assertThat(((TextContent) promptResult.messages().get(0).content()).text()).isEqualTo("Single message"); + }).verifyComplete(); + } + + @Test + void testGetPromptSpecificationsWithMonoStringList() { + class MonoStringListPrompt { + + @McpPrompt(name = "mono-string-list-prompt", description = "Prompt returning Mono>") + public Mono> monoStringListPrompt() { + return Mono.just(List.of("First string", "Second string", "Third string")); + } + + } + + MonoStringListPrompt promptObject = new MonoStringListPrompt(); + AsyncStatelessMcpPromptProvider provider = new AsyncStatelessMcpPromptProvider(List.of(promptObject)); + + List promptSpecs = provider.getPromptSpecifications(); + + assertThat(promptSpecs).hasSize(1); + assertThat(promptSpecs.get(0).prompt().name()).isEqualTo("mono-string-list-prompt"); + + // Test that the handler works with Mono> return type + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + GetPromptRequest request = new GetPromptRequest("mono-string-list-prompt", args); + Mono result = promptSpecs.get(0).promptHandler().apply(context, request); + + StepVerifier.create(result).assertNext(promptResult -> { + assertThat(promptResult.messages()).hasSize(3); + assertThat(((TextContent) promptResult.messages().get(0).content()).text()).isEqualTo("First string"); + assertThat(((TextContent) promptResult.messages().get(1).content()).text()).isEqualTo("Second string"); + assertThat(((TextContent) promptResult.messages().get(2).content()).text()).isEqualTo("Third string"); + }).verifyComplete(); + } + +} diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/AsyncStatelessMcpResourceProviderTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/AsyncStatelessMcpResourceProviderTests.java new file mode 100644 index 0000000..311573e --- /dev/null +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/AsyncStatelessMcpResourceProviderTests.java @@ -0,0 +1,493 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springaicommunity.mcp.provider; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.springaicommunity.mcp.annotation.McpResource; + +import io.modelcontextprotocol.server.McpStatelessServerFeatures.AsyncResourceSpecification; +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceRequest; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.ResourceContents; +import io.modelcontextprotocol.spec.McpSchema.TextResourceContents; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +/** + * Tests for {@link AsyncStatelessMcpResourceProvider}. + * + * @author Christian Tzolov + */ +public class AsyncStatelessMcpResourceProviderTests { + + @Test + void testConstructorWithNullResourceObjects() { + assertThatThrownBy(() -> new AsyncStatelessMcpResourceProvider(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("resourceObjects cannot be null"); + } + + @Test + void testGetResourceSpecificationsWithSingleValidResource() { + // Create a class with only one valid async resource method + class SingleValidResource { + + @McpResource(uri = "test://resource/{id}", name = "test-resource", description = "A test resource") + public Mono testResource(String id) { + return Mono.just("Resource content for: " + id); + } + + } + + SingleValidResource resourceObject = new SingleValidResource(); + AsyncStatelessMcpResourceProvider provider = new AsyncStatelessMcpResourceProvider(List.of(resourceObject)); + + List resourceSpecs = provider.getResourceSpecifications(); + + assertThat(resourceSpecs).isNotNull(); + assertThat(resourceSpecs).hasSize(1); + + AsyncResourceSpecification resourceSpec = resourceSpecs.get(0); + assertThat(resourceSpec.resource().uri()).isEqualTo("test://resource/{id}"); + assertThat(resourceSpec.resource().name()).isEqualTo("test-resource"); + assertThat(resourceSpec.resource().description()).isEqualTo("A test resource"); + assertThat(resourceSpec.readHandler()).isNotNull(); + + // Test that the handler works + McpTransportContext context = mock(McpTransportContext.class); + ReadResourceRequest request = new ReadResourceRequest("test://resource/123"); + Mono result = resourceSpec.readHandler().apply(context, request); + + StepVerifier.create(result).assertNext(readResult -> { + assertThat(readResult.contents()).hasSize(1); + ResourceContents content = readResult.contents().get(0); + assertThat(content).isInstanceOf(TextResourceContents.class); + assertThat(((TextResourceContents) content).text()).isEqualTo("Resource content for: 123"); + }).verifyComplete(); + } + + @Test + void testGetResourceSpecificationsWithCustomResourceName() { + class CustomNameResource { + + @McpResource(uri = "custom://resource", name = "custom-name", description = "Custom named resource") + public Mono methodWithDifferentName() { + return Mono.just("Custom resource content"); + } + + } + + CustomNameResource resourceObject = new CustomNameResource(); + AsyncStatelessMcpResourceProvider provider = new AsyncStatelessMcpResourceProvider(List.of(resourceObject)); + + List resourceSpecs = provider.getResourceSpecifications(); + + assertThat(resourceSpecs).hasSize(1); + assertThat(resourceSpecs.get(0).resource().name()).isEqualTo("custom-name"); + assertThat(resourceSpecs.get(0).resource().description()).isEqualTo("Custom named resource"); + } + + @Test + void testGetResourceSpecificationsWithDefaultResourceName() { + class DefaultNameResource { + + @McpResource(uri = "default://resource", description = "Resource with default name") + public Mono defaultNameMethod() { + return Mono.just("Default resource content"); + } + + } + + DefaultNameResource resourceObject = new DefaultNameResource(); + AsyncStatelessMcpResourceProvider provider = new AsyncStatelessMcpResourceProvider(List.of(resourceObject)); + + List resourceSpecs = provider.getResourceSpecifications(); + + assertThat(resourceSpecs).hasSize(1); + assertThat(resourceSpecs.get(0).resource().name()).isEqualTo("defaultNameMethod"); + assertThat(resourceSpecs.get(0).resource().description()).isEqualTo("Resource with default name"); + } + + @Test + void testGetResourceSpecificationsWithEmptyResourceName() { + class EmptyNameResource { + + @McpResource(uri = "empty://resource", name = "", description = "Resource with empty name") + public Mono emptyNameMethod() { + return Mono.just("Empty name resource content"); + } + + } + + EmptyNameResource resourceObject = new EmptyNameResource(); + AsyncStatelessMcpResourceProvider provider = new AsyncStatelessMcpResourceProvider(List.of(resourceObject)); + + List resourceSpecs = provider.getResourceSpecifications(); + + assertThat(resourceSpecs).hasSize(1); + assertThat(resourceSpecs.get(0).resource().name()).isEqualTo("emptyNameMethod"); + assertThat(resourceSpecs.get(0).resource().description()).isEqualTo("Resource with empty name"); + } + + @Test + void testGetResourceSpecificationsFiltersOutNonReactiveReturnTypes() { + class MixedReturnResource { + + @McpResource(uri = "sync://resource", name = "sync-resource", description = "Synchronous resource") + public String syncResource() { + return "Sync resource content"; + } + + @McpResource(uri = "async://resource", name = "async-resource", description = "Asynchronous resource") + public Mono asyncResource() { + return Mono.just("Async resource content"); + } + + } + + MixedReturnResource resourceObject = new MixedReturnResource(); + AsyncStatelessMcpResourceProvider provider = new AsyncStatelessMcpResourceProvider(List.of(resourceObject)); + + List resourceSpecs = provider.getResourceSpecifications(); + + assertThat(resourceSpecs).hasSize(1); + assertThat(resourceSpecs.get(0).resource().name()).isEqualTo("async-resource"); + assertThat(resourceSpecs.get(0).resource().description()).isEqualTo("Asynchronous resource"); + } + + @Test + void testGetResourceSpecificationsWithMultipleResourceMethods() { + class MultipleResourceMethods { + + @McpResource(uri = "first://resource", name = "resource1", description = "First resource") + public Mono firstResource() { + return Mono.just("First resource content"); + } + + @McpResource(uri = "second://resource", name = "resource2", description = "Second resource") + public Mono secondResource() { + return Mono.just("Second resource content"); + } + + } + + MultipleResourceMethods resourceObject = new MultipleResourceMethods(); + AsyncStatelessMcpResourceProvider provider = new AsyncStatelessMcpResourceProvider(List.of(resourceObject)); + + List resourceSpecs = provider.getResourceSpecifications(); + + assertThat(resourceSpecs).hasSize(2); + assertThat(resourceSpecs.get(0).resource().name()).isIn("resource1", "resource2"); + assertThat(resourceSpecs.get(1).resource().name()).isIn("resource1", "resource2"); + assertThat(resourceSpecs.get(0).resource().name()).isNotEqualTo(resourceSpecs.get(1).resource().name()); + } + + @Test + void testGetResourceSpecificationsWithMultipleResourceObjects() { + class FirstResourceObject { + + @McpResource(uri = "first://resource", name = "first-resource", description = "First resource") + public Mono firstResource() { + return Mono.just("First resource content"); + } + + } + + class SecondResourceObject { + + @McpResource(uri = "second://resource", name = "second-resource", description = "Second resource") + public Mono secondResource() { + return Mono.just("Second resource content"); + } + + } + + FirstResourceObject firstObject = new FirstResourceObject(); + SecondResourceObject secondObject = new SecondResourceObject(); + AsyncStatelessMcpResourceProvider provider = new AsyncStatelessMcpResourceProvider( + List.of(firstObject, secondObject)); + + List resourceSpecs = provider.getResourceSpecifications(); + + assertThat(resourceSpecs).hasSize(2); + assertThat(resourceSpecs.get(0).resource().name()).isIn("first-resource", "second-resource"); + assertThat(resourceSpecs.get(1).resource().name()).isIn("first-resource", "second-resource"); + assertThat(resourceSpecs.get(0).resource().name()).isNotEqualTo(resourceSpecs.get(1).resource().name()); + } + + @Test + void testGetResourceSpecificationsWithMixedMethods() { + class MixedMethods { + + @McpResource(uri = "valid://resource", name = "valid-resource", description = "Valid resource") + public Mono validResource() { + return Mono.just("Valid resource content"); + } + + public String nonAnnotatedMethod() { + return "Non-annotated resource content"; + } + + @McpResource(uri = "sync://resource", name = "sync-resource", description = "Sync resource") + public String syncResource() { + return "Sync resource content"; + } + + } + + MixedMethods resourceObject = new MixedMethods(); + AsyncStatelessMcpResourceProvider provider = new AsyncStatelessMcpResourceProvider(List.of(resourceObject)); + + List resourceSpecs = provider.getResourceSpecifications(); + + assertThat(resourceSpecs).hasSize(1); + assertThat(resourceSpecs.get(0).resource().name()).isEqualTo("valid-resource"); + assertThat(resourceSpecs.get(0).resource().description()).isEqualTo("Valid resource"); + } + + @Test + void testGetResourceSpecificationsWithUriVariables() { + class UriVariableResource { + + @McpResource(uri = "variable://resource/{id}/{type}", name = "variable-resource", + description = "Resource with URI variables") + public Mono variableResource(String id, String type) { + return Mono.just(String.format("Resource content for id: %s, type: %s", id, type)); + } + + } + + UriVariableResource resourceObject = new UriVariableResource(); + AsyncStatelessMcpResourceProvider provider = new AsyncStatelessMcpResourceProvider(List.of(resourceObject)); + + List resourceSpecs = provider.getResourceSpecifications(); + + assertThat(resourceSpecs).hasSize(1); + assertThat(resourceSpecs.get(0).resource().uri()).isEqualTo("variable://resource/{id}/{type}"); + assertThat(resourceSpecs.get(0).resource().name()).isEqualTo("variable-resource"); + + // Test that the handler works with URI variables + McpTransportContext context = mock(McpTransportContext.class); + ReadResourceRequest request = new ReadResourceRequest("variable://resource/123/document"); + Mono result = resourceSpecs.get(0).readHandler().apply(context, request); + + StepVerifier.create(result).assertNext(readResult -> { + assertThat(readResult.contents()).hasSize(1); + ResourceContents content = readResult.contents().get(0); + assertThat(content).isInstanceOf(TextResourceContents.class); + assertThat(((TextResourceContents) content).text()) + .isEqualTo("Resource content for id: 123, type: document"); + }).verifyComplete(); + } + + @Test + void testGetResourceSpecificationsWithMimeType() { + class MimeTypeResource { + + @McpResource(uri = "mime://resource", name = "mime-resource", description = "Resource with MIME type", + mimeType = "application/json") + public Mono mimeTypeResource() { + return Mono.just("{\"message\": \"JSON resource content\"}"); + } + + } + + MimeTypeResource resourceObject = new MimeTypeResource(); + AsyncStatelessMcpResourceProvider provider = new AsyncStatelessMcpResourceProvider(List.of(resourceObject)); + + List resourceSpecs = provider.getResourceSpecifications(); + + assertThat(resourceSpecs).hasSize(1); + assertThat(resourceSpecs.get(0).resource().mimeType()).isEqualTo("application/json"); + assertThat(resourceSpecs.get(0).resource().name()).isEqualTo("mime-resource"); + } + + @Test + void testGetResourceSpecificationsWithPrivateMethod() { + class PrivateMethodResource { + + @McpResource(uri = "private://resource", name = "private-resource", description = "Private resource method") + private Mono privateResource() { + return Mono.just("Private resource content"); + } + + } + + PrivateMethodResource resourceObject = new PrivateMethodResource(); + AsyncStatelessMcpResourceProvider provider = new AsyncStatelessMcpResourceProvider(List.of(resourceObject)); + + List resourceSpecs = provider.getResourceSpecifications(); + + assertThat(resourceSpecs).hasSize(1); + assertThat(resourceSpecs.get(0).resource().name()).isEqualTo("private-resource"); + assertThat(resourceSpecs.get(0).resource().description()).isEqualTo("Private resource method"); + + // Test that the handler works with private methods + McpTransportContext context = mock(McpTransportContext.class); + ReadResourceRequest request = new ReadResourceRequest("private://resource"); + Mono result = resourceSpecs.get(0).readHandler().apply(context, request); + + StepVerifier.create(result).assertNext(readResult -> { + assertThat(readResult.contents()).hasSize(1); + ResourceContents content = readResult.contents().get(0); + assertThat(content).isInstanceOf(TextResourceContents.class); + assertThat(((TextResourceContents) content).text()).isEqualTo("Private resource content"); + }).verifyComplete(); + } + + @Test + void testGetResourceSpecificationsWithResourceContentsList() { + class ResourceContentsListResource { + + @McpResource(uri = "list://resource", name = "list-resource", description = "Resource returning list") + public Mono> listResource() { + return Mono.just(List.of("First content", "Second content")); + } + + } + + ResourceContentsListResource resourceObject = new ResourceContentsListResource(); + AsyncStatelessMcpResourceProvider provider = new AsyncStatelessMcpResourceProvider(List.of(resourceObject)); + + List resourceSpecs = provider.getResourceSpecifications(); + + assertThat(resourceSpecs).hasSize(1); + assertThat(resourceSpecs.get(0).resource().name()).isEqualTo("list-resource"); + + // Test that the handler works with list return type + McpTransportContext context = mock(McpTransportContext.class); + ReadResourceRequest request = new ReadResourceRequest("list://resource"); + Mono result = resourceSpecs.get(0).readHandler().apply(context, request); + + StepVerifier.create(result).assertNext(readResult -> { + assertThat(readResult.contents()).hasSize(2); + assertThat(readResult.contents().get(0)).isInstanceOf(TextResourceContents.class); + assertThat(readResult.contents().get(1)).isInstanceOf(TextResourceContents.class); + assertThat(((TextResourceContents) readResult.contents().get(0)).text()).isEqualTo("First content"); + assertThat(((TextResourceContents) readResult.contents().get(1)).text()).isEqualTo("Second content"); + }).verifyComplete(); + } + + @Test + void testGetResourceSpecificationsWithContextParameter() { + class ContextParameterResource { + + @McpResource(uri = "context://resource", name = "context-resource", + description = "Resource with context parameter") + public Mono contextResource(McpTransportContext context, ReadResourceRequest request) { + return Mono.just( + "Resource with context: " + (context != null ? "present" : "null") + ", URI: " + request.uri()); + } + + } + + ContextParameterResource resourceObject = new ContextParameterResource(); + AsyncStatelessMcpResourceProvider provider = new AsyncStatelessMcpResourceProvider(List.of(resourceObject)); + + List resourceSpecs = provider.getResourceSpecifications(); + + assertThat(resourceSpecs).hasSize(1); + assertThat(resourceSpecs.get(0).resource().name()).isEqualTo("context-resource"); + + // Test that the handler works with context parameter + McpTransportContext context = mock(McpTransportContext.class); + ReadResourceRequest request = new ReadResourceRequest("context://resource"); + Mono result = resourceSpecs.get(0).readHandler().apply(context, request); + + StepVerifier.create(result).assertNext(readResult -> { + assertThat(readResult.contents()).hasSize(1); + ResourceContents content = readResult.contents().get(0); + assertThat(content).isInstanceOf(TextResourceContents.class); + assertThat(((TextResourceContents) content).text()) + .isEqualTo("Resource with context: present, URI: context://resource"); + }).verifyComplete(); + } + + @Test + void testGetResourceSpecificationsWithRequestParameter() { + class RequestParameterResource { + + @McpResource(uri = "request://resource", name = "request-resource", + description = "Resource with request parameter") + public Mono requestResource(ReadResourceRequest request) { + return Mono.just("Resource for URI: " + request.uri()); + } + + } + + RequestParameterResource resourceObject = new RequestParameterResource(); + AsyncStatelessMcpResourceProvider provider = new AsyncStatelessMcpResourceProvider(List.of(resourceObject)); + + List resourceSpecs = provider.getResourceSpecifications(); + + assertThat(resourceSpecs).hasSize(1); + assertThat(resourceSpecs.get(0).resource().name()).isEqualTo("request-resource"); + + // Test that the handler works with request parameter + McpTransportContext context = mock(McpTransportContext.class); + ReadResourceRequest request = new ReadResourceRequest("request://resource"); + Mono result = resourceSpecs.get(0).readHandler().apply(context, request); + + StepVerifier.create(result).assertNext(readResult -> { + assertThat(readResult.contents()).hasSize(1); + ResourceContents content = readResult.contents().get(0); + assertThat(content).isInstanceOf(TextResourceContents.class); + assertThat(((TextResourceContents) content).text()).isEqualTo("Resource for URI: request://resource"); + }).verifyComplete(); + } + + @Test + void testGetResourceSpecificationsWithSyncMethodReturningMono() { + class SyncMethodReturningMono { + + @McpResource(uri = "sync-mono://resource", name = "sync-mono-resource", + description = "Sync method returning Mono") + public Mono syncMethodReturningMono() { + return Mono.just("Sync method returning Mono content"); + } + + } + + SyncMethodReturningMono resourceObject = new SyncMethodReturningMono(); + AsyncStatelessMcpResourceProvider provider = new AsyncStatelessMcpResourceProvider(List.of(resourceObject)); + + List resourceSpecs = provider.getResourceSpecifications(); + + assertThat(resourceSpecs).hasSize(1); + assertThat(resourceSpecs.get(0).resource().name()).isEqualTo("sync-mono-resource"); + + // Test that the handler works with sync method returning Mono + McpTransportContext context = mock(McpTransportContext.class); + ReadResourceRequest request = new ReadResourceRequest("sync-mono://resource"); + Mono result = resourceSpecs.get(0).readHandler().apply(context, request); + + StepVerifier.create(result).assertNext(readResult -> { + assertThat(readResult.contents()).hasSize(1); + ResourceContents content = readResult.contents().get(0); + assertThat(content).isInstanceOf(TextResourceContents.class); + assertThat(((TextResourceContents) content).text()).isEqualTo("Sync method returning Mono content"); + }).verifyComplete(); + } + +} diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/SyncStatelessMcpCompleteProviderTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/SyncStatelessMcpCompleteProviderTests.java new file mode 100644 index 0000000..4bf80a1 --- /dev/null +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/SyncStatelessMcpCompleteProviderTests.java @@ -0,0 +1,451 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springaicommunity.mcp.provider; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.springaicommunity.mcp.annotation.McpComplete; + +import io.modelcontextprotocol.server.McpStatelessServerFeatures.SyncCompletionSpecification; +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema.CompleteRequest; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult.CompleteCompletion; +import io.modelcontextprotocol.spec.McpSchema.PromptReference; +import io.modelcontextprotocol.spec.McpSchema.ResourceReference; +import reactor.core.publisher.Mono; + +/** + * Tests for {@link SyncStatelessMcpCompleteProvider}. + * + * @author Christian Tzolov + */ +public class SyncStatelessMcpCompleteProviderTests { + + @Test + void testConstructorWithNullCompleteObjects() { + assertThatThrownBy(() -> new SyncStatelessMcpCompleteProvider(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("completeObjects cannot be null"); + } + + @Test + void testGetCompleteSpecificationsWithSingleValidComplete() { + // Create a class with only one valid sync complete method + class SingleValidComplete { + + @McpComplete(prompt = "test-prompt") + public CompleteResult testComplete(CompleteRequest request) { + return new CompleteResult( + new CompleteCompletion(List.of("Sync completion for " + request.argument().value()), 1, false)); + } + + } + + SingleValidComplete completeObject = new SingleValidComplete(); + SyncStatelessMcpCompleteProvider provider = new SyncStatelessMcpCompleteProvider(List.of(completeObject)); + + List completeSpecs = provider.getCompleteSpecifications(); + + assertThat(completeSpecs).isNotNull(); + assertThat(completeSpecs).hasSize(1); + + SyncCompletionSpecification completeSpec = completeSpecs.get(0); + assertThat(completeSpec.referenceKey()).isInstanceOf(PromptReference.class); + PromptReference promptRef = (PromptReference) completeSpec.referenceKey(); + assertThat(promptRef.name()).isEqualTo("test-prompt"); + assertThat(completeSpec.completionHandler()).isNotNull(); + + // Test that the handler works + McpTransportContext context = mock(McpTransportContext.class); + CompleteRequest request = new CompleteRequest(new PromptReference("test-prompt"), + new CompleteRequest.CompleteArgument("test", "value")); + CompleteResult result = completeSpec.completionHandler().apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.completion()).isNotNull(); + assertThat(result.completion().values()).hasSize(1); + assertThat(result.completion().values().get(0)).isEqualTo("Sync completion for value"); + } + + @Test + void testGetCompleteSpecificationsWithUriReference() { + class UriComplete { + + @McpComplete(uri = "test://{variable}") + public CompleteResult uriComplete(CompleteRequest request) { + return new CompleteResult(new CompleteCompletion( + List.of("Sync URI completion for " + request.argument().value()), 1, false)); + } + + } + + UriComplete completeObject = new UriComplete(); + SyncStatelessMcpCompleteProvider provider = new SyncStatelessMcpCompleteProvider(List.of(completeObject)); + + List completeSpecs = provider.getCompleteSpecifications(); + + assertThat(completeSpecs).hasSize(1); + assertThat(completeSpecs.get(0).referenceKey()).isInstanceOf(ResourceReference.class); + ResourceReference resourceRef = (ResourceReference) completeSpecs.get(0).referenceKey(); + assertThat(resourceRef.uri()).isEqualTo("test://{variable}"); + + // Test that the handler works + McpTransportContext context = mock(McpTransportContext.class); + CompleteRequest request = new CompleteRequest(new ResourceReference("test://value"), + new CompleteRequest.CompleteArgument("variable", "value")); + CompleteResult result = completeSpecs.get(0).completionHandler().apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.completion()).isNotNull(); + assertThat(result.completion().values()).hasSize(1); + assertThat(result.completion().values().get(0)).isEqualTo("Sync URI completion for value"); + } + + @Test + void testGetCompleteSpecificationsFiltersOutReactiveReturnTypes() { + class MixedReturnComplete { + + @McpComplete(prompt = "sync-complete") + public CompleteResult syncComplete(CompleteRequest request) { + return new CompleteResult( + new CompleteCompletion(List.of("Sync completion for " + request.argument().value()), 1, false)); + } + + @McpComplete(prompt = "async-complete") + public Mono asyncComplete(CompleteRequest request) { + return Mono.just(new CompleteResult(new CompleteCompletion( + List.of("Async completion for " + request.argument().value()), 1, false))); + } + + } + + MixedReturnComplete completeObject = new MixedReturnComplete(); + SyncStatelessMcpCompleteProvider provider = new SyncStatelessMcpCompleteProvider(List.of(completeObject)); + + List completeSpecs = provider.getCompleteSpecifications(); + + assertThat(completeSpecs).hasSize(1); + PromptReference promptRef = (PromptReference) completeSpecs.get(0).referenceKey(); + assertThat(promptRef.name()).isEqualTo("sync-complete"); + } + + @Test + void testGetCompleteSpecificationsWithMultipleCompleteMethods() { + class MultipleCompleteMethods { + + @McpComplete(prompt = "complete1") + public CompleteResult firstComplete(CompleteRequest request) { + return new CompleteResult(new CompleteCompletion( + List.of("First completion for " + request.argument().value()), 1, false)); + } + + @McpComplete(prompt = "complete2") + public CompleteResult secondComplete(CompleteRequest request) { + return new CompleteResult(new CompleteCompletion( + List.of("Second completion for " + request.argument().value()), 1, false)); + } + + } + + MultipleCompleteMethods completeObject = new MultipleCompleteMethods(); + SyncStatelessMcpCompleteProvider provider = new SyncStatelessMcpCompleteProvider(List.of(completeObject)); + + List completeSpecs = provider.getCompleteSpecifications(); + + assertThat(completeSpecs).hasSize(2); + PromptReference promptRef1 = (PromptReference) completeSpecs.get(0).referenceKey(); + PromptReference promptRef2 = (PromptReference) completeSpecs.get(1).referenceKey(); + assertThat(promptRef1.name()).isIn("complete1", "complete2"); + assertThat(promptRef2.name()).isIn("complete1", "complete2"); + assertThat(promptRef1.name()).isNotEqualTo(promptRef2.name()); + } + + @Test + void testGetCompleteSpecificationsWithMultipleCompleteObjects() { + class FirstCompleteObject { + + @McpComplete(prompt = "first-complete") + public CompleteResult firstComplete(CompleteRequest request) { + return new CompleteResult(new CompleteCompletion( + List.of("First completion for " + request.argument().value()), 1, false)); + } + + } + + class SecondCompleteObject { + + @McpComplete(prompt = "second-complete") + public CompleteResult secondComplete(CompleteRequest request) { + return new CompleteResult(new CompleteCompletion( + List.of("Second completion for " + request.argument().value()), 1, false)); + } + + } + + FirstCompleteObject firstObject = new FirstCompleteObject(); + SecondCompleteObject secondObject = new SecondCompleteObject(); + SyncStatelessMcpCompleteProvider provider = new SyncStatelessMcpCompleteProvider( + List.of(firstObject, secondObject)); + + List completeSpecs = provider.getCompleteSpecifications(); + + assertThat(completeSpecs).hasSize(2); + PromptReference promptRef1 = (PromptReference) completeSpecs.get(0).referenceKey(); + PromptReference promptRef2 = (PromptReference) completeSpecs.get(1).referenceKey(); + assertThat(promptRef1.name()).isIn("first-complete", "second-complete"); + assertThat(promptRef2.name()).isIn("first-complete", "second-complete"); + assertThat(promptRef1.name()).isNotEqualTo(promptRef2.name()); + } + + @Test + void testGetCompleteSpecificationsWithMixedMethods() { + class MixedMethods { + + @McpComplete(prompt = "valid-complete") + public CompleteResult validComplete(CompleteRequest request) { + return new CompleteResult(new CompleteCompletion( + List.of("Valid completion for " + request.argument().value()), 1, false)); + } + + public CompleteResult nonAnnotatedMethod(CompleteRequest request) { + return new CompleteResult(new CompleteCompletion( + List.of("Non-annotated completion for " + request.argument().value()), 1, false)); + } + + @McpComplete(prompt = "async-complete") + public Mono asyncComplete(CompleteRequest request) { + return Mono.just(new CompleteResult(new CompleteCompletion( + List.of("Async completion for " + request.argument().value()), 1, false))); + } + + } + + MixedMethods completeObject = new MixedMethods(); + SyncStatelessMcpCompleteProvider provider = new SyncStatelessMcpCompleteProvider(List.of(completeObject)); + + List completeSpecs = provider.getCompleteSpecifications(); + + assertThat(completeSpecs).hasSize(1); + PromptReference promptRef = (PromptReference) completeSpecs.get(0).referenceKey(); + assertThat(promptRef.name()).isEqualTo("valid-complete"); + } + + @Test + void testGetCompleteSpecificationsWithPrivateMethod() { + class PrivateMethodComplete { + + @McpComplete(prompt = "private-complete") + private CompleteResult privateComplete(CompleteRequest request) { + return new CompleteResult(new CompleteCompletion( + List.of("Private completion for " + request.argument().value()), 1, false)); + } + + } + + PrivateMethodComplete completeObject = new PrivateMethodComplete(); + SyncStatelessMcpCompleteProvider provider = new SyncStatelessMcpCompleteProvider(List.of(completeObject)); + + List completeSpecs = provider.getCompleteSpecifications(); + + assertThat(completeSpecs).hasSize(1); + PromptReference promptRef = (PromptReference) completeSpecs.get(0).referenceKey(); + assertThat(promptRef.name()).isEqualTo("private-complete"); + + // Test that the handler works with private methods + McpTransportContext context = mock(McpTransportContext.class); + CompleteRequest request = new CompleteRequest(new PromptReference("private-complete"), + new CompleteRequest.CompleteArgument("test", "value")); + CompleteResult result = completeSpecs.get(0).completionHandler().apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.completion()).isNotNull(); + assertThat(result.completion().values()).hasSize(1); + assertThat(result.completion().values().get(0)).isEqualTo("Private completion for value"); + } + + @Test + void testGetCompleteSpecificationsWithStringReturn() { + class StringReturnComplete { + + @McpComplete(prompt = "string-complete") + public String stringComplete(CompleteRequest request) { + return "Simple string completion for " + request.argument().value(); + } + + } + + StringReturnComplete completeObject = new StringReturnComplete(); + SyncStatelessMcpCompleteProvider provider = new SyncStatelessMcpCompleteProvider(List.of(completeObject)); + + List completeSpecs = provider.getCompleteSpecifications(); + + assertThat(completeSpecs).hasSize(1); + PromptReference promptRef = (PromptReference) completeSpecs.get(0).referenceKey(); + assertThat(promptRef.name()).isEqualTo("string-complete"); + + // Test that the handler works with String return type + McpTransportContext context = mock(McpTransportContext.class); + CompleteRequest request = new CompleteRequest(new PromptReference("string-complete"), + new CompleteRequest.CompleteArgument("test", "value")); + CompleteResult result = completeSpecs.get(0).completionHandler().apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.completion()).isNotNull(); + assertThat(result.completion().values()).hasSize(1); + assertThat(result.completion().values().get(0)).isEqualTo("Simple string completion for value"); + } + + @Test + void testGetCompleteSpecificationsWithContextParameter() { + class ContextParameterComplete { + + @McpComplete(prompt = "context-complete") + public CompleteResult contextComplete(McpTransportContext context, CompleteRequest request) { + return new CompleteResult(new CompleteCompletion(List.of("Completion with context: " + + (context != null ? "present" : "null") + ", value: " + request.argument().value()), 1, + false)); + } + + } + + ContextParameterComplete completeObject = new ContextParameterComplete(); + SyncStatelessMcpCompleteProvider provider = new SyncStatelessMcpCompleteProvider(List.of(completeObject)); + + List completeSpecs = provider.getCompleteSpecifications(); + + assertThat(completeSpecs).hasSize(1); + PromptReference promptRef = (PromptReference) completeSpecs.get(0).referenceKey(); + assertThat(promptRef.name()).isEqualTo("context-complete"); + + // Test that the handler works with context parameter + McpTransportContext context = mock(McpTransportContext.class); + CompleteRequest request = new CompleteRequest(new PromptReference("context-complete"), + new CompleteRequest.CompleteArgument("test", "value")); + CompleteResult result = completeSpecs.get(0).completionHandler().apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.completion()).isNotNull(); + assertThat(result.completion().values()).hasSize(1); + assertThat(result.completion().values().get(0)).isEqualTo("Completion with context: present, value: value"); + } + + @Test + void testGetCompleteSpecificationsWithListReturn() { + class ListReturnComplete { + + @McpComplete(prompt = "list-complete") + public List listComplete(CompleteRequest request) { + return List.of("First completion for " + request.argument().value(), + "Second completion for " + request.argument().value()); + } + + } + + ListReturnComplete completeObject = new ListReturnComplete(); + SyncStatelessMcpCompleteProvider provider = new SyncStatelessMcpCompleteProvider(List.of(completeObject)); + + List completeSpecs = provider.getCompleteSpecifications(); + + assertThat(completeSpecs).hasSize(1); + PromptReference promptRef = (PromptReference) completeSpecs.get(0).referenceKey(); + assertThat(promptRef.name()).isEqualTo("list-complete"); + + // Test that the handler works with List return type + McpTransportContext context = mock(McpTransportContext.class); + CompleteRequest request = new CompleteRequest(new PromptReference("list-complete"), + new CompleteRequest.CompleteArgument("test", "value")); + CompleteResult result = completeSpecs.get(0).completionHandler().apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.completion()).isNotNull(); + assertThat(result.completion().values()).hasSize(2); + assertThat(result.completion().values().get(0)).isEqualTo("First completion for value"); + assertThat(result.completion().values().get(1)).isEqualTo("Second completion for value"); + } + + @Test + void testGetCompleteSpecificationsWithCompletionReturn() { + class CompletionReturnComplete { + + @McpComplete(prompt = "completion-complete") + public CompleteCompletion completionComplete(CompleteRequest request) { + return new CompleteCompletion(List.of("Completion object for " + request.argument().value()), 1, false); + } + + } + + CompletionReturnComplete completeObject = new CompletionReturnComplete(); + SyncStatelessMcpCompleteProvider provider = new SyncStatelessMcpCompleteProvider(List.of(completeObject)); + + List completeSpecs = provider.getCompleteSpecifications(); + + assertThat(completeSpecs).hasSize(1); + PromptReference promptRef = (PromptReference) completeSpecs.get(0).referenceKey(); + assertThat(promptRef.name()).isEqualTo("completion-complete"); + + // Test that the handler works with CompleteCompletion return type + McpTransportContext context = mock(McpTransportContext.class); + CompleteRequest request = new CompleteRequest(new PromptReference("completion-complete"), + new CompleteRequest.CompleteArgument("test", "value")); + CompleteResult result = completeSpecs.get(0).completionHandler().apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.completion()).isNotNull(); + assertThat(result.completion().values()).hasSize(1); + assertThat(result.completion().values().get(0)).isEqualTo("Completion object for value"); + } + + @Test + void testGetCompleteSpecificationsWithEmptyList() { + SyncStatelessMcpCompleteProvider provider = new SyncStatelessMcpCompleteProvider(List.of()); + + List completeSpecs = provider.getCompleteSpecifications(); + + assertThat(completeSpecs).isNotNull(); + assertThat(completeSpecs).isEmpty(); + } + + @Test + void testGetCompleteSpecificationsWithNoValidMethods() { + class NoValidMethods { + + public void voidMethod() { + // No return value + } + + public String nonAnnotatedMethod() { + return "Not annotated"; + } + + } + + NoValidMethods completeObject = new NoValidMethods(); + SyncStatelessMcpCompleteProvider provider = new SyncStatelessMcpCompleteProvider(List.of(completeObject)); + + List completeSpecs = provider.getCompleteSpecifications(); + + assertThat(completeSpecs).isNotNull(); + assertThat(completeSpecs).isEmpty(); + } + +} diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/SyncStatelessMcpPromptProviderTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/SyncStatelessMcpPromptProviderTests.java new file mode 100644 index 0000000..a0c1129 --- /dev/null +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/SyncStatelessMcpPromptProviderTests.java @@ -0,0 +1,556 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springaicommunity.mcp.provider; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; +import org.springaicommunity.mcp.annotation.McpArg; +import org.springaicommunity.mcp.annotation.McpPrompt; + +import io.modelcontextprotocol.server.McpStatelessServerFeatures.SyncPromptSpecification; +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; +import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; +import io.modelcontextprotocol.spec.McpSchema.PromptMessage; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import reactor.core.publisher.Mono; + +/** + * Tests for {@link SyncStatelessMcpPromptProvider}. + * + * @author Christian Tzolov + */ +public class SyncStatelessMcpPromptProviderTests { + + @Test + void testConstructorWithNullPromptObjects() { + assertThatThrownBy(() -> new SyncStatelessMcpPromptProvider(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("promptObjects cannot be null"); + } + + @Test + void testGetPromptSpecificationsWithSingleValidPrompt() { + // Create a class with only one valid prompt method + class SingleValidPrompt { + + @McpPrompt(name = "test-prompt", description = "A test prompt") + public GetPromptResult testPrompt(GetPromptRequest request) { + return new GetPromptResult("Test prompt result", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent("Hello from " + request.name())))); + } + + } + + SingleValidPrompt promptObject = new SingleValidPrompt(); + SyncStatelessMcpPromptProvider provider = new SyncStatelessMcpPromptProvider(List.of(promptObject)); + + List promptSpecs = provider.getPromptSpecifications(); + + assertThat(promptSpecs).isNotNull(); + assertThat(promptSpecs).hasSize(1); + + SyncPromptSpecification promptSpec = promptSpecs.get(0); + assertThat(promptSpec.prompt().name()).isEqualTo("test-prompt"); + assertThat(promptSpec.prompt().description()).isEqualTo("A test prompt"); + assertThat(promptSpec.promptHandler()).isNotNull(); + + // Test that the handler works + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + args.put("name", "John"); + GetPromptRequest request = new GetPromptRequest("test-prompt", args); + GetPromptResult result = promptSpec.promptHandler().apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.description()).isEqualTo("Test prompt result"); + assertThat(result.messages()).hasSize(1); + PromptMessage message = result.messages().get(0); + assertThat(message.role()).isEqualTo(Role.ASSISTANT); + assertThat(((TextContent) message.content()).text()).isEqualTo("Hello from test-prompt"); + } + + @Test + void testGetPromptSpecificationsWithCustomPromptName() { + class CustomNamePrompt { + + @McpPrompt(name = "custom-name", description = "Custom named prompt") + public GetPromptResult methodWithDifferentName() { + return new GetPromptResult("Custom prompt result", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent("Custom prompt content")))); + } + + } + + CustomNamePrompt promptObject = new CustomNamePrompt(); + SyncStatelessMcpPromptProvider provider = new SyncStatelessMcpPromptProvider(List.of(promptObject)); + + List promptSpecs = provider.getPromptSpecifications(); + + assertThat(promptSpecs).hasSize(1); + assertThat(promptSpecs.get(0).prompt().name()).isEqualTo("custom-name"); + assertThat(promptSpecs.get(0).prompt().description()).isEqualTo("Custom named prompt"); + } + + @Test + void testGetPromptSpecificationsWithDefaultPromptName() { + class DefaultNamePrompt { + + @McpPrompt(description = "Prompt with default name") + public GetPromptResult defaultNameMethod() { + return new GetPromptResult("Default prompt result", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent("Default prompt content")))); + } + + } + + DefaultNamePrompt promptObject = new DefaultNamePrompt(); + SyncStatelessMcpPromptProvider provider = new SyncStatelessMcpPromptProvider(List.of(promptObject)); + + List promptSpecs = provider.getPromptSpecifications(); + + assertThat(promptSpecs).hasSize(1); + assertThat(promptSpecs.get(0).prompt().name()).isEqualTo("defaultNameMethod"); + assertThat(promptSpecs.get(0).prompt().description()).isEqualTo("Prompt with default name"); + } + + @Test + void testGetPromptSpecificationsWithEmptyPromptName() { + class EmptyNamePrompt { + + @McpPrompt(name = "", description = "Prompt with empty name") + public GetPromptResult emptyNameMethod() { + return new GetPromptResult("Empty name prompt result", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent("Empty name prompt content")))); + } + + } + + EmptyNamePrompt promptObject = new EmptyNamePrompt(); + SyncStatelessMcpPromptProvider provider = new SyncStatelessMcpPromptProvider(List.of(promptObject)); + + List promptSpecs = provider.getPromptSpecifications(); + + assertThat(promptSpecs).hasSize(1); + assertThat(promptSpecs.get(0).prompt().name()).isEqualTo("emptyNameMethod"); + assertThat(promptSpecs.get(0).prompt().description()).isEqualTo("Prompt with empty name"); + } + + @Test + void testGetPromptSpecificationsFiltersOutMonoReturnTypes() { + class MonoReturnPrompt { + + @McpPrompt(name = "mono-prompt", description = "Prompt returning Mono") + public Mono monoPrompt() { + return Mono.just(new GetPromptResult("Mono prompt result", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent("Mono prompt content"))))); + } + + @McpPrompt(name = "sync-prompt", description = "Synchronous prompt") + public GetPromptResult syncPrompt() { + return new GetPromptResult("Sync prompt result", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent("Sync prompt content")))); + } + + } + + MonoReturnPrompt promptObject = new MonoReturnPrompt(); + SyncStatelessMcpPromptProvider provider = new SyncStatelessMcpPromptProvider(List.of(promptObject)); + + List promptSpecs = provider.getPromptSpecifications(); + + assertThat(promptSpecs).hasSize(1); + assertThat(promptSpecs.get(0).prompt().name()).isEqualTo("sync-prompt"); + assertThat(promptSpecs.get(0).prompt().description()).isEqualTo("Synchronous prompt"); + } + + @Test + void testGetPromptSpecificationsWithMultiplePromptMethods() { + class MultiplePromptMethods { + + @McpPrompt(name = "prompt1", description = "First prompt") + public GetPromptResult firstPrompt() { + return new GetPromptResult("First prompt result", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent("First prompt content")))); + } + + @McpPrompt(name = "prompt2", description = "Second prompt") + public GetPromptResult secondPrompt() { + return new GetPromptResult("Second prompt result", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent("Second prompt content")))); + } + + } + + MultiplePromptMethods promptObject = new MultiplePromptMethods(); + SyncStatelessMcpPromptProvider provider = new SyncStatelessMcpPromptProvider(List.of(promptObject)); + + List promptSpecs = provider.getPromptSpecifications(); + + assertThat(promptSpecs).hasSize(2); + assertThat(promptSpecs.get(0).prompt().name()).isIn("prompt1", "prompt2"); + assertThat(promptSpecs.get(1).prompt().name()).isIn("prompt1", "prompt2"); + assertThat(promptSpecs.get(0).prompt().name()).isNotEqualTo(promptSpecs.get(1).prompt().name()); + } + + @Test + void testGetPromptSpecificationsWithMultiplePromptObjects() { + class FirstPromptObject { + + @McpPrompt(name = "first-prompt", description = "First prompt") + public GetPromptResult firstPrompt() { + return new GetPromptResult("First prompt result", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent("First prompt content")))); + } + + } + + class SecondPromptObject { + + @McpPrompt(name = "second-prompt", description = "Second prompt") + public GetPromptResult secondPrompt() { + return new GetPromptResult("Second prompt result", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent("Second prompt content")))); + } + + } + + FirstPromptObject firstObject = new FirstPromptObject(); + SecondPromptObject secondObject = new SecondPromptObject(); + SyncStatelessMcpPromptProvider provider = new SyncStatelessMcpPromptProvider( + List.of(firstObject, secondObject)); + + List promptSpecs = provider.getPromptSpecifications(); + + assertThat(promptSpecs).hasSize(2); + assertThat(promptSpecs.get(0).prompt().name()).isIn("first-prompt", "second-prompt"); + assertThat(promptSpecs.get(1).prompt().name()).isIn("first-prompt", "second-prompt"); + assertThat(promptSpecs.get(0).prompt().name()).isNotEqualTo(promptSpecs.get(1).prompt().name()); + } + + @Test + void testGetPromptSpecificationsWithMixedMethods() { + class MixedMethods { + + @McpPrompt(name = "valid-prompt", description = "Valid prompt") + public GetPromptResult validPrompt() { + return new GetPromptResult("Valid prompt result", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent("Valid prompt content")))); + } + + public GetPromptResult nonAnnotatedMethod() { + return new GetPromptResult("Non-annotated result", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent("Non-annotated content")))); + } + + @McpPrompt(name = "mono-prompt", description = "Mono prompt") + public Mono monoPrompt() { + return Mono.just(new GetPromptResult("Mono prompt result", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent("Mono prompt content"))))); + } + + } + + MixedMethods promptObject = new MixedMethods(); + SyncStatelessMcpPromptProvider provider = new SyncStatelessMcpPromptProvider(List.of(promptObject)); + + List promptSpecs = provider.getPromptSpecifications(); + + assertThat(promptSpecs).hasSize(1); + assertThat(promptSpecs.get(0).prompt().name()).isEqualTo("valid-prompt"); + assertThat(promptSpecs.get(0).prompt().description()).isEqualTo("Valid prompt"); + } + + @Test + void testGetPromptSpecificationsWithArguments() { + class ArgumentPrompt { + + @McpPrompt(name = "argument-prompt", description = "Prompt with arguments") + public GetPromptResult argumentPrompt( + @McpArg(name = "name", description = "User's name", required = true) String name, + @McpArg(name = "age", description = "User's age", required = false) Integer age) { + return new GetPromptResult("Argument prompt result", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent( + "Hello " + name + ", you are " + (age != null ? age : "unknown") + " years old")))); + } + + } + + ArgumentPrompt promptObject = new ArgumentPrompt(); + SyncStatelessMcpPromptProvider provider = new SyncStatelessMcpPromptProvider(List.of(promptObject)); + + List promptSpecs = provider.getPromptSpecifications(); + + assertThat(promptSpecs).hasSize(1); + assertThat(promptSpecs.get(0).prompt().name()).isEqualTo("argument-prompt"); + assertThat(promptSpecs.get(0).prompt().arguments()).hasSize(2); + + // Test that the handler works with arguments + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + args.put("name", "John"); + args.put("age", 30); + GetPromptRequest request = new GetPromptRequest("argument-prompt", args); + GetPromptResult result = promptSpecs.get(0).promptHandler().apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.description()).isEqualTo("Argument prompt result"); + assertThat(result.messages()).hasSize(1); + PromptMessage message = result.messages().get(0); + assertThat(message.role()).isEqualTo(Role.ASSISTANT); + assertThat(((TextContent) message.content()).text()).isEqualTo("Hello John, you are 30 years old"); + } + + @Test + void testGetPromptSpecificationsWithPrivateMethod() { + class PrivateMethodPrompt { + + @McpPrompt(name = "private-prompt", description = "Private prompt method") + private GetPromptResult privatePrompt() { + return new GetPromptResult("Private prompt result", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent("Private prompt content")))); + } + + } + + PrivateMethodPrompt promptObject = new PrivateMethodPrompt(); + SyncStatelessMcpPromptProvider provider = new SyncStatelessMcpPromptProvider(List.of(promptObject)); + + List promptSpecs = provider.getPromptSpecifications(); + + assertThat(promptSpecs).hasSize(1); + assertThat(promptSpecs.get(0).prompt().name()).isEqualTo("private-prompt"); + assertThat(promptSpecs.get(0).prompt().description()).isEqualTo("Private prompt method"); + + // Test that the handler works with private methods + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + GetPromptRequest request = new GetPromptRequest("private-prompt", args); + GetPromptResult result = promptSpecs.get(0).promptHandler().apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.description()).isEqualTo("Private prompt result"); + assertThat(result.messages()).hasSize(1); + PromptMessage message = result.messages().get(0); + assertThat(message.role()).isEqualTo(Role.ASSISTANT); + assertThat(((TextContent) message.content()).text()).isEqualTo("Private prompt content"); + } + + @Test + void testGetPromptSpecificationsWithStringReturn() { + class StringReturnPrompt { + + @McpPrompt(name = "string-prompt", description = "Prompt returning string") + public String stringPrompt() { + return "Simple string response"; + } + + } + + StringReturnPrompt promptObject = new StringReturnPrompt(); + SyncStatelessMcpPromptProvider provider = new SyncStatelessMcpPromptProvider(List.of(promptObject)); + + List promptSpecs = provider.getPromptSpecifications(); + + assertThat(promptSpecs).hasSize(1); + assertThat(promptSpecs.get(0).prompt().name()).isEqualTo("string-prompt"); + + // Test that the handler works with string return type + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + GetPromptRequest request = new GetPromptRequest("string-prompt", args); + GetPromptResult result = promptSpecs.get(0).promptHandler().apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.messages()).hasSize(1); + PromptMessage message = result.messages().get(0); + assertThat(message.role()).isEqualTo(Role.ASSISTANT); + assertThat(((TextContent) message.content()).text()).isEqualTo("Simple string response"); + } + + @Test + void testGetPromptSpecificationsWithContextParameter() { + class ContextParameterPrompt { + + @McpPrompt(name = "context-prompt", description = "Prompt with context parameter") + public GetPromptResult contextPrompt(McpTransportContext context, GetPromptRequest request) { + return new GetPromptResult("Context prompt result", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent("Prompt with context: " + + (context != null ? "present" : "null") + ", name: " + request.name())))); + } + + } + + ContextParameterPrompt promptObject = new ContextParameterPrompt(); + SyncStatelessMcpPromptProvider provider = new SyncStatelessMcpPromptProvider(List.of(promptObject)); + + List promptSpecs = provider.getPromptSpecifications(); + + assertThat(promptSpecs).hasSize(1); + assertThat(promptSpecs.get(0).prompt().name()).isEqualTo("context-prompt"); + + // Test that the handler works with context parameter + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + GetPromptRequest request = new GetPromptRequest("context-prompt", args); + GetPromptResult result = promptSpecs.get(0).promptHandler().apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.description()).isEqualTo("Context prompt result"); + assertThat(result.messages()).hasSize(1); + PromptMessage message = result.messages().get(0); + assertThat(message.role()).isEqualTo(Role.ASSISTANT); + assertThat(((TextContent) message.content()).text()) + .isEqualTo("Prompt with context: present, name: context-prompt"); + } + + @Test + void testGetPromptSpecificationsWithRequestParameter() { + class RequestParameterPrompt { + + @McpPrompt(name = "request-prompt", description = "Prompt with request parameter") + public GetPromptResult requestPrompt(GetPromptRequest request) { + return new GetPromptResult("Request prompt result", List + .of(new PromptMessage(Role.ASSISTANT, new TextContent("Prompt for name: " + request.name())))); + } + + } + + RequestParameterPrompt promptObject = new RequestParameterPrompt(); + SyncStatelessMcpPromptProvider provider = new SyncStatelessMcpPromptProvider(List.of(promptObject)); + + List promptSpecs = provider.getPromptSpecifications(); + + assertThat(promptSpecs).hasSize(1); + assertThat(promptSpecs.get(0).prompt().name()).isEqualTo("request-prompt"); + + // Test that the handler works with request parameter + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + GetPromptRequest request = new GetPromptRequest("request-prompt", args); + GetPromptResult result = promptSpecs.get(0).promptHandler().apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.description()).isEqualTo("Request prompt result"); + assertThat(result.messages()).hasSize(1); + PromptMessage message = result.messages().get(0); + assertThat(message.role()).isEqualTo(Role.ASSISTANT); + assertThat(((TextContent) message.content()).text()).isEqualTo("Prompt for name: request-prompt"); + } + + @Test + void testGetPromptSpecificationsWithMessagesList() { + class MessagesListPrompt { + + @McpPrompt(name = "messages-list-prompt", description = "Prompt returning messages list") + public List messagesListPrompt() { + return List.of(new PromptMessage(Role.ASSISTANT, new TextContent("First message")), + new PromptMessage(Role.ASSISTANT, new TextContent("Second message"))); + } + + } + + MessagesListPrompt promptObject = new MessagesListPrompt(); + SyncStatelessMcpPromptProvider provider = new SyncStatelessMcpPromptProvider(List.of(promptObject)); + + List promptSpecs = provider.getPromptSpecifications(); + + assertThat(promptSpecs).hasSize(1); + assertThat(promptSpecs.get(0).prompt().name()).isEqualTo("messages-list-prompt"); + + // Test that the handler works with messages list return type + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + GetPromptRequest request = new GetPromptRequest("messages-list-prompt", args); + GetPromptResult result = promptSpecs.get(0).promptHandler().apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.messages()).hasSize(2); + assertThat(((TextContent) result.messages().get(0).content()).text()).isEqualTo("First message"); + assertThat(((TextContent) result.messages().get(1).content()).text()).isEqualTo("Second message"); + } + + @Test + void testGetPromptSpecificationsWithSingleMessage() { + class SingleMessagePrompt { + + @McpPrompt(name = "single-message-prompt", description = "Prompt returning single message") + public PromptMessage singleMessagePrompt() { + return new PromptMessage(Role.ASSISTANT, new TextContent("Single message")); + } + + } + + SingleMessagePrompt promptObject = new SingleMessagePrompt(); + SyncStatelessMcpPromptProvider provider = new SyncStatelessMcpPromptProvider(List.of(promptObject)); + + List promptSpecs = provider.getPromptSpecifications(); + + assertThat(promptSpecs).hasSize(1); + assertThat(promptSpecs.get(0).prompt().name()).isEqualTo("single-message-prompt"); + + // Test that the handler works with single message return type + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + GetPromptRequest request = new GetPromptRequest("single-message-prompt", args); + GetPromptResult result = promptSpecs.get(0).promptHandler().apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.messages()).hasSize(1); + assertThat(((TextContent) result.messages().get(0).content()).text()).isEqualTo("Single message"); + } + + @Test + void testGetPromptSpecificationsWithStringList() { + class StringListPrompt { + + @McpPrompt(name = "string-list-prompt", description = "Prompt returning string list") + public List stringListPrompt() { + return List.of("First string", "Second string", "Third string"); + } + + } + + StringListPrompt promptObject = new StringListPrompt(); + SyncStatelessMcpPromptProvider provider = new SyncStatelessMcpPromptProvider(List.of(promptObject)); + + List promptSpecs = provider.getPromptSpecifications(); + + assertThat(promptSpecs).hasSize(1); + assertThat(promptSpecs.get(0).prompt().name()).isEqualTo("string-list-prompt"); + + // Test that the handler works with string list return type + McpTransportContext context = mock(McpTransportContext.class); + Map args = new HashMap<>(); + GetPromptRequest request = new GetPromptRequest("string-list-prompt", args); + GetPromptResult result = promptSpecs.get(0).promptHandler().apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.messages()).hasSize(3); + assertThat(((TextContent) result.messages().get(0).content()).text()).isEqualTo("First string"); + assertThat(((TextContent) result.messages().get(1).content()).text()).isEqualTo("Second string"); + assertThat(((TextContent) result.messages().get(2).content()).text()).isEqualTo("Third string"); + } + +} diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/SyncStatelessMcpResourceProviderTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/SyncStatelessMcpResourceProviderTests.java new file mode 100644 index 0000000..6a21817 --- /dev/null +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/SyncStatelessMcpResourceProviderTests.java @@ -0,0 +1,451 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springaicommunity.mcp.provider; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.springaicommunity.mcp.annotation.McpResource; + +import io.modelcontextprotocol.server.McpStatelessServerFeatures.SyncResourceSpecification; +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceRequest; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.ResourceContents; +import io.modelcontextprotocol.spec.McpSchema.TextResourceContents; +import reactor.core.publisher.Mono; + +/** + * Tests for {@link SyncStatelessMcpResourceProvider}. + * + * @author Christian Tzolov + */ +public class SyncStatelessMcpResourceProviderTests { + + @Test + void testConstructorWithNullResourceObjects() { + assertThatThrownBy(() -> new SyncStatelessMcpResourceProvider(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("resourceObjects cannot be null"); + } + + @Test + void testGetResourceSpecificationsWithSingleValidResource() { + // Create a class with only one valid resource method + class SingleValidResource { + + @McpResource(uri = "test://resource/{id}", name = "test-resource", description = "A test resource") + public String testResource(String id) { + return "Resource content for: " + id; + } + + } + + SingleValidResource resourceObject = new SingleValidResource(); + SyncStatelessMcpResourceProvider provider = new SyncStatelessMcpResourceProvider(List.of(resourceObject)); + + List resourceSpecs = provider.getResourceSpecifications(); + + assertThat(resourceSpecs).isNotNull(); + assertThat(resourceSpecs).hasSize(1); + + SyncResourceSpecification resourceSpec = resourceSpecs.get(0); + assertThat(resourceSpec.resource().uri()).isEqualTo("test://resource/{id}"); + assertThat(resourceSpec.resource().name()).isEqualTo("test-resource"); + assertThat(resourceSpec.resource().description()).isEqualTo("A test resource"); + assertThat(resourceSpec.readHandler()).isNotNull(); + + // Test that the handler works + McpTransportContext context = mock(McpTransportContext.class); + ReadResourceRequest request = new ReadResourceRequest("test://resource/123"); + ReadResourceResult result = resourceSpec.readHandler().apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(1); + ResourceContents content = result.contents().get(0); + assertThat(content).isInstanceOf(TextResourceContents.class); + assertThat(((TextResourceContents) content).text()).isEqualTo("Resource content for: 123"); + } + + @Test + void testGetResourceSpecificationsWithCustomResourceName() { + class CustomNameResource { + + @McpResource(uri = "custom://resource", name = "custom-name", description = "Custom named resource") + public String methodWithDifferentName() { + return "Custom resource content"; + } + + } + + CustomNameResource resourceObject = new CustomNameResource(); + SyncStatelessMcpResourceProvider provider = new SyncStatelessMcpResourceProvider(List.of(resourceObject)); + + List resourceSpecs = provider.getResourceSpecifications(); + + assertThat(resourceSpecs).hasSize(1); + assertThat(resourceSpecs.get(0).resource().name()).isEqualTo("custom-name"); + assertThat(resourceSpecs.get(0).resource().description()).isEqualTo("Custom named resource"); + } + + @Test + void testGetResourceSpecificationsWithDefaultResourceName() { + class DefaultNameResource { + + @McpResource(uri = "default://resource", description = "Resource with default name") + public String defaultNameMethod() { + return "Default resource content"; + } + + } + + DefaultNameResource resourceObject = new DefaultNameResource(); + SyncStatelessMcpResourceProvider provider = new SyncStatelessMcpResourceProvider(List.of(resourceObject)); + + List resourceSpecs = provider.getResourceSpecifications(); + + assertThat(resourceSpecs).hasSize(1); + assertThat(resourceSpecs.get(0).resource().name()).isEqualTo("defaultNameMethod"); + assertThat(resourceSpecs.get(0).resource().description()).isEqualTo("Resource with default name"); + } + + @Test + void testGetResourceSpecificationsWithEmptyResourceName() { + class EmptyNameResource { + + @McpResource(uri = "empty://resource", name = "", description = "Resource with empty name") + public String emptyNameMethod() { + return "Empty name resource content"; + } + + } + + EmptyNameResource resourceObject = new EmptyNameResource(); + SyncStatelessMcpResourceProvider provider = new SyncStatelessMcpResourceProvider(List.of(resourceObject)); + + List resourceSpecs = provider.getResourceSpecifications(); + + assertThat(resourceSpecs).hasSize(1); + assertThat(resourceSpecs.get(0).resource().name()).isEqualTo("emptyNameMethod"); + assertThat(resourceSpecs.get(0).resource().description()).isEqualTo("Resource with empty name"); + } + + @Test + void testGetResourceSpecificationsFiltersOutMonoReturnTypes() { + class MonoReturnResource { + + @McpResource(uri = "mono://resource", name = "mono-resource", description = "Resource returning Mono") + public Mono monoResource() { + return Mono.just("Mono resource content"); + } + + @McpResource(uri = "sync://resource", name = "sync-resource", description = "Synchronous resource") + public String syncResource() { + return "Sync resource content"; + } + + } + + MonoReturnResource resourceObject = new MonoReturnResource(); + SyncStatelessMcpResourceProvider provider = new SyncStatelessMcpResourceProvider(List.of(resourceObject)); + + List resourceSpecs = provider.getResourceSpecifications(); + + assertThat(resourceSpecs).hasSize(1); + assertThat(resourceSpecs.get(0).resource().name()).isEqualTo("sync-resource"); + assertThat(resourceSpecs.get(0).resource().description()).isEqualTo("Synchronous resource"); + } + + @Test + void testGetResourceSpecificationsWithMultipleResourceMethods() { + class MultipleResourceMethods { + + @McpResource(uri = "first://resource", name = "resource1", description = "First resource") + public String firstResource() { + return "First resource content"; + } + + @McpResource(uri = "second://resource", name = "resource2", description = "Second resource") + public String secondResource() { + return "Second resource content"; + } + + } + + MultipleResourceMethods resourceObject = new MultipleResourceMethods(); + SyncStatelessMcpResourceProvider provider = new SyncStatelessMcpResourceProvider(List.of(resourceObject)); + + List resourceSpecs = provider.getResourceSpecifications(); + + assertThat(resourceSpecs).hasSize(2); + assertThat(resourceSpecs.get(0).resource().name()).isIn("resource1", "resource2"); + assertThat(resourceSpecs.get(1).resource().name()).isIn("resource1", "resource2"); + assertThat(resourceSpecs.get(0).resource().name()).isNotEqualTo(resourceSpecs.get(1).resource().name()); + } + + @Test + void testGetResourceSpecificationsWithMultipleResourceObjects() { + class FirstResourceObject { + + @McpResource(uri = "first://resource", name = "first-resource", description = "First resource") + public String firstResource() { + return "First resource content"; + } + + } + + class SecondResourceObject { + + @McpResource(uri = "second://resource", name = "second-resource", description = "Second resource") + public String secondResource() { + return "Second resource content"; + } + + } + + FirstResourceObject firstObject = new FirstResourceObject(); + SecondResourceObject secondObject = new SecondResourceObject(); + SyncStatelessMcpResourceProvider provider = new SyncStatelessMcpResourceProvider( + List.of(firstObject, secondObject)); + + List resourceSpecs = provider.getResourceSpecifications(); + + assertThat(resourceSpecs).hasSize(2); + assertThat(resourceSpecs.get(0).resource().name()).isIn("first-resource", "second-resource"); + assertThat(resourceSpecs.get(1).resource().name()).isIn("first-resource", "second-resource"); + assertThat(resourceSpecs.get(0).resource().name()).isNotEqualTo(resourceSpecs.get(1).resource().name()); + } + + @Test + void testGetResourceSpecificationsWithMixedMethods() { + class MixedMethods { + + @McpResource(uri = "valid://resource", name = "valid-resource", description = "Valid resource") + public String validResource() { + return "Valid resource content"; + } + + public String nonAnnotatedMethod() { + return "Non-annotated resource content"; + } + + @McpResource(uri = "mono://resource", name = "mono-resource", description = "Mono resource") + public Mono monoResource() { + return Mono.just("Mono resource content"); + } + + } + + MixedMethods resourceObject = new MixedMethods(); + SyncStatelessMcpResourceProvider provider = new SyncStatelessMcpResourceProvider(List.of(resourceObject)); + + List resourceSpecs = provider.getResourceSpecifications(); + + assertThat(resourceSpecs).hasSize(1); + assertThat(resourceSpecs.get(0).resource().name()).isEqualTo("valid-resource"); + assertThat(resourceSpecs.get(0).resource().description()).isEqualTo("Valid resource"); + } + + @Test + void testGetResourceSpecificationsWithUriVariables() { + class UriVariableResource { + + @McpResource(uri = "variable://resource/{id}/{type}", name = "variable-resource", + description = "Resource with URI variables") + public String variableResource(String id, String type) { + return String.format("Resource content for id: %s, type: %s", id, type); + } + + } + + UriVariableResource resourceObject = new UriVariableResource(); + SyncStatelessMcpResourceProvider provider = new SyncStatelessMcpResourceProvider(List.of(resourceObject)); + + List resourceSpecs = provider.getResourceSpecifications(); + + assertThat(resourceSpecs).hasSize(1); + assertThat(resourceSpecs.get(0).resource().uri()).isEqualTo("variable://resource/{id}/{type}"); + assertThat(resourceSpecs.get(0).resource().name()).isEqualTo("variable-resource"); + + // Test that the handler works with URI variables + McpTransportContext context = mock(McpTransportContext.class); + ReadResourceRequest request = new ReadResourceRequest("variable://resource/123/document"); + ReadResourceResult result = resourceSpecs.get(0).readHandler().apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(1); + ResourceContents content = result.contents().get(0); + assertThat(content).isInstanceOf(TextResourceContents.class); + assertThat(((TextResourceContents) content).text()).isEqualTo("Resource content for id: 123, type: document"); + } + + @Test + void testGetResourceSpecificationsWithMimeType() { + class MimeTypeResource { + + @McpResource(uri = "mime://resource", name = "mime-resource", description = "Resource with MIME type", + mimeType = "application/json") + public String mimeTypeResource() { + return "{\"message\": \"JSON resource content\"}"; + } + + } + + MimeTypeResource resourceObject = new MimeTypeResource(); + SyncStatelessMcpResourceProvider provider = new SyncStatelessMcpResourceProvider(List.of(resourceObject)); + + List resourceSpecs = provider.getResourceSpecifications(); + + assertThat(resourceSpecs).hasSize(1); + assertThat(resourceSpecs.get(0).resource().mimeType()).isEqualTo("application/json"); + assertThat(resourceSpecs.get(0).resource().name()).isEqualTo("mime-resource"); + } + + @Test + void testGetResourceSpecificationsWithPrivateMethod() { + class PrivateMethodResource { + + @McpResource(uri = "private://resource", name = "private-resource", description = "Private resource method") + private String privateResource() { + return "Private resource content"; + } + + } + + PrivateMethodResource resourceObject = new PrivateMethodResource(); + SyncStatelessMcpResourceProvider provider = new SyncStatelessMcpResourceProvider(List.of(resourceObject)); + + List resourceSpecs = provider.getResourceSpecifications(); + + assertThat(resourceSpecs).hasSize(1); + assertThat(resourceSpecs.get(0).resource().name()).isEqualTo("private-resource"); + assertThat(resourceSpecs.get(0).resource().description()).isEqualTo("Private resource method"); + + // Test that the handler works with private methods + McpTransportContext context = mock(McpTransportContext.class); + ReadResourceRequest request = new ReadResourceRequest("private://resource"); + ReadResourceResult result = resourceSpecs.get(0).readHandler().apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(1); + ResourceContents content = result.contents().get(0); + assertThat(content).isInstanceOf(TextResourceContents.class); + assertThat(((TextResourceContents) content).text()).isEqualTo("Private resource content"); + } + + @Test + void testGetResourceSpecificationsWithResourceContentsList() { + class ResourceContentsListResource { + + @McpResource(uri = "list://resource", name = "list-resource", description = "Resource returning list") + public List listResource() { + return List.of("First content", "Second content"); + } + + } + + ResourceContentsListResource resourceObject = new ResourceContentsListResource(); + SyncStatelessMcpResourceProvider provider = new SyncStatelessMcpResourceProvider(List.of(resourceObject)); + + List resourceSpecs = provider.getResourceSpecifications(); + + assertThat(resourceSpecs).hasSize(1); + assertThat(resourceSpecs.get(0).resource().name()).isEqualTo("list-resource"); + + // Test that the handler works with list return type + McpTransportContext context = mock(McpTransportContext.class); + ReadResourceRequest request = new ReadResourceRequest("list://resource"); + ReadResourceResult result = resourceSpecs.get(0).readHandler().apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(2); + assertThat(result.contents().get(0)).isInstanceOf(TextResourceContents.class); + assertThat(result.contents().get(1)).isInstanceOf(TextResourceContents.class); + assertThat(((TextResourceContents) result.contents().get(0)).text()).isEqualTo("First content"); + assertThat(((TextResourceContents) result.contents().get(1)).text()).isEqualTo("Second content"); + } + + @Test + void testGetResourceSpecificationsWithContextParameter() { + class ContextParameterResource { + + @McpResource(uri = "context://resource", name = "context-resource", + description = "Resource with context parameter") + public String contextResource(McpTransportContext context, ReadResourceRequest request) { + return "Resource with context: " + (context != null ? "present" : "null") + ", URI: " + request.uri(); + } + + } + + ContextParameterResource resourceObject = new ContextParameterResource(); + SyncStatelessMcpResourceProvider provider = new SyncStatelessMcpResourceProvider(List.of(resourceObject)); + + List resourceSpecs = provider.getResourceSpecifications(); + + assertThat(resourceSpecs).hasSize(1); + assertThat(resourceSpecs.get(0).resource().name()).isEqualTo("context-resource"); + + // Test that the handler works with context parameter + McpTransportContext context = mock(McpTransportContext.class); + ReadResourceRequest request = new ReadResourceRequest("context://resource"); + ReadResourceResult result = resourceSpecs.get(0).readHandler().apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(1); + ResourceContents content = result.contents().get(0); + assertThat(content).isInstanceOf(TextResourceContents.class); + assertThat(((TextResourceContents) content).text()) + .isEqualTo("Resource with context: present, URI: context://resource"); + } + + @Test + void testGetResourceSpecificationsWithRequestParameter() { + class RequestParameterResource { + + @McpResource(uri = "request://resource", name = "request-resource", + description = "Resource with request parameter") + public String requestResource(ReadResourceRequest request) { + return "Resource for URI: " + request.uri(); + } + + } + + RequestParameterResource resourceObject = new RequestParameterResource(); + SyncStatelessMcpResourceProvider provider = new SyncStatelessMcpResourceProvider(List.of(resourceObject)); + + List resourceSpecs = provider.getResourceSpecifications(); + + assertThat(resourceSpecs).hasSize(1); + assertThat(resourceSpecs.get(0).resource().name()).isEqualTo("request-resource"); + + // Test that the handler works with request parameter + McpTransportContext context = mock(McpTransportContext.class); + ReadResourceRequest request = new ReadResourceRequest("request://resource"); + ReadResourceResult result = resourceSpecs.get(0).readHandler().apply(context, request); + + assertThat(result).isNotNull(); + assertThat(result.contents()).hasSize(1); + ResourceContents content = result.contents().get(0); + assertThat(content).isInstanceOf(TextResourceContents.class); + assertThat(((TextResourceContents) content).text()).isEqualTo("Resource for URI: request://resource"); + } + +} diff --git a/pom.xml b/pom.xml index e20ff1d..e8c6797 100644 --- a/pom.xml +++ b/pom.xml @@ -55,7 +55,7 @@ 17 17 - 0.12.0-SNAPSHOT + 0.11.2 1.1.0-SNAPSHOT 2.0.16