Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 234 additions & 9 deletions README.md

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,12 @@
import org.springaicommunity.mcp.provider.AsyncMcpLoggingConsumerProvider;
import org.springaicommunity.mcp.provider.AsyncMcpSamplingProvider;
import org.springaicommunity.mcp.provider.AsyncMcpToolProvider;
import org.springaicommunity.mcp.provider.AsyncStatelessMcpPromptProvider;
import org.springaicommunity.mcp.provider.AsyncStatelessMcpResourceProvider;
import org.springaicommunity.mcp.provider.AsyncStatelessMcpToolProvider;

import io.modelcontextprotocol.server.McpServerFeatures.AsyncToolSpecification;
import io.modelcontextprotocol.server.McpStatelessServerFeatures;
import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest;
import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult;
import io.modelcontextprotocol.spec.McpSchema.ElicitRequest;
Expand Down Expand Up @@ -89,6 +93,45 @@ protected Method[] doGetClassMethods(Object bean) {

}

private static class SpringAiAsyncStatelessMcpToolProvider extends AsyncStatelessMcpToolProvider {

public SpringAiAsyncStatelessMcpToolProvider(List<Object> toolObjects) {
super(toolObjects);
}

@Override
protected Method[] doGetClassMethods(Object bean) {
return AnnotationProviderUtil.beanMethods(bean);
}

}

private static class SpringAiAsyncStatelessPromptProvider extends AsyncStatelessMcpPromptProvider {

public SpringAiAsyncStatelessPromptProvider(List<Object> promptObjects) {
super(promptObjects);
}

@Override
protected Method[] doGetClassMethods(Object bean) {
return AnnotationProviderUtil.beanMethods(bean);
}

}

private static class SpringAiAsyncStatelessResourceProvider extends AsyncStatelessMcpResourceProvider {

public SpringAiAsyncStatelessResourceProvider(List<Object> resourceObjects) {
super(resourceObjects);
}

@Override
protected Method[] doGetClassMethods(Object bean) {
return AnnotationProviderUtil.beanMethods(bean);
}

}

public static List<Function<LoggingMessageNotification, Mono<Void>>> createAsyncLoggingConsumers(
List<Object> loggingObjects) {
return new SpringAiAsyncMcpLoggingConsumerProvider(loggingObjects).getLoggingConsumers();
Expand All @@ -108,4 +151,19 @@ public static List<AsyncToolSpecification> createAsyncToolSpecifications(List<Ob
return new SpringAiAsyncMcpToolProvider(toolObjects).getToolSpecifications();
}

public static List<McpStatelessServerFeatures.AsyncToolSpecification> createAsyncStatelessToolSpecifications(
List<Object> toolObjects) {
return new SpringAiAsyncStatelessMcpToolProvider(toolObjects).getToolSpecifications();
}

public static List<McpStatelessServerFeatures.AsyncPromptSpecification> createAsyncStatelessPromptSpecifications(
List<Object> promptObjects) {
return new SpringAiAsyncStatelessPromptProvider(promptObjects).getPromptSpecifications();
}

public static List<McpStatelessServerFeatures.AsyncResourceSpecification> createAsyncStatelessResourceSpecifications(
List<Object> resourceObjects) {
return new SpringAiAsyncStatelessResourceProvider(resourceObjects).getResourceSpecifications();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,15 @@
import org.springaicommunity.mcp.provider.SyncMcpResourceProvider;
import org.springaicommunity.mcp.provider.SyncMcpSamplingProvider;
import org.springaicommunity.mcp.provider.SyncMcpToolProvider;
import org.springaicommunity.mcp.provider.SyncStatelessMcpPromptProvider;
import org.springaicommunity.mcp.provider.SyncStatelessMcpResourceProvider;
import org.springaicommunity.mcp.provider.SyncStatelessMcpToolProvider;

import io.modelcontextprotocol.server.McpServerFeatures.SyncCompletionSpecification;
import io.modelcontextprotocol.server.McpServerFeatures.SyncPromptSpecification;
import io.modelcontextprotocol.server.McpServerFeatures.SyncResourceSpecification;
import io.modelcontextprotocol.server.McpServerFeatures.SyncToolSpecification;
import io.modelcontextprotocol.server.McpStatelessServerFeatures;
import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest;
import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult;
import io.modelcontextprotocol.spec.McpSchema.ElicitRequest;
Expand Down Expand Up @@ -69,6 +73,19 @@ protected Method[] doGetClassMethods(Object bean) {

}

private static class SpringAiSyncStatelessToolProvider extends SyncStatelessMcpToolProvider {

public SpringAiSyncStatelessToolProvider(List<Object> toolObjects) {
super(toolObjects);
}

@Override
protected Method[] doGetClassMethods(Object bean) {
return AnnotationProviderUtil.beanMethods(bean);
}

}

private static class SpringAiSyncMcpPromptProvider extends SyncMcpPromptProvider {

public SpringAiSyncMcpPromptProvider(List<Object> promptObjects) {
Expand All @@ -82,6 +99,19 @@ protected Method[] doGetClassMethods(Object bean) {

};

private static class SpringAiSyncStatelessPromptProvider extends SyncStatelessMcpPromptProvider {

public SpringAiSyncStatelessPromptProvider(List<Object> promptObjects) {
super(promptObjects);
}

@Override
protected Method[] doGetClassMethods(Object bean) {
return AnnotationProviderUtil.beanMethods(bean);
}

}

private static class SpringAiSyncMcpResourceProvider extends SyncMcpResourceProvider {

public SpringAiSyncMcpResourceProvider(List<Object> resourceObjects) {
Expand All @@ -95,6 +125,19 @@ protected Method[] doGetClassMethods(Object bean) {

}

private static class SpringAiSyncStatelessResourceProvider extends SyncStatelessMcpResourceProvider {

public SpringAiSyncStatelessResourceProvider(List<Object> resourceObjects) {
super(resourceObjects);
}

@Override
protected Method[] doGetClassMethods(Object bean) {
return AnnotationProviderUtil.beanMethods(bean);
}

}

private static class SpringAiSyncMcpLoggingConsumerProvider extends SyncMcpLoggingConsumerProvider {

public SpringAiSyncMcpLoggingConsumerProvider(List<Object> loggingObjects) {
Expand Down Expand Up @@ -138,6 +181,11 @@ public static List<SyncToolSpecification> createSyncToolSpecifications(List<Obje
return new SpringAiSyncToolProvider(toolObjects).getToolSpecifications();
}

public static List<McpStatelessServerFeatures.SyncToolSpecification> createSyncStatelessToolSpecifications(
List<Object> toolObjects) {
return new SpringAiSyncStatelessToolProvider(toolObjects).getToolSpecifications();
}

public static List<SyncCompletionSpecification> createSyncCompleteSpecifications(List<Object> completeObjects) {
return new SpringAiSyncMcpCompletionProvider(completeObjects).getCompleteSpecifications();
}
Expand All @@ -146,10 +194,20 @@ public static List<SyncPromptSpecification> createSyncPromptSpecifications(List<
return new SpringAiSyncMcpPromptProvider(promptObjects).getPromptSpecifications();
}

public static List<McpStatelessServerFeatures.SyncPromptSpecification> createSyncStatelessPromptSpecifications(
List<Object> promptObjects) {
return new SpringAiSyncStatelessPromptProvider(promptObjects).getPromptSpecifications();
}

public static List<SyncResourceSpecification> createSyncResourceSpecifications(List<Object> resourceObjects) {
return new SpringAiSyncMcpResourceProvider(resourceObjects).getResourceSpecifications();
}

public static List<McpStatelessServerFeatures.SyncResourceSpecification> createSyncStatelessResourceSpecifications(
List<Object> resourceObjects) {
return new SpringAiSyncStatelessResourceProvider(resourceObjects).getResourceSpecifications();
}

public static List<Consumer<LoggingMessageNotification>> createSyncLoggingConsumers(List<Object> loggingObjects) {
return new SpringAiSyncMcpLoggingConsumerProvider(loggingObjects).getLoggingConsumers();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public static McpSchema.Prompt asPrompt(McpPrompt mcpPrompt, Method method) {

private static String getName(McpPrompt promptAnnotation, Method method) {
Assert.notNull(method, "method cannot be null");
if (promptAnnotation == null || (promptAnnotation.name() == null)) {
if (promptAnnotation == null || (promptAnnotation.name() == null) || promptAnnotation.name().isEmpty()) {
return method.getName();
}
return promptAnnotation.name();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
/*
* Copyright 2025-2025 the original author or authors.
*/

package org.springaicommunity.mcp.method.complete;

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;
import java.util.function.BiFunction;

import org.springaicommunity.mcp.annotation.McpComplete;

import io.modelcontextprotocol.server.McpTransportContext;
import io.modelcontextprotocol.spec.McpSchema.CompleteRequest;
import io.modelcontextprotocol.spec.McpSchema.CompleteResult;
import io.modelcontextprotocol.spec.McpSchema.CompleteResult.CompleteCompletion;
import io.modelcontextprotocol.util.DeafaultMcpUriTemplateManagerFactory;
import reactor.core.publisher.Mono;

/**
* Class for creating BiFunction callbacks around complete methods with asynchronous
* processing for stateless contexts.
*
* This class provides a way to convert methods annotated with {@link McpComplete} into
* callback functions that can be used to handle completion requests asynchronously in
* stateless environments. It supports various method signatures and return types, and
* handles both prompt and URI template completions.
*
* @author Christian Tzolov
*/
public final class AsyncStatelessMcpCompleteMethodCallback extends AbstractMcpCompleteMethodCallback
implements BiFunction<McpTransportContext, CompleteRequest, Mono<CompleteResult>> {

private AsyncStatelessMcpCompleteMethodCallback(Builder builder) {
super(builder.method, builder.bean, builder.prompt, builder.uri, builder.uriTemplateManagerFactory);
this.validateMethod(this.method);
}

/**
* Apply the callback to the given context and request.
* <p>
* This method builds the arguments for the method call, invokes the method, and
* converts the result to a CompleteResult.
* @param context The transport context, may be null if the method doesn't require it
* @param request The complete request, must not be null
* @return A Mono that emits the complete result
* @throws McpCompleteMethodException if there is an error invoking the complete
* method
* @throws IllegalArgumentException if the request is null
*/
@Override
public Mono<CompleteResult> apply(McpTransportContext context, CompleteRequest request) {
if (request == null) {
return Mono.error(new IllegalArgumentException("Request must not be null"));
}

return Mono.defer(() -> {
try {
// Build arguments for the method call
Object[] args = this.buildArgs(this.method, context, request);

// Invoke the method
this.method.setAccessible(true);
Object result = this.method.invoke(this.bean, args);

// Handle the result based on its type
if (result instanceof Mono<?>) {
// If the result is already a Mono, map it to a CompleteResult
return ((Mono<?>) result).map(r -> convertToCompleteResult(r));
}
else {
// Otherwise, convert the result to a CompleteResult and wrap in a
// Mono
return Mono.just(convertToCompleteResult(result));
}
}
catch (Exception e) {
return Mono.error(
new McpCompleteMethodException("Error invoking complete method: " + this.method.getName(), e));
}
});
}

/**
* Converts a result object to a CompleteResult.
* @param result The result object
* @return The CompleteResult
*/
private CompleteResult convertToCompleteResult(Object result) {
if (result == null) {
return new CompleteResult(new CompleteCompletion(List.of(), 0, false));
}

if (result instanceof CompleteResult) {
return (CompleteResult) result;
}

if (result instanceof CompleteCompletion) {
return new CompleteResult((CompleteCompletion) result);
}

if (result instanceof List) {
List<?> list = (List<?>) result;
List<String> values = new ArrayList<>();

for (Object item : list) {
if (item instanceof String) {
values.add((String) item);
}
else {
throw new IllegalArgumentException("List items must be of type String");
}
}

return new CompleteResult(new CompleteCompletion(values, values.size(), false));
}

if (result instanceof String) {
return new CompleteResult(new CompleteCompletion(List.of((String) result), 1, false));
}

throw new IllegalArgumentException("Unsupported return type: " + result.getClass().getName());
}

/**
* Builder for creating AsyncStatelessMcpCompleteMethodCallback instances.
* <p>
* This builder provides a fluent API for constructing
* AsyncStatelessMcpCompleteMethodCallback instances with the required parameters.
*/
public static class Builder extends AbstractBuilder<Builder, AsyncStatelessMcpCompleteMethodCallback> {

/**
* Constructor for Builder.
*/
public Builder() {
this.uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory();
}

/**
* Build the callback.
* @return A new AsyncStatelessMcpCompleteMethodCallback instance
*/
@Override
public AsyncStatelessMcpCompleteMethodCallback build() {
validate();
return new AsyncStatelessMcpCompleteMethodCallback(this);
}

}

/**
* Create a new builder.
* @return A new builder instance
*/
public static Builder builder() {
return new Builder();
}

/**
* Validates that the method return type is compatible with the complete callback.
* @param method The method to validate
* @throws IllegalArgumentException if the return type is not compatible
*/
@Override
protected void validateReturnType(Method method) {
Class<?> returnType = method.getReturnType();

boolean validReturnType = CompleteResult.class.isAssignableFrom(returnType)
|| CompleteCompletion.class.isAssignableFrom(returnType) || List.class.isAssignableFrom(returnType)
|| String.class.isAssignableFrom(returnType) || Mono.class.isAssignableFrom(returnType);

if (!validReturnType) {
throw new IllegalArgumentException(
"Method must return either CompleteResult, CompleteCompletion, List<String>, "
+ "String, or Mono<T>: " + method.getName() + " in " + method.getDeclaringClass().getName()
+ " returns " + returnType.getName());
}
}

/**
* Checks if a parameter type is compatible with the exchange type.
* @param paramType The parameter type to check
* @return true if the parameter type is compatible with the exchange type, false
* otherwise
*/
@Override
protected boolean isExchangeType(Class<?> paramType) {
return McpTransportContext.class.isAssignableFrom(paramType);
}

}
Loading