diff --git a/README.md b/README.md index 6e4ca58..6df7053 100644 --- a/README.md +++ b/README.md @@ -91,6 +91,7 @@ The core module provides a set of annotations and callback implementations for p 5. **Logging Consumer** - For handling logging message notifications 6. **Sampling** - For handling sampling requests 7. **Elicitation** - For handling elicitation requests to gather additional information from users +8. **Progress** - For handling progress notifications during long-running operations Each operation type has both synchronous and asynchronous implementations, allowing for flexible integration with different application architectures. @@ -110,6 +111,7 @@ The Spring integration module provides seamless integration with Spring AI and S - **`@McpLoggingConsumer`** - Annotates methods that handle logging message notifications from MCP servers - **`@McpSampling`** - Annotates methods that handle sampling requests from MCP servers - **`@McpElicitation`** - Annotates methods that handle elicitation requests to gather additional information from users +- **`@McpProgress`** - Annotates methods that handle progress notifications for long-running operations - **`@McpArg`** - Annotates method parameters as MCP arguments ### Method Callbacks @@ -160,6 +162,11 @@ The modules provide callback implementations for each operation type: - `SyncMcpElicitationMethodCallback` - Synchronous implementation - `AsyncMcpElicitationMethodCallback` - Asynchronous implementation using Reactor's Mono +#### Progress +- `AbstractMcpProgressMethodCallback` - Base class for progress method callbacks +- `SyncMcpProgressMethodCallback` - Synchronous implementation +- `AsyncMcpProgressMethodCallback` - Asynchronous implementation using Reactor's Mono + ### Providers The project includes provider classes that scan for annotated methods and create appropriate callbacks: @@ -176,6 +183,8 @@ The project includes provider classes that scan for annotated methods and create - `AsyncMcpSamplingProvider` - Processes `@McpSampling` annotations for asynchronous operations - `SyncMcpElicitationProvider` - Processes `@McpElicitation` annotations for synchronous operations - `AsyncMcpElicitationProvider` - Processes `@McpElicitation` annotations for asynchronous operations +- `SyncMcpProgressProvider` - Processes `@McpProgress` annotations for synchronous operations +- `AsyncMcpProgressProvider` - Processes `@McpProgress` annotations for asynchronous operations #### Stateless Providers (using McpTransportContext) - `SyncStatelessMcpCompleteProvider` - Processes `@McpComplete` annotations for synchronous stateless operations @@ -807,6 +816,126 @@ public class MyMcpClient { } ``` +### Mcp Client Progress Example + +```java +public class ProgressHandler { + + /** + * Handle progress notifications with a single parameter. + * @param notification The progress notification + */ + @McpProgress + public void handleProgressNotification(ProgressNotification notification) { + System.out.println(String.format("Progress: %.2f%% - %s", + notification.progress() * 100, + notification.message())); + } + + /** + * Handle progress notifications with individual parameters. + * @param progressToken The progress token identifying the operation + * @param progress The current progress (0.0 to 1.0) + * @param total Optional total value for the operation + * @param message Optional progress message + */ + @McpProgress + public void handleProgressWithParams(String progressToken, double progress, Double total, String message) { + if (total != null) { + System.out.println(String.format("Progress [%s]: %.0f/%.0f - %s", + progressToken, progress, total, message)); + } else { + System.out.println(String.format("Progress [%s]: %.2f%% - %s", + progressToken, progress * 100, message)); + } + } + + /** + * Handle progress notifications for a specific client. + * @param notification The progress notification + */ + @McpProgress(clientId = "client-1") + public void handleClient1Progress(ProgressNotification notification) { + System.out.println(String.format("Client-1 Progress: %.2f%% - %s", + notification.progress() * 100, + notification.message())); + } +} + +public class AsyncProgressHandler { + + /** + * Handle progress notifications asynchronously. + * @param notification The progress notification + * @return A Mono that completes when the notification is handled + */ + @McpProgress + public Mono handleAsyncProgress(ProgressNotification notification) { + return Mono.fromRunnable(() -> { + System.out.println(String.format("Async Progress: %.2f%% - %s", + notification.progress() * 100, + notification.message())); + }); + } + + /** + * Handle progress notifications for a specific client asynchronously. + * @param progressToken The progress token + * @param progress The current progress + * @param total Optional total value + * @param message Optional message + * @return A Mono that completes when the notification is handled + */ + @McpProgress(clientId = "client-2") + public Mono handleClient2AsyncProgress( + String progressToken, + double progress, + Double total, + String message) { + + return Mono.fromRunnable(() -> { + String progressText = total != null ? + String.format("%.0f/%.0f", progress, total) : + String.format("%.2f%%", progress * 100); + + System.out.println(String.format("Client-2 Progress [%s]: %s - %s", + progressToken, progressText, message)); + }).then(); + } +} + +public class MyMcpClient { + + public static McpSyncClient createSyncClientWithProgress(ProgressHandler progressHandler) { + List> progressConsumers = + new SyncMcpProgressProvider(List.of(progressHandler)).getProgressConsumers(); + + McpSyncClient client = McpClient.sync(transport) + .capabilities(ClientCapabilities.builder() + // Enable capabilities... + .build()) + .progressConsumers(progressConsumers) + .build(); + + return client; + } + + public static McpAsyncClient createAsyncClientWithProgress(AsyncProgressHandler asyncProgressHandler) { + List>> progressHandlers = + new AsyncMcpProgressProvider(List.of(asyncProgressHandler)).getProgressHandlers(); + + McpAsyncClient client = McpClient.async(transport) + .capabilities(ClientCapabilities.builder() + // Enable capabilities... + .build()) + .progressHandlers(progressHandlers) + .build(); + + return client; + } +} +``` + ### Mcp Client Elicitation Example ```java @@ -1213,6 +1342,18 @@ public class McpConfig { return SpringAiMcpAnnotationProvider.createAsyncElicitationSpecifications(asyncElicitationHandlers); } + @Bean + public List syncProgressSpecifications( + List progressHandlers) { + return SpringAiMcpAnnotationProvider.createSyncProgressSpecifications(progressHandlers); + } + + @Bean + public List asyncProgressSpecifications( + List asyncProgressHandlers) { + return SpringAiMcpAnnotationProvider.createAsyncProgressSpecifications(asyncProgressHandlers); + } + // Stateless Spring Integration Examples @Bean @@ -1248,6 +1389,7 @@ public class McpConfig { - **Dynamic schema support via CallToolRequest** - Tools can accept `CallToolRequest` parameters to handle dynamic schemas at runtime - **Logging consumer support** - Handle logging message notifications from MCP servers - **Sampling support** - Handle sampling requests from MCP servers +- **Progress notification support** - Handle progress notifications for long-running operations - **Spring integration** - Seamless integration with Spring Framework and Spring AI, including support for both stateful and stateless operations - **AOP proxy support** - Proper handling of Spring AOP proxies when processing annotations diff --git a/mcp-annotations-spring/src/main/java/org/springaicommunity/mcp/spring/AsyncMcpAnnotationProvider.java b/mcp-annotations-spring/src/main/java/org/springaicommunity/mcp/spring/AsyncMcpAnnotationProvider.java index b6a3262..434c4b2 100644 --- a/mcp-annotations-spring/src/main/java/org/springaicommunity/mcp/spring/AsyncMcpAnnotationProvider.java +++ b/mcp-annotations-spring/src/main/java/org/springaicommunity/mcp/spring/AsyncMcpAnnotationProvider.java @@ -20,9 +20,11 @@ import org.springaicommunity.mcp.method.elicitation.AsyncElicitationSpecification; import org.springaicommunity.mcp.method.logging.AsyncLoggingSpecification; +import org.springaicommunity.mcp.method.progress.AsyncProgressSpecification; import org.springaicommunity.mcp.method.sampling.AsyncSamplingSpecification; import org.springaicommunity.mcp.provider.AsyncMcpElicitationProvider; import org.springaicommunity.mcp.provider.AsyncMcpLoggingConsumerProvider; +import org.springaicommunity.mcp.provider.AsyncMcpProgressProvider; import org.springaicommunity.mcp.provider.AsyncMcpSamplingProvider; import org.springaicommunity.mcp.provider.AsyncMcpToolProvider; import org.springaicommunity.mcp.provider.AsyncStatelessMcpPromptProvider; @@ -128,6 +130,19 @@ protected Method[] doGetClassMethods(Object bean) { } + private static class SpringAiAsyncMcpProgressProvider extends AsyncMcpProgressProvider { + + public SpringAiAsyncMcpProgressProvider(List progressObjects) { + super(progressObjects); + } + + @Override + protected Method[] doGetClassMethods(Object bean) { + return AnnotationProviderUtil.beanMethods(bean); + } + + } + public static List createAsyncLoggingSpecifications(List loggingObjects) { return new SpringAiAsyncMcpLoggingConsumerProvider(loggingObjects).getLoggingSpecifications(); } @@ -160,4 +175,8 @@ public static List create return new SpringAiAsyncStatelessResourceProvider(resourceObjects).getResourceSpecifications(); } + public static List createAsyncProgressSpecifications(List progressObjects) { + return new SpringAiAsyncMcpProgressProvider(progressObjects).getProgressSpecifications(); + } + } diff --git a/mcp-annotations-spring/src/main/java/org/springaicommunity/mcp/spring/SyncMcpAnnotationProvider.java b/mcp-annotations-spring/src/main/java/org/springaicommunity/mcp/spring/SyncMcpAnnotationProvider.java index 8e4ac30..e702f71 100644 --- a/mcp-annotations-spring/src/main/java/org/springaicommunity/mcp/spring/SyncMcpAnnotationProvider.java +++ b/mcp-annotations-spring/src/main/java/org/springaicommunity/mcp/spring/SyncMcpAnnotationProvider.java @@ -20,10 +20,12 @@ import org.springaicommunity.mcp.method.elicitation.SyncElicitationSpecification; import org.springaicommunity.mcp.method.logging.SyncLoggingSpecification; +import org.springaicommunity.mcp.method.progress.SyncProgressSpecification; import org.springaicommunity.mcp.method.sampling.SyncSamplingSpecification; import org.springaicommunity.mcp.provider.SyncMcpCompletionProvider; import org.springaicommunity.mcp.provider.SyncMcpElicitationProvider; import org.springaicommunity.mcp.provider.SyncMcpLoggingConsumerProvider; +import org.springaicommunity.mcp.provider.SyncMcpProgressProvider; import org.springaicommunity.mcp.provider.SyncMcpPromptProvider; import org.springaicommunity.mcp.provider.SyncMcpResourceProvider; import org.springaicommunity.mcp.provider.SyncMcpSamplingProvider; @@ -173,6 +175,19 @@ protected Method[] doGetClassMethods(Object bean) { } + private static class SpringAiSyncMcpProgressProvider extends SyncMcpProgressProvider { + + public SpringAiSyncMcpProgressProvider(List progressObjects) { + super(progressObjects); + } + + @Override + protected Method[] doGetClassMethods(Object bean) { + return AnnotationProviderUtil.beanMethods(bean); + } + + } + public static List createSyncToolSpecifications(List toolObjects) { return new SpringAiSyncToolProvider(toolObjects).getToolSpecifications(); } @@ -217,4 +232,8 @@ public static List createSyncElicitationSpecificat return new SpringAiSyncMcpElicitationProvider(elicitationObjects).getElicitationSpecifications(); } + public static List createSyncProgressSpecifications(List progressObjects) { + return new SpringAiSyncMcpProgressProvider(progressObjects).getProgressSpecifications(); + } + } diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/annotation/McpElicitation.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/annotation/McpElicitation.java index 139a110..343bca9 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/annotation/McpElicitation.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/annotation/McpElicitation.java @@ -11,7 +11,8 @@ import java.lang.annotation.Target; /** - * Annotation for methods that handle elicitation requests from MCP servers. + * Annotation for methods that handle elicitation requests from MCP servers. This + * annotation is applicable only for MCP clients. * *

* Methods annotated with this annotation can be used to process elicitation requests from diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/annotation/McpLoggingConsumer.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/annotation/McpLoggingConsumer.java index 9c75aef..9354d99 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/annotation/McpLoggingConsumer.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/annotation/McpLoggingConsumer.java @@ -11,7 +11,8 @@ import java.lang.annotation.Target; /** - * Annotation for methods that handle logging message notifications from MCP servers. + * Annotation for methods that handle logging message notifications from MCP servers. This + * annotation is applicable only for MCP clients. * *

* Methods annotated with this annotation can be used to consume logging messages from MCP diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/annotation/McpProgress.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/annotation/McpProgress.java new file mode 100644 index 0000000..4240c7d --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/annotation/McpProgress.java @@ -0,0 +1,45 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Annotation for methods that handle progress notifications from MCP servers. This + * annotation is applicable only for MCP clients. + * + *

+ * Methods annotated with this annotation can be used to consume progress messages from + * MCP servers. The methods takes a single parameter of type + * {@code ProgressMessageNotification} + * + * + *

+ * Example usage:

{@code
+ * @McpProgress
+ * public void handleProgressMessage(ProgressMessageNotification notification) {
+ *     // Handle the notification *
+ * }
+ * + * @author Christian Tzolov + * + * @see io.modelcontextprotocol.spec.McpSchema.ProgressMessageNotification + */ +@Target({ ElementType.METHOD, ElementType.ANNOTATION_TYPE }) +@Retention(RetentionPolicy.RUNTIME) +@Documented +public @interface McpProgress { + + /** + * Used as connection or client identifier to select the MCP client, the logging + * consumer is associated with. If not specified, is applied to all clients. + */ + String clientId() default ""; + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/annotation/McpSampling.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/annotation/McpSampling.java index ac97378..7a084db 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/annotation/McpSampling.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/annotation/McpSampling.java @@ -11,7 +11,8 @@ import java.lang.annotation.Target; /** - * Annotation for methods that handle sampling requests from MCP servers. + * Annotation for methods that handle sampling requests from MCP servers. This annotation + * is applicable only for MCP clients. * *

* Methods annotated with this annotation can be used to process sampling requests from diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/logging/AbstractMcpLoggingConsumerMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/logging/AbstractMcpLoggingConsumerMethodCallback.java index 1970ce5..301f467 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/logging/AbstractMcpLoggingConsumerMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/logging/AbstractMcpLoggingConsumerMethodCallback.java @@ -9,7 +9,6 @@ import org.springaicommunity.mcp.annotation.McpLoggingConsumer; -import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; import io.modelcontextprotocol.util.Assert; @@ -144,15 +143,6 @@ protected Object[] buildArgs(Method method, Object exchange, LoggingMessageNotif return args; } - /** - * Checks if a parameter type is compatible with the exchange type. This method should - * be implemented by subclasses to handle specific exchange type checking. - * @param paramType The parameter type to check - * @return true if the parameter type is compatible with the exchange type, false - * otherwise - */ - protected abstract boolean isExchangeType(Class paramType); - /** * Exception thrown when there is an error invoking a logging consumer method. */ diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/logging/AsyncMcpLoggingConsumerMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/logging/AsyncMcpLoggingConsumerMethodCallback.java index 7295383..4801e6e 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/logging/AsyncMcpLoggingConsumerMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/logging/AsyncMcpLoggingConsumerMethodCallback.java @@ -96,18 +96,6 @@ protected void validateReturnType(Method method) { } } - /** - * Checks if a parameter type is compatible with the exchange type. - * @param paramType The parameter type to check - * @return true if the parameter type is compatible with the exchange type, false - * otherwise - */ - @Override - protected boolean isExchangeType(Class paramType) { - // No exchange type for logging consumer methods - return false; - } - /** * Builder for creating AsyncMcpLoggingConsumerMethodCallback instances. *

diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/logging/SyncMcpLoggingConsumerMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/logging/SyncMcpLoggingConsumerMethodCallback.java index 7c4712b..7229c67 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/logging/SyncMcpLoggingConsumerMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/logging/SyncMcpLoggingConsumerMethodCallback.java @@ -73,18 +73,6 @@ protected void validateReturnType(Method method) { } } - /** - * Checks if a parameter type is compatible with the exchange type. - * @param paramType The parameter type to check - * @return true if the parameter type is compatible with the exchange type, false - * otherwise - */ - @Override - protected boolean isExchangeType(Class paramType) { - // No exchange type for logging consumer methods - return false; - } - /** * Builder for creating SyncMcpLoggingConsumerMethodCallback instances. *

diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/progress/AbstractMcpProgressMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/progress/AbstractMcpProgressMethodCallback.java new file mode 100644 index 0000000..09dc5cc --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/progress/AbstractMcpProgressMethodCallback.java @@ -0,0 +1,240 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.method.progress; + +import java.lang.reflect.Method; +import java.lang.reflect.Parameter; + +import org.springaicommunity.mcp.annotation.McpProgress; + +import io.modelcontextprotocol.spec.McpSchema.ProgressNotification; +import io.modelcontextprotocol.util.Assert; + +/** + * Abstract base class for creating callbacks around progress methods. + * + * This class provides common functionality for both synchronous and asynchronous progress + * method callbacks. It contains shared logic for method validation, argument building, + * and other common operations. + * + * @author Christian Tzolov + */ +public abstract class AbstractMcpProgressMethodCallback { + + protected final Method method; + + protected final Object bean; + + /** + * Constructor for AbstractMcpProgressMethodCallback. + * @param method The method to create a callback for + * @param bean The bean instance that contains the method + */ + protected AbstractMcpProgressMethodCallback(Method method, Object bean) { + Assert.notNull(method, "Method can't be null!"); + Assert.notNull(bean, "Bean can't be null!"); + + this.method = method; + this.bean = bean; + this.validateMethod(this.method); + } + + /** + * Validates that the method signature is compatible with the progress callback. + *

+ * This method checks that the return type is valid and that the parameters match the + * expected pattern. + * @param method The method to validate + * @throws IllegalArgumentException if the method signature is not compatible + */ + protected void validateMethod(Method method) { + if (method == null) { + throw new IllegalArgumentException("Method must not be null"); + } + + this.validateReturnType(method); + this.validateParameters(method); + } + + /** + * Validates that the method return type is compatible with the progress callback. + * This method should be implemented by subclasses to handle specific return type + * validation. + * @param method The method to validate + * @throws IllegalArgumentException if the return type is not compatible + */ + protected abstract void validateReturnType(Method method); + + /** + * Validates method parameters. This method provides common validation logic and + * delegates exchange type checking to subclasses. + * @param method The method to validate + * @throws IllegalArgumentException if the parameters are not compatible + */ + protected void validateParameters(Method method) { + Parameter[] parameters = method.getParameters(); + + // Check parameter count - must have either 1 or 3 parameters + if (parameters.length != 1 && parameters.length != 3) { + throw new IllegalArgumentException( + "Method must have either 1 parameter (ProgressNotification) or 3 parameters (Double, String, String): " + + method.getName() + " in " + method.getDeclaringClass().getName() + " has " + + parameters.length + " parameters"); + } + + // Check parameter types + if (parameters.length == 1) { + // Single parameter must be ProgressNotification + if (!ProgressNotification.class.isAssignableFrom(parameters[0].getType())) { + throw new IllegalArgumentException("Single parameter must be of type ProgressNotification: " + + method.getName() + " in " + method.getDeclaringClass().getName() + " has parameter of type " + + parameters[0].getType().getName()); + } + } + else { + // Three parameters must be Double, String, String + if (!Double.class.isAssignableFrom(parameters[0].getType()) + && !double.class.isAssignableFrom(parameters[0].getType())) { + throw new IllegalArgumentException("First parameter must be of type Double or double: " + + method.getName() + " in " + method.getDeclaringClass().getName() + " has parameter of type " + + parameters[0].getType().getName()); + } + if (!String.class.isAssignableFrom(parameters[1].getType())) { + throw new IllegalArgumentException("Second parameter must be of type String: " + method.getName() + + " in " + method.getDeclaringClass().getName() + " has parameter of type " + + parameters[1].getType().getName()); + } + if (!String.class.isAssignableFrom(parameters[2].getType())) { + throw new IllegalArgumentException("Third parameter must be of type String: " + method.getName() + + " in " + method.getDeclaringClass().getName() + " has parameter of type " + + parameters[2].getType().getName()); + } + } + } + + /** + * Builds the arguments array for invoking the method. + *

+ * This method constructs an array of arguments based on the method's parameter types + * and the available values (exchange, notification). + * @param method The method to build arguments for + * @param exchange The server exchange + * @param notification The progress notification + * @return An array of arguments for the method invocation + */ + protected Object[] buildArgs(Method method, Object exchange, ProgressNotification notification) { + Parameter[] parameters = method.getParameters(); + Object[] args = new Object[parameters.length]; + + if (parameters.length == 1) { + // Single parameter (ProgressNotification) + args[0] = notification; + } + else { + // Three parameters (Double, String, String) + args[0] = notification.progress(); + args[1] = notification.progressToken(); + args[2] = notification.total() != null ? String.valueOf(notification.total()) : null; + } + + return args; + } + + /** + * Exception thrown when there is an error invoking a progress method. + */ + public static class McpProgressMethodException extends RuntimeException { + + private static final long serialVersionUID = 1L; + + /** + * Constructs a new exception with the specified detail message and cause. + * @param message The detail message + * @param cause The cause + */ + public McpProgressMethodException(String message, Throwable cause) { + super(message, cause); + } + + /** + * Constructs a new exception with the specified detail message. + * @param message The detail message + */ + public McpProgressMethodException(String message) { + super(message); + } + + } + + /** + * Abstract builder for creating McpProgressMethodCallback instances. + *

+ * This builder provides a base for constructing callback instances with the required + * parameters. + * + * @param The type of the builder + * @param The type of the callback + */ + protected abstract static class AbstractBuilder, R> { + + protected Method method; + + protected Object bean; + + /** + * Set the method to create a callback for. + * @param method The method to create a callback for + * @return This builder + */ + @SuppressWarnings("unchecked") + public T method(Method method) { + this.method = method; + return (T) this; + } + + /** + * Set the bean instance that contains the method. + * @param bean The bean instance + * @return This builder + */ + @SuppressWarnings("unchecked") + public T bean(Object bean) { + this.bean = bean; + return (T) this; + } + + /** + * Set the progress annotation. + * @param progress The progress annotation + * @return This builder + */ + @SuppressWarnings("unchecked") + public T progress(McpProgress progress) { + // No additional configuration needed from the annotation at this time + return (T) this; + } + + /** + * Validate the builder state. + * @throws IllegalArgumentException if the builder state is invalid + */ + protected void validate() { + if (method == null) { + throw new IllegalArgumentException("Method must not be null"); + } + if (bean == null) { + throw new IllegalArgumentException("Bean must not be null"); + } + } + + /** + * Build the callback. + * @return A new callback instance + */ + public abstract R build(); + + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/progress/AsyncMcpProgressMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/progress/AsyncMcpProgressMethodCallback.java new file mode 100644 index 0000000..8557f0c --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/progress/AsyncMcpProgressMethodCallback.java @@ -0,0 +1,130 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.method.progress; + +import java.lang.reflect.Method; +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; +import java.util.function.Function; + +import io.modelcontextprotocol.spec.McpSchema.ProgressNotification; +import reactor.core.publisher.Mono; + +/** + * Asynchronous implementation of a progress method callback. + * + * This class creates a Function that invokes a method annotated with @McpProgress + * asynchronously when a progress notification is received, returning a Mono. + * + * @author Christian Tzolov + */ +public final class AsyncMcpProgressMethodCallback extends AbstractMcpProgressMethodCallback + implements Function> { + + private AsyncMcpProgressMethodCallback(Builder builder) { + super(builder.method, builder.bean); + } + + @Override + protected void validateReturnType(Method method) { + Class returnType = method.getReturnType(); + + // Check if return type is void or Mono + if (returnType == void.class) { + // void is acceptable - we'll wrap it in Mono + return; + } + + if (Mono.class.isAssignableFrom(returnType)) { + // Check if it's Mono + Type genericReturnType = method.getGenericReturnType(); + if (genericReturnType instanceof ParameterizedType paramType) { + Type[] typeArguments = paramType.getActualTypeArguments(); + if (typeArguments.length == 1 && typeArguments[0] == Void.class) { + // Mono is acceptable + return; + } + else { + throw new IllegalArgumentException("Mono return type must be Mono: " + method.getName() + + " in " + method.getDeclaringClass().getName() + " returns " + returnType.getName()); + } + } + } + + throw new IllegalArgumentException( + "Asynchronous progress methods must return void or Mono: " + method.getName() + " in " + + method.getDeclaringClass().getName() + " returns " + returnType.getName()); + } + + /** + * Apply the progress notification and process it asynchronously. + *

+ * This method builds the arguments for the method call and invokes the method, + * returning a Mono. + * @param notification The progress notification, must not be null + * @return A Mono representing the asynchronous operation + * @throws McpProgressMethodException if there is an error invoking the progress + * method + * @throws IllegalArgumentException if the notification is null + */ + @Override + public Mono apply(ProgressNotification notification) { + if (notification == null) { + return Mono.error(new IllegalArgumentException("Notification must not be null")); + } + + return Mono.fromCallable(() -> { + try { + // Build arguments for the method call + Object[] args = this.buildArgs(this.method, null, notification); + + // Invoke the method + this.method.setAccessible(true); + Object result = this.method.invoke(this.bean, args); + + // Handle return type + if (result instanceof Mono) { + return (Mono) result; + } + else { + // void return type + return Mono.empty(); + } + } + catch (Exception e) { + throw new McpProgressMethodException("Error invoking progress method: " + this.method.getName(), e); + } + }).flatMap(mono -> mono.then()); + } + + /** + * Builder for creating AsyncMcpProgressMethodCallback instances. + *

+ * This builder provides a fluent API for constructing AsyncMcpProgressMethodCallback + * instances with the required parameters. + */ + public static class Builder extends AbstractBuilder { + + /** + * Build the callback. + * @return A new AsyncMcpProgressMethodCallback instance + */ + @Override + public AsyncMcpProgressMethodCallback build() { + validate(); + return new AsyncMcpProgressMethodCallback(this); + } + + } + + /** + * Create a new builder. + * @return A new builder instance + */ + public static Builder builder() { + return new Builder(); + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/progress/AsyncProgressSpecification.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/progress/AsyncProgressSpecification.java new file mode 100644 index 0000000..9e38290 --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/progress/AsyncProgressSpecification.java @@ -0,0 +1,20 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.method.progress; + +import java.util.function.Function; + +import io.modelcontextprotocol.spec.McpSchema.ProgressNotification; +import reactor.core.publisher.Mono; + +/** + * Specification for asynchronous progress handlers. + * + * @param clientId The client ID for the progress handler + * @param progressHandler The function that handles progress notifications asynchronously + * @author Christian Tzolov + */ +public record AsyncProgressSpecification(String clientId, Function> progressHandler) { +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/progress/SyncMcpProgressMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/progress/SyncMcpProgressMethodCallback.java new file mode 100644 index 0000000..8f30a10 --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/progress/SyncMcpProgressMethodCallback.java @@ -0,0 +1,92 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.method.progress; + +import java.lang.reflect.Method; +import java.util.function.Consumer; + +import io.modelcontextprotocol.spec.McpSchema.ProgressNotification; + +/** + * Synchronous implementation of a progress method callback. + * + * This class creates a Consumer that invokes a method annotated with @McpProgress + * synchronously when a progress notification is received. + * + * @author Christian Tzolov + */ +public final class SyncMcpProgressMethodCallback extends AbstractMcpProgressMethodCallback + implements Consumer { + + private SyncMcpProgressMethodCallback(Builder builder) { + super(builder.method, builder.bean); + } + + @Override + protected void validateReturnType(Method method) { + // Synchronous methods must return void + if (!void.class.equals(method.getReturnType())) { + throw new IllegalArgumentException("Synchronous progress methods must return void: " + method.getName() + + " in " + method.getDeclaringClass().getName() + " returns " + method.getReturnType().getName()); + } + } + + /** + * Accept the progress notification and process it. + *

+ * This method builds the arguments for the method call and invokes the method. + * @param notification The progress notification, must not be null + * @throws McpProgressMethodException if there is an error invoking the progress + * method + * @throws IllegalArgumentException if the notification is null + */ + @Override + public void accept(ProgressNotification notification) { + if (notification == null) { + throw new IllegalArgumentException("Notification must not be null"); + } + + try { + // Build arguments for the method call + Object[] args = this.buildArgs(this.method, null, notification); + + // Invoke the method + this.method.setAccessible(true); + this.method.invoke(this.bean, args); + } + catch (Exception e) { + throw new McpProgressMethodException("Error invoking progress method: " + this.method.getName(), e); + } + } + + /** + * Builder for creating SyncMcpProgressMethodCallback instances. + *

+ * This builder provides a fluent API for constructing SyncMcpProgressMethodCallback + * instances with the required parameters. + */ + public static class Builder extends AbstractBuilder { + + /** + * Build the callback. + * @return A new SyncMcpProgressMethodCallback instance + */ + @Override + public SyncMcpProgressMethodCallback build() { + validate(); + return new SyncMcpProgressMethodCallback(this); + } + + } + + /** + * Create a new builder. + * @return A new builder instance + */ + public static Builder builder() { + return new Builder(); + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/progress/SyncProgressSpecification.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/progress/SyncProgressSpecification.java new file mode 100644 index 0000000..3a14b67 --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/progress/SyncProgressSpecification.java @@ -0,0 +1,19 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.method.progress; + +import java.util.function.Consumer; + +import io.modelcontextprotocol.spec.McpSchema.ProgressNotification; + +/** + * Specification for synchronous progress handlers. + * + * @param clientId The client ID for the progress handler + * @param progressHandler The consumer that handles progress notifications + * @author Christian Tzolov + */ +public record SyncProgressSpecification(String clientId, Consumer progressHandler) { +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/AsyncMcpProgressProvider.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/AsyncMcpProgressProvider.java new file mode 100644 index 0000000..8378380 --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/AsyncMcpProgressProvider.java @@ -0,0 +1,127 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springaicommunity.mcp.provider; + +import java.lang.reflect.Method; +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; +import java.util.List; +import java.util.function.Function; +import java.util.stream.Stream; + +import org.springaicommunity.mcp.annotation.McpProgress; +import org.springaicommunity.mcp.method.progress.AsyncProgressSpecification; +import org.springaicommunity.mcp.method.progress.AsyncMcpProgressMethodCallback; + +import io.modelcontextprotocol.spec.McpSchema.ProgressNotification; +import reactor.core.publisher.Mono; + +/** + * Provider for asynchronous progress callbacks. + * + *

+ * This class scans a list of objects for methods annotated with {@link McpProgress} and + * creates {@link Function} callbacks for them. These callbacks can be used to handle + * progress notifications from MCP servers asynchronously. + * + *

+ * Example usage:

{@code
+ * // Create a provider with a list of objects containing @McpProgress methods
+ * AsyncMcpProgressProvider provider = new AsyncMcpProgressProvider(List.of(progressHandler));
+ *
+ * // Get the list of progress callbacks
+ * List progressSpecs = provider.getProgressSpecifications();
+ *
+ * // Add the functions to the client features
+ * McpClientFeatures.Async clientFeatures = new McpClientFeatures.Async(
+ *     clientInfo, clientCapabilities, roots,
+ *     toolsChangeConsumers, resourcesChangeConsumers, promptsChangeConsumers,
+ *     loggingConsumers, progressHandlers, samplingHandler);
+ * }
+ * + * @author Christian Tzolov + * @see McpProgress + * @see AsyncMcpProgressMethodCallback + * @see ProgressNotification + */ +public class AsyncMcpProgressProvider { + + private final List progressObjects; + + /** + * Create a new AsyncMcpProgressProvider. + * @param progressObjects the objects containing methods annotated with + * {@link McpProgress} + */ + public AsyncMcpProgressProvider(List progressObjects) { + this.progressObjects = progressObjects != null ? progressObjects : List.of(); + } + + /** + * Get the list of progress specifications. + * @return the list of progress specifications + */ + public List getProgressSpecifications() { + + List progressHandlers = this.progressObjects.stream() + .map(progressObject -> Stream.of(doGetClassMethods(progressObject)) + .filter(method -> method.isAnnotationPresent(McpProgress.class)) + .filter(method -> { + // For async callbacks, only Mono is valid + Class returnType = method.getReturnType(); + if (!Mono.class.isAssignableFrom(returnType)) { + return false; + } + // Check if it's specifically Mono + Type genericReturnType = method.getGenericReturnType(); + if (genericReturnType instanceof ParameterizedType) { + ParameterizedType paramType = (ParameterizedType) genericReturnType; + Type[] typeArguments = paramType.getActualTypeArguments(); + if (typeArguments.length == 1) { + return typeArguments[0] == Void.class; + } + } + return false; + }) + .map(mcpProgressMethod -> { + var progressAnnotation = mcpProgressMethod.getAnnotation(McpProgress.class); + + Function> methodCallback = AsyncMcpProgressMethodCallback.builder() + .method(mcpProgressMethod) + .bean(progressObject) + .progress(progressAnnotation) + .build(); + + return new AsyncProgressSpecification(progressAnnotation.clientId(), methodCallback); + }) + .toList()) + .flatMap(List::stream) + .toList(); + + return progressHandlers; + } + + /** + * Returns the methods of the given bean class. + * @param bean the bean instance + * @return the methods of the bean class + */ + protected Method[] doGetClassMethods(Object bean) { + return bean.getClass().getDeclaredMethods(); + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/SyncMcpProgressProvider.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/SyncMcpProgressProvider.java new file mode 100644 index 0000000..dace4f2 --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/provider/SyncMcpProgressProvider.java @@ -0,0 +1,112 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springaicommunity.mcp.provider; + +import java.lang.reflect.Method; +import java.util.List; +import java.util.function.Consumer; +import java.util.stream.Stream; + +import org.springaicommunity.mcp.annotation.McpProgress; +import org.springaicommunity.mcp.method.progress.SyncProgressSpecification; +import org.springaicommunity.mcp.method.progress.SyncMcpProgressMethodCallback; + +import io.modelcontextprotocol.spec.McpSchema.ProgressNotification; +import reactor.core.publisher.Mono; + +/** + * Provider for synchronous progress callbacks. + * + *

+ * This class scans a list of objects for methods annotated with {@link McpProgress} and + * creates {@link Consumer} callbacks for them. These callbacks can be used to handle + * progress notifications from MCP servers. + * + *

+ * Example usage:

{@code
+ * // Create a provider with a list of objects containing @McpProgress methods
+ * SyncMcpProgressProvider provider = new SyncMcpProgressProvider(List.of(progressHandler));
+ *
+ * // Get the list of progress callbacks
+ * List progressSpecs = provider.getProgressSpecifications();
+ *
+ * // Add the consumers to the client features
+ * McpClientFeatures.Sync clientFeatures = new McpClientFeatures.Sync(
+ *     clientInfo, clientCapabilities, roots,
+ *     toolsChangeConsumers, resourcesChangeConsumers, promptsChangeConsumers,
+ *     loggingConsumers, progressConsumers, samplingHandler);
+ * }
+ * + * @author Christian Tzolov + * @see McpProgress + * @see SyncMcpProgressMethodCallback + * @see ProgressNotification + */ +public class SyncMcpProgressProvider { + + private final List progressObjects; + + /** + * Create a new SyncMcpProgressProvider. + * @param progressObjects the objects containing methods annotated with + * {@link McpProgress} + */ + public SyncMcpProgressProvider(List progressObjects) { + this.progressObjects = progressObjects != null ? progressObjects : List.of(); + } + + /** + * Get the list of progress specifications. + * @return the list of progress specifications + */ + public List getProgressSpecifications() { + + List progressConsumers = this.progressObjects.stream() + .map(progressObject -> Stream.of(doGetClassMethods(progressObject)) + .filter(method -> method.isAnnotationPresent(McpProgress.class)) + .filter(method -> !Mono.class.isAssignableFrom(method.getReturnType())) + .filter(method -> method.getReturnType() == void.class) // Only void + // return type is + // valid for sync + .map(mcpProgressMethod -> { + var progressAnnotation = mcpProgressMethod.getAnnotation(McpProgress.class); + + Consumer methodCallback = SyncMcpProgressMethodCallback.builder() + .method(mcpProgressMethod) + .bean(progressObject) + .progress(progressAnnotation) + .build(); + + return new SyncProgressSpecification(progressAnnotation.clientId(), methodCallback); + }) + .toList()) + .flatMap(List::stream) + .toList(); + + return progressConsumers; + } + + /** + * Returns the methods of the given bean class. + * @param bean the bean instance + * @return the methods of the bean class + */ + protected Method[] doGetClassMethods(Object bean) { + return bean.getClass().getDeclaredMethods(); + } + +} diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/progress/AsyncMcpProgressMethodCallbackExample.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/progress/AsyncMcpProgressMethodCallbackExample.java new file mode 100644 index 0000000..2d7ec6f --- /dev/null +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/progress/AsyncMcpProgressMethodCallbackExample.java @@ -0,0 +1,177 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.method.progress; + +import java.time.Duration; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; + +import org.springaicommunity.mcp.annotation.McpProgress; + +import io.modelcontextprotocol.spec.McpSchema.ProgressNotification; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * Example demonstrating the usage of {@link AsyncMcpProgressMethodCallback}. + * + * @author Christian Tzolov + */ +public class AsyncMcpProgressMethodCallbackExample { + + /** + * Example async service that handles progress notifications. + */ + public static class AsyncProgressService { + + private final AtomicInteger notificationCount = new AtomicInteger(0); + + /** + * Handle progress notification asynchronously with the full notification object. + * @param notification the progress notification + * @return Mono completing when processing is done + */ + @McpProgress + public Mono handleProgressNotificationAsync(ProgressNotification notification) { + return Mono.fromRunnable(() -> { + int count = notificationCount.incrementAndGet(); + System.out.printf("[Async] Progress Update #%d: Token=%s, Progress=%.2f%%, Total=%.0f, Message=%s%n", + count, notification.progressToken(), notification.progress() * 100, notification.total(), + notification.message()); + }) + .delayElement(Duration.ofMillis(100)) // Simulate async processing + .then(); + } + + /** + * Handle progress notification with individual parameters returning void. + * @param progress the progress value (0.0 to 1.0) + * @param progressToken the progress token identifier + * @param total the total value as string + */ + @McpProgress + public void handleProgressWithParams(Double progress, String progressToken, String total) { + System.out.printf("[Sync in Async] Progress: %.2f%% for token %s (Total: %s)%n", progress * 100, + progressToken, total); + } + + /** + * Handle progress asynchronously with individual parameters. + * @param progress the progress value (0.0 to 1.0) + * @param progressToken the progress token identifier + * @param total the total value as string + * @return Mono completing when processing is done + */ + @McpProgress + public Mono handleProgressWithParamsAsync(Double progress, String progressToken, String total) { + return Mono.fromRunnable(() -> { + System.out.printf("[Async Params] Progress: %.2f%% for token %s (Total: %s)%n", progress * 100, + progressToken, total); + }).delayElement(Duration.ofMillis(50)).then(); + } + + /** + * Handle progress with primitive double. + * @param progress the progress value (0.0 to 1.0) + * @param progressToken the progress token identifier + * @param total the total value as string + */ + @McpProgress + public void handleProgressPrimitive(double progress, String progressToken, String total) { + System.out.printf("[Primitive] Processing: %.1f%% complete (Token: %s)%n", progress * 100, progressToken); + } + + public int getNotificationCount() { + return notificationCount.get(); + } + + } + + public static void main(String[] args) throws Exception { + // Create the service instance + AsyncProgressService service = new AsyncProgressService(); + + // Build the async callback for the notification method + Function> asyncNotificationCallback = AsyncMcpProgressMethodCallback.builder() + .method(AsyncProgressService.class.getMethod("handleProgressNotificationAsync", ProgressNotification.class)) + .bean(service) + .build(); + + // Build the callback for the sync params method + Function> syncParamsCallback = AsyncMcpProgressMethodCallback.builder() + .method(AsyncProgressService.class.getMethod("handleProgressWithParams", Double.class, String.class, + String.class)) + .bean(service) + .build(); + + // Build the async callback for the params method + Function> asyncParamsCallback = AsyncMcpProgressMethodCallback.builder() + .method(AsyncProgressService.class.getMethod("handleProgressWithParamsAsync", Double.class, String.class, + String.class)) + .bean(service) + .build(); + + // Build the callback for the primitive method + Function> primitiveCallback = AsyncMcpProgressMethodCallback.builder() + .method(AsyncProgressService.class.getMethod("handleProgressPrimitive", double.class, String.class, + String.class)) + .bean(service) + .build(); + + System.out.println("=== Async Progress Notification Example ==="); + + // Create a flux of progress notifications + Flux progressFlux = Flux.just( + new ProgressNotification("async-task-001", 0.0, 100.0, "Starting async operation..."), + new ProgressNotification("async-task-001", 0.25, 100.0, "Processing batch 1..."), + new ProgressNotification("async-task-001", 0.5, 100.0, "Halfway through..."), + new ProgressNotification("async-task-001", 0.75, 100.0, "Processing batch 3..."), + new ProgressNotification("async-task-001", 1.0, 100.0, "Operation completed successfully!")); + + // Process notifications with different callbacks + Mono processing = progressFlux.index().flatMap(indexed -> { + Long index = indexed.getT1(); + ProgressNotification notification = indexed.getT2(); + + // Use different callbacks based on index + if (index == 0) { + return asyncNotificationCallback.apply(notification); + } + else if (index == 1) { + return syncParamsCallback.apply(notification); + } + else if (index == 2) { + return asyncParamsCallback.apply(notification); + } + else if (index == 3) { + return primitiveCallback.apply(notification); + } + else { + return asyncNotificationCallback.apply(notification); + } + }).then(); + + // Block and wait for all processing to complete + System.out.println("Processing notifications asynchronously..."); + processing.block(); + + System.out.printf("%nTotal async notifications handled: %d%n", service.getNotificationCount()); + + // Demonstrate concurrent processing + System.out.println("\n=== Concurrent Progress Processing ==="); + + Flux concurrentNotifications = Flux.range(1, 5) + .map(i -> new ProgressNotification("concurrent-task-" + i, i * 0.2, 100.0, "Processing task " + i)); + + concurrentNotifications + .flatMap(notification -> asyncNotificationCallback.apply(notification) + .doOnSubscribe(s -> System.out.println("Starting: " + notification.progressToken())) + .doOnSuccess(v -> System.out.println("Completed: " + notification.progressToken()))) + .blockLast(); + + System.out.println("\nAll async operations completed!"); + } + +} diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/progress/AsyncMcpProgressMethodCallbackTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/progress/AsyncMcpProgressMethodCallbackTests.java new file mode 100644 index 0000000..a07c933 --- /dev/null +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/progress/AsyncMcpProgressMethodCallbackTests.java @@ -0,0 +1,278 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.method.progress; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.lang.reflect.Method; +import java.util.function.Function; + +import org.junit.jupiter.api.Test; +import org.springaicommunity.mcp.annotation.McpProgress; + +import io.modelcontextprotocol.spec.McpSchema.ProgressNotification; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +/** + * Tests for {@link AsyncMcpProgressMethodCallback}. + * + * @author Christian Tzolov + */ +public class AsyncMcpProgressMethodCallbackTests { + + // ProgressNotification constructor: (String progressToken, double progress, Double + // total, String message) + private static final ProgressNotification TEST_NOTIFICATION = new ProgressNotification("progress-token-123", // progressToken + 0.5, // progress + 100.0, // total + "Processing..." // message + ); + + /** + * Test class with valid methods. + */ + static class ValidMethods { + + private ProgressNotification lastNotification; + + private Double lastProgress; + + private String lastProgressToken; + + private String lastTotal; + + @McpProgress + public void handleProgressVoid(ProgressNotification notification) { + this.lastNotification = notification; + } + + @McpProgress + public Mono handleProgressMono(ProgressNotification notification) { + this.lastNotification = notification; + return Mono.empty(); + } + + @McpProgress + public void handleProgressWithParams(Double progress, String progressToken, String total) { + this.lastProgress = progress; + this.lastProgressToken = progressToken; + this.lastTotal = total; + } + + @McpProgress + public Mono handleProgressWithParamsMono(Double progress, String progressToken, String total) { + this.lastProgress = progress; + this.lastProgressToken = progressToken; + this.lastTotal = total; + return Mono.empty(); + } + + @McpProgress + public void handleProgressWithPrimitiveDouble(double progress, String progressToken, String total) { + this.lastProgress = progress; + this.lastProgressToken = progressToken; + this.lastTotal = total; + } + + } + + /** + * Test class with invalid methods. + */ + static class InvalidMethods { + + @McpProgress + public String invalidReturnType(ProgressNotification notification) { + return "Invalid"; + } + + @McpProgress + public Mono invalidMonoReturnType(ProgressNotification notification) { + return Mono.just("Invalid"); + } + + @McpProgress + public void invalidParameterCount(ProgressNotification notification, String extra) { + // Invalid parameter count + } + + @McpProgress + public void invalidParameterType(String invalidType) { + // Invalid parameter type + } + + @McpProgress + public void invalidParameterTypes(String progress, int progressToken, boolean total) { + // Invalid parameter types + } + + @McpProgress + public void invalidFirstParameterType(String progress, String progressToken, String total) { + // Invalid first parameter type + } + + @McpProgress + public void invalidSecondParameterType(Double progress, int progressToken, String total) { + // Invalid second parameter type + } + + @McpProgress + public void invalidThirdParameterType(Double progress, String progressToken, int total) { + // Invalid third parameter type + } + + } + + @Test + void testValidVoidMethod() throws Exception { + ValidMethods bean = new ValidMethods(); + Method method = ValidMethods.class.getMethod("handleProgressVoid", ProgressNotification.class); + + Function> callback = AsyncMcpProgressMethodCallback.builder() + .method(method) + .bean(bean) + .build(); + + StepVerifier.create(callback.apply(TEST_NOTIFICATION)).verifyComplete(); + + assertThat(bean.lastNotification).isEqualTo(TEST_NOTIFICATION); + } + + @Test + void testValidMethodWithNotification() throws Exception { + ValidMethods bean = new ValidMethods(); + Method method = ValidMethods.class.getMethod("handleProgressMono", ProgressNotification.class); + + Function> callback = AsyncMcpProgressMethodCallback.builder() + .method(method) + .bean(bean) + .build(); + + StepVerifier.create(callback.apply(TEST_NOTIFICATION)).verifyComplete(); + + assertThat(bean.lastNotification).isEqualTo(TEST_NOTIFICATION); + } + + @Test + void testValidMethodWithParams() throws Exception { + ValidMethods bean = new ValidMethods(); + Method method = ValidMethods.class.getMethod("handleProgressWithParams", Double.class, String.class, + String.class); + + Function> callback = AsyncMcpProgressMethodCallback.builder() + .method(method) + .bean(bean) + .build(); + + StepVerifier.create(callback.apply(TEST_NOTIFICATION)).verifyComplete(); + + assertThat(bean.lastProgress).isEqualTo(TEST_NOTIFICATION.progress()); + assertThat(bean.lastProgressToken).isEqualTo(TEST_NOTIFICATION.progressToken()); + assertThat(bean.lastTotal).isEqualTo(String.valueOf(TEST_NOTIFICATION.total())); + } + + @Test + void testValidMethodWithParamsMono() throws Exception { + ValidMethods bean = new ValidMethods(); + Method method = ValidMethods.class.getMethod("handleProgressWithParamsMono", Double.class, String.class, + String.class); + + Function> callback = AsyncMcpProgressMethodCallback.builder() + .method(method) + .bean(bean) + .build(); + + StepVerifier.create(callback.apply(TEST_NOTIFICATION)).verifyComplete(); + + assertThat(bean.lastProgress).isEqualTo(TEST_NOTIFICATION.progress()); + assertThat(bean.lastProgressToken).isEqualTo(TEST_NOTIFICATION.progressToken()); + assertThat(bean.lastTotal).isEqualTo(String.valueOf(TEST_NOTIFICATION.total())); + } + + @Test + void testValidMethodWithPrimitiveDouble() throws Exception { + ValidMethods bean = new ValidMethods(); + Method method = ValidMethods.class.getMethod("handleProgressWithPrimitiveDouble", double.class, String.class, + String.class); + + Function> callback = AsyncMcpProgressMethodCallback.builder() + .method(method) + .bean(bean) + .build(); + + StepVerifier.create(callback.apply(TEST_NOTIFICATION)).verifyComplete(); + + assertThat(bean.lastProgress).isEqualTo(TEST_NOTIFICATION.progress()); + assertThat(bean.lastProgressToken).isEqualTo(TEST_NOTIFICATION.progressToken()); + assertThat(bean.lastTotal).isEqualTo(String.valueOf(TEST_NOTIFICATION.total())); + } + + @Test + void testInvalidReturnType() throws Exception { + InvalidMethods bean = new InvalidMethods(); + Method method = InvalidMethods.class.getMethod("invalidReturnType", ProgressNotification.class); + + assertThatThrownBy(() -> AsyncMcpProgressMethodCallback.builder().method(method).bean(bean).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Asynchronous progress methods must return void or Mono"); + } + + @Test + void testInvalidMonoReturnType() throws Exception { + InvalidMethods bean = new InvalidMethods(); + Method method = InvalidMethods.class.getMethod("invalidMonoReturnType", ProgressNotification.class); + + assertThatThrownBy(() -> AsyncMcpProgressMethodCallback.builder().method(method).bean(bean).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Mono return type must be Mono"); + } + + @Test + void testInvalidParameterCount() throws Exception { + InvalidMethods bean = new InvalidMethods(); + Method method = InvalidMethods.class.getMethod("invalidParameterCount", ProgressNotification.class, + String.class); + + assertThatThrownBy(() -> AsyncMcpProgressMethodCallback.builder().method(method).bean(bean).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Method must have either 1 parameter (ProgressNotification) or 3 parameters"); + } + + @Test + void testInvalidParameterType() throws Exception { + InvalidMethods bean = new InvalidMethods(); + Method method = InvalidMethods.class.getMethod("invalidParameterType", String.class); + + assertThatThrownBy(() -> AsyncMcpProgressMethodCallback.builder().method(method).bean(bean).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Single parameter must be of type ProgressNotification"); + } + + @Test + void testInvalidParameterTypes() throws Exception { + InvalidMethods bean = new InvalidMethods(); + Method method = InvalidMethods.class.getMethod("invalidParameterTypes", String.class, int.class, boolean.class); + + assertThatThrownBy(() -> AsyncMcpProgressMethodCallback.builder().method(method).bean(bean).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("First parameter must be of type Double or double"); + } + + @Test + void testNullNotification() throws Exception { + ValidMethods bean = new ValidMethods(); + Method method = ValidMethods.class.getMethod("handleProgressMono", ProgressNotification.class); + + Function> callback = AsyncMcpProgressMethodCallback.builder() + .method(method) + .bean(bean) + .build(); + + StepVerifier.create(callback.apply(null)).expectError(IllegalArgumentException.class).verify(); + } + +} diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/progress/SyncMcpProgressMethodCallbackExample.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/progress/SyncMcpProgressMethodCallbackExample.java new file mode 100644 index 0000000..fb7aa99 --- /dev/null +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/progress/SyncMcpProgressMethodCallbackExample.java @@ -0,0 +1,120 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.method.progress; + +import java.util.function.Consumer; + +import org.springaicommunity.mcp.annotation.McpProgress; + +import io.modelcontextprotocol.spec.McpSchema.ProgressNotification; + +/** + * Example demonstrating the usage of {@link SyncMcpProgressMethodCallback}. + * + * @author Christian Tzolov + */ +public class SyncMcpProgressMethodCallbackExample { + + /** + * Example service that handles progress notifications. + */ + public static class ProgressService { + + private int notificationCount = 0; + + /** + * Handle progress notification with the full notification object. + * @param notification the progress notification + */ + @McpProgress + public void handleProgressNotification(ProgressNotification notification) { + notificationCount++; + System.out.printf("Progress Update #%d: Token=%s, Progress=%.2f%%, Total=%.0f, Message=%s%n", + notificationCount, notification.progressToken(), notification.progress() * 100, + notification.total(), notification.message()); + } + + /** + * Handle progress notification with individual parameters. + * @param progress the progress value (0.0 to 1.0) + * @param progressToken the progress token identifier + * @param total the total value as string + */ + @McpProgress + public void handleProgressWithParams(Double progress, String progressToken, String total) { + System.out.printf("Progress: %.2f%% for token %s (Total: %s)%n", progress * 100, progressToken, total); + } + + /** + * Handle progress with primitive double. + * @param progress the progress value (0.0 to 1.0) + * @param progressToken the progress token identifier + * @param total the total value as string + */ + @McpProgress + public void handleProgressPrimitive(double progress, String progressToken, String total) { + System.out.printf("Processing: %.1f%% complete (Token: %s)%n", progress * 100, progressToken); + } + + public int getNotificationCount() { + return notificationCount; + } + + } + + public static void main(String[] args) throws Exception { + // Create the service instance + ProgressService service = new ProgressService(); + + // Build the callback for the notification method + Consumer notificationCallback = SyncMcpProgressMethodCallback.builder() + .method(ProgressService.class.getMethod("handleProgressNotification", ProgressNotification.class)) + .bean(service) + .build(); + + // Build the callback for the params method + Consumer paramsCallback = SyncMcpProgressMethodCallback.builder() + .method(ProgressService.class.getMethod("handleProgressWithParams", Double.class, String.class, + String.class)) + .bean(service) + .build(); + + // Build the callback for the primitive method + Consumer primitiveCallback = SyncMcpProgressMethodCallback.builder() + .method(ProgressService.class.getMethod("handleProgressPrimitive", double.class, String.class, + String.class)) + .bean(service) + .build(); + + // Simulate progress notifications + System.out.println("=== Progress Notification Example ==="); + + // Start of operation + ProgressNotification startNotification = new ProgressNotification("task-001", 0.0, 100.0, + "Starting operation..."); + notificationCallback.accept(startNotification); + + // Progress updates + ProgressNotification progressNotification1 = new ProgressNotification("task-001", 0.25, 100.0, + "Processing batch 1..."); + paramsCallback.accept(progressNotification1); + + ProgressNotification progressNotification2 = new ProgressNotification("task-001", 0.5, 100.0, + "Halfway through..."); + primitiveCallback.accept(progressNotification2); + + ProgressNotification progressNotification3 = new ProgressNotification("task-001", 0.75, 100.0, + "Processing batch 3..."); + notificationCallback.accept(progressNotification3); + + // Completion + ProgressNotification completeNotification = new ProgressNotification("task-001", 1.0, 100.0, + "Operation completed successfully!"); + notificationCallback.accept(completeNotification); + + System.out.printf("%nTotal notifications handled: %d%n", service.getNotificationCount()); + } + +} diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/progress/SyncMcpProgressMethodCallbackTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/progress/SyncMcpProgressMethodCallbackTests.java new file mode 100644 index 0000000..7909ad3 --- /dev/null +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/progress/SyncMcpProgressMethodCallbackTests.java @@ -0,0 +1,248 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.method.progress; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.lang.reflect.Method; +import java.util.function.Consumer; + +import org.junit.jupiter.api.Test; +import org.springaicommunity.mcp.annotation.McpProgress; + +import io.modelcontextprotocol.spec.McpSchema.ProgressNotification; + +/** + * Tests for {@link SyncMcpProgressMethodCallback}. + * + * @author Christian Tzolov + */ +public class SyncMcpProgressMethodCallbackTests { + + // ProgressNotification constructor: (String progressToken, double progress, Double + // total, String message) + private static final ProgressNotification TEST_NOTIFICATION = new ProgressNotification("progress-token-123", // progressToken + 0.5, // progress + 100.0, // total + "Processing..." // message + ); + + /** + * Test class with valid methods. + */ + static class ValidMethods { + + private ProgressNotification lastNotification; + + private Double lastProgress; + + private String lastProgressToken; + + private String lastTotal; + + @McpProgress + public void handleProgressNotification(ProgressNotification notification) { + this.lastNotification = notification; + } + + @McpProgress + public void handleProgressWithParams(Double progress, String progressToken, String total) { + this.lastProgress = progress; + this.lastProgressToken = progressToken; + this.lastTotal = total; + } + + @McpProgress + public void handleProgressWithPrimitiveDouble(double progress, String progressToken, String total) { + this.lastProgress = progress; + this.lastProgressToken = progressToken; + this.lastTotal = total; + } + + } + + /** + * Test class with invalid methods. + */ + static class InvalidMethods { + + @McpProgress + public String invalidReturnType(ProgressNotification notification) { + return "Invalid"; + } + + @McpProgress + public void invalidParameterCount(ProgressNotification notification, String extra) { + // Invalid parameter count + } + + @McpProgress + public void invalidParameterType(String invalidType) { + // Invalid parameter type + } + + @McpProgress + public void invalidParameterTypes(String progress, int progressToken, boolean total) { + // Invalid parameter types + } + + @McpProgress + public void invalidFirstParameterType(String progress, String progressToken, String total) { + // Invalid first parameter type + } + + @McpProgress + public void invalidSecondParameterType(Double progress, int progressToken, String total) { + // Invalid second parameter type + } + + @McpProgress + public void invalidThirdParameterType(Double progress, String progressToken, int total) { + // Invalid third parameter type + } + + } + + @Test + void testValidMethodWithNotification() throws Exception { + ValidMethods bean = new ValidMethods(); + Method method = ValidMethods.class.getMethod("handleProgressNotification", ProgressNotification.class); + + Consumer callback = SyncMcpProgressMethodCallback.builder() + .method(method) + .bean(bean) + .build(); + + callback.accept(TEST_NOTIFICATION); + + assertThat(bean.lastNotification).isEqualTo(TEST_NOTIFICATION); + } + + @Test + void testValidMethodWithParams() throws Exception { + ValidMethods bean = new ValidMethods(); + Method method = ValidMethods.class.getMethod("handleProgressWithParams", Double.class, String.class, + String.class); + + Consumer callback = SyncMcpProgressMethodCallback.builder() + .method(method) + .bean(bean) + .build(); + + callback.accept(TEST_NOTIFICATION); + + assertThat(bean.lastProgress).isEqualTo(TEST_NOTIFICATION.progress()); + assertThat(bean.lastProgressToken).isEqualTo(TEST_NOTIFICATION.progressToken()); + assertThat(bean.lastTotal).isEqualTo(String.valueOf(TEST_NOTIFICATION.total())); + } + + @Test + void testValidMethodWithPrimitiveDouble() throws Exception { + ValidMethods bean = new ValidMethods(); + Method method = ValidMethods.class.getMethod("handleProgressWithPrimitiveDouble", double.class, String.class, + String.class); + + Consumer callback = SyncMcpProgressMethodCallback.builder() + .method(method) + .bean(bean) + .build(); + + callback.accept(TEST_NOTIFICATION); + + assertThat(bean.lastProgress).isEqualTo(TEST_NOTIFICATION.progress()); + assertThat(bean.lastProgressToken).isEqualTo(TEST_NOTIFICATION.progressToken()); + assertThat(bean.lastTotal).isEqualTo(String.valueOf(TEST_NOTIFICATION.total())); + } + + @Test + void testInvalidReturnType() throws Exception { + InvalidMethods bean = new InvalidMethods(); + Method method = InvalidMethods.class.getMethod("invalidReturnType", ProgressNotification.class); + + assertThatThrownBy(() -> SyncMcpProgressMethodCallback.builder().method(method).bean(bean).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Synchronous progress methods must return void"); + } + + @Test + void testInvalidParameterCount() throws Exception { + InvalidMethods bean = new InvalidMethods(); + Method method = InvalidMethods.class.getMethod("invalidParameterCount", ProgressNotification.class, + String.class); + + assertThatThrownBy(() -> SyncMcpProgressMethodCallback.builder().method(method).bean(bean).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Method must have either 1 parameter (ProgressNotification) or 3 parameters"); + } + + @Test + void testInvalidParameterType() throws Exception { + InvalidMethods bean = new InvalidMethods(); + Method method = InvalidMethods.class.getMethod("invalidParameterType", String.class); + + assertThatThrownBy(() -> SyncMcpProgressMethodCallback.builder().method(method).bean(bean).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Single parameter must be of type ProgressNotification"); + } + + @Test + void testInvalidParameterTypes() throws Exception { + InvalidMethods bean = new InvalidMethods(); + Method method = InvalidMethods.class.getMethod("invalidParameterTypes", String.class, int.class, boolean.class); + + assertThatThrownBy(() -> SyncMcpProgressMethodCallback.builder().method(method).bean(bean).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("First parameter must be of type Double or double"); + } + + @Test + void testInvalidFirstParameterType() throws Exception { + InvalidMethods bean = new InvalidMethods(); + Method method = InvalidMethods.class.getMethod("invalidFirstParameterType", String.class, String.class, + String.class); + + assertThatThrownBy(() -> SyncMcpProgressMethodCallback.builder().method(method).bean(bean).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("First parameter must be of type Double or double"); + } + + @Test + void testInvalidSecondParameterType() throws Exception { + InvalidMethods bean = new InvalidMethods(); + Method method = InvalidMethods.class.getMethod("invalidSecondParameterType", Double.class, int.class, + String.class); + + assertThatThrownBy(() -> SyncMcpProgressMethodCallback.builder().method(method).bean(bean).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Second parameter must be of type String"); + } + + @Test + void testInvalidThirdParameterType() throws Exception { + InvalidMethods bean = new InvalidMethods(); + Method method = InvalidMethods.class.getMethod("invalidThirdParameterType", Double.class, String.class, + int.class); + + assertThatThrownBy(() -> SyncMcpProgressMethodCallback.builder().method(method).bean(bean).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Third parameter must be of type String"); + } + + @Test + void testNullNotification() throws Exception { + ValidMethods bean = new ValidMethods(); + Method method = ValidMethods.class.getMethod("handleProgressNotification", ProgressNotification.class); + + Consumer callback = SyncMcpProgressMethodCallback.builder() + .method(method) + .bean(bean) + .build(); + + assertThatThrownBy(() -> callback.accept(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Notification must not be null"); + } + +} diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/AsyncMcpProgressProviderTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/AsyncMcpProgressProviderTests.java new file mode 100644 index 0000000..22409a4 --- /dev/null +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/AsyncMcpProgressProviderTests.java @@ -0,0 +1,206 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.provider; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; +import java.util.function.Function; + +import org.junit.jupiter.api.Test; +import org.springaicommunity.mcp.annotation.McpProgress; +import org.springaicommunity.mcp.method.progress.AsyncProgressSpecification; + +import io.modelcontextprotocol.spec.McpSchema.ProgressNotification; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +/** + * Tests for {@link AsyncMcpProgressProvider}. + * + * @author Christian Tzolov + */ +public class AsyncMcpProgressProviderTests { + + /** + * Test class with async progress handler methods. + */ + static class AsyncProgressHandler { + + private ProgressNotification lastNotification; + + private Double lastProgress; + + private String lastProgressToken; + + private String lastTotal; + + @McpProgress + public void handleProgressVoid(ProgressNotification notification) { + this.lastNotification = notification; + } + + @McpProgress + public Mono handleProgressMono(ProgressNotification notification) { + this.lastNotification = notification; + return Mono.empty(); + } + + @McpProgress + public void handleProgressWithParams(Double progress, String progressToken, String total) { + this.lastProgress = progress; + this.lastProgressToken = progressToken; + this.lastTotal = total; + } + + @McpProgress + public Mono handleProgressWithParamsMono(Double progress, String progressToken, String total) { + this.lastProgress = progress; + this.lastProgressToken = progressToken; + this.lastTotal = total; + return Mono.empty(); + } + + @McpProgress + public void handleProgressWithPrimitiveDouble(double progress, String progressToken, String total) { + this.lastProgress = progress; + this.lastProgressToken = progressToken; + this.lastTotal = total; + } + + // This method is not annotated and should be ignored + public Mono notAnnotatedMethod(ProgressNotification notification) { + // This method should be ignored + return Mono.empty(); + } + + // This method has invalid return type and should be ignored + @McpProgress + public String invalidReturnType(ProgressNotification notification) { + return "Invalid"; + } + + // This method has invalid Mono return type and should be ignored + @McpProgress + public Mono invalidMonoReturnType(ProgressNotification notification) { + return Mono.just("Invalid"); + } + + } + + @Test + void testGetProgressSpecifications() { + AsyncProgressHandler progressHandler = new AsyncProgressHandler(); + AsyncMcpProgressProvider provider = new AsyncMcpProgressProvider(List.of(progressHandler)); + + List specifications = provider.getProgressSpecifications(); + List>> handlers = specifications.stream() + .map(AsyncProgressSpecification::progressHandler) + .toList(); + + // Should find 2 valid annotated methods (only Mono methods are valid for + // async) + assertThat(handlers).hasSize(2); + + // Test the first handler (Mono method) + ProgressNotification notification = new ProgressNotification("test-token-123", 0.5, 100.0, + "Test progress message"); + + StepVerifier.create(handlers.get(0).apply(notification)).verifyComplete(); + assertThat(progressHandler.lastNotification).isEqualTo(notification); + + // Reset + progressHandler.lastNotification = null; + progressHandler.lastProgress = null; + progressHandler.lastProgressToken = null; + progressHandler.lastTotal = null; + + // Test the second handler (Mono with params) + StepVerifier.create(handlers.get(1).apply(notification)).verifyComplete(); + assertThat(progressHandler.lastProgress).isEqualTo(notification.progress()); + assertThat(progressHandler.lastProgressToken).isEqualTo(notification.progressToken()); + assertThat(progressHandler.lastTotal).isEqualTo(String.valueOf(notification.total())); + } + + @Test + void testEmptyList() { + AsyncMcpProgressProvider provider = new AsyncMcpProgressProvider(List.of()); + + List>> handlers = provider.getProgressSpecifications() + .stream() + .map(AsyncProgressSpecification::progressHandler) + .toList(); + + assertThat(handlers).isEmpty(); + } + + @Test + void testMultipleObjects() { + AsyncProgressHandler handler1 = new AsyncProgressHandler(); + AsyncProgressHandler handler2 = new AsyncProgressHandler(); + AsyncMcpProgressProvider provider = new AsyncMcpProgressProvider(List.of(handler1, handler2)); + + List>> handlers = provider.getProgressSpecifications() + .stream() + .map(AsyncProgressSpecification::progressHandler) + .toList(); + + // Should find 4 valid annotated methods (2 from each handler - only Mono + // methods) + assertThat(handlers).hasSize(4); + } + + @Test + void testNullProgressObjects() { + AsyncMcpProgressProvider provider = new AsyncMcpProgressProvider(null); + + List>> handlers = provider.getProgressSpecifications() + .stream() + .map(AsyncProgressSpecification::progressHandler) + .toList(); + + assertThat(handlers).isEmpty(); + } + + @Test + void testClientIdExtraction() { + AsyncProgressHandler handler = new AsyncProgressHandler(); + AsyncMcpProgressProvider provider = new AsyncMcpProgressProvider(List.of(handler)); + + List specifications = provider.getProgressSpecifications(); + + // All specifications should have empty clientId (default value from annotation) + assertThat(specifications).allMatch(spec -> spec.clientId().equals("")); + } + + @Test + void testErrorHandling() { + // Test class with method that throws an exception + class ErrorHandler { + + @McpProgress + public Mono handleProgressWithError(ProgressNotification notification) { + return Mono.error(new RuntimeException("Test error")); + } + + } + + ErrorHandler errorHandler = new ErrorHandler(); + AsyncMcpProgressProvider provider = new AsyncMcpProgressProvider(List.of(errorHandler)); + + List>> handlers = provider.getProgressSpecifications() + .stream() + .map(AsyncProgressSpecification::progressHandler) + .toList(); + + assertThat(handlers).hasSize(1); + + ProgressNotification notification = new ProgressNotification("error-token", 0.5, 100.0, "Error test"); + + // Verify that the error is propagated correctly + StepVerifier.create(handlers.get(0).apply(notification)).expectError(RuntimeException.class).verify(); + } + +} diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/SyncMcpProgressProviderTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/SyncMcpProgressProviderTests.java new file mode 100644 index 0000000..0b7356e --- /dev/null +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/provider/SyncMcpProgressProviderTests.java @@ -0,0 +1,152 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.provider; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; +import java.util.function.Consumer; + +import org.junit.jupiter.api.Test; +import org.springaicommunity.mcp.annotation.McpProgress; +import org.springaicommunity.mcp.method.progress.SyncProgressSpecification; + +import io.modelcontextprotocol.spec.McpSchema.ProgressNotification; + +/** + * Tests for {@link SyncMcpProgressProvider}. + * + * @author Christian Tzolov + */ +public class SyncMcpProgressProviderTests { + + /** + * Test class with progress handler methods. + */ + static class ProgressHandler { + + private ProgressNotification lastNotification; + + private Double lastProgress; + + private String lastProgressToken; + + private String lastTotal; + + @McpProgress + public void handleProgressNotification(ProgressNotification notification) { + this.lastNotification = notification; + } + + @McpProgress + public void handleProgressWithParams(Double progress, String progressToken, String total) { + this.lastProgress = progress; + this.lastProgressToken = progressToken; + this.lastTotal = total; + } + + @McpProgress + public void handleProgressWithPrimitiveDouble(double progress, String progressToken, String total) { + this.lastProgress = progress; + this.lastProgressToken = progressToken; + this.lastTotal = total; + } + + // This method is not annotated and should be ignored + public void notAnnotatedMethod(ProgressNotification notification) { + // This method should be ignored + } + + // This method has invalid return type and should be ignored + @McpProgress + public String invalidReturnType(ProgressNotification notification) { + return "Invalid"; + } + + } + + @Test + void testGetProgressSpecifications() { + ProgressHandler progressHandler = new ProgressHandler(); + SyncMcpProgressProvider provider = new SyncMcpProgressProvider(List.of(progressHandler)); + + List specifications = provider.getProgressSpecifications(); + List> consumers = specifications.stream() + .map(SyncProgressSpecification::progressHandler) + .toList(); + + // Should find 3 valid annotated methods (invalid return type method is filtered + // out) + assertThat(consumers).hasSize(3); + + // Test all consumers and verify at least one sets each expected field + ProgressNotification notification = new ProgressNotification("test-token-123", 0.5, 100.0, + "Test progress message"); + + // Call all consumers + for (Consumer consumer : consumers) { + consumer.accept(notification); + } + + // Verify that at least one method set the notification + assertThat(progressHandler.lastNotification).isEqualTo(notification); + + // Verify that at least one method set the individual parameters + assertThat(progressHandler.lastProgress).isEqualTo(notification.progress()); + assertThat(progressHandler.lastProgressToken).isEqualTo(notification.progressToken()); + assertThat(progressHandler.lastTotal).isEqualTo(String.valueOf(notification.total())); + } + + @Test + void testEmptyList() { + SyncMcpProgressProvider provider = new SyncMcpProgressProvider(List.of()); + + List> consumers = provider.getProgressSpecifications() + .stream() + .map(SyncProgressSpecification::progressHandler) + .toList(); + + assertThat(consumers).isEmpty(); + } + + @Test + void testMultipleObjects() { + ProgressHandler handler1 = new ProgressHandler(); + ProgressHandler handler2 = new ProgressHandler(); + SyncMcpProgressProvider provider = new SyncMcpProgressProvider(List.of(handler1, handler2)); + + List> consumers = provider.getProgressSpecifications() + .stream() + .map(SyncProgressSpecification::progressHandler) + .toList(); + + // Should find 6 valid annotated methods (3 from each handler) + assertThat(consumers).hasSize(6); + } + + @Test + void testNullProgressObjects() { + SyncMcpProgressProvider provider = new SyncMcpProgressProvider(null); + + List> consumers = provider.getProgressSpecifications() + .stream() + .map(SyncProgressSpecification::progressHandler) + .toList(); + + assertThat(consumers).isEmpty(); + } + + @Test + void testClientIdExtraction() { + ProgressHandler handler = new ProgressHandler(); + SyncMcpProgressProvider provider = new SyncMcpProgressProvider(List.of(handler)); + + List specifications = provider.getProgressSpecifications(); + + // All specifications should have empty clientId (default value from annotation) + assertThat(specifications).allMatch(spec -> spec.clientId().equals("")); + } + +}