Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*
* Copyright 2025-2025 the original author or authors.
*/

package org.springaicommunity.mcp;

import java.util.Objects;

public class ErrorUtils {

public static Throwable findCauseUsingPlainJava(Throwable throwable) {
Objects.requireNonNull(throwable);
Throwable rootCause = throwable;
while (rootCause.getCause() != null && rootCause.getCause() != rootCause) {
rootCause = rootCause.getCause();
}
return rootCause;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
import org.springaicommunity.mcp.annotation.McpArg;
import org.springaicommunity.mcp.annotation.McpMeta;
import org.springaicommunity.mcp.annotation.McpProgressToken;

import io.modelcontextprotocol.common.McpTransportContext;
import io.modelcontextprotocol.server.McpAsyncServerExchange;
import io.modelcontextprotocol.server.McpSyncServerExchange;
import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest;
import io.modelcontextprotocol.spec.McpSchema.GetPromptResult;
import io.modelcontextprotocol.spec.McpSchema.Prompt;
Expand Down Expand Up @@ -75,7 +77,10 @@ protected void validateMethod(Method method) {
* @return true if the parameter type is compatible with the exchange type, false
* otherwise
*/
protected abstract boolean isExchangeOrContextType(Class<?> paramType);
protected abstract boolean isSupportedExchangeOrContextType(Class<?> paramType);

protected void validateParamType(Class<?> paramType) {
}

/**
* Validates method parameters.
Expand All @@ -95,6 +100,8 @@ protected void validateParameters(Method method) {
for (java.lang.reflect.Parameter param : parameters) {
Class<?> paramType = param.getType();

this.validateParamType(paramType);

// Skip @McpProgressToken annotated parameters from validation
if (param.isAnnotationPresent(McpProgressToken.class)) {
if (hasProgressTokenParam) {
Expand All @@ -115,7 +122,7 @@ protected void validateParameters(Method method) {
continue;
}

if (isExchangeOrContextType(paramType)) {
if (isSupportedExchangeOrContextType(paramType)) {
if (hasExchangeParam) {
throw new IllegalArgumentException("Method cannot have more than one exchange parameter: "
+ method.getName() + " in " + method.getDeclaringClass().getName());
Expand All @@ -140,6 +147,8 @@ else if (Map.class.isAssignableFrom(paramType)) {
}
}

protected abstract Object assignExchangeType(Class<?> paramType, Object exchange);

/**
* Builds the arguments array for invoking the method.
* <p>
Expand Down Expand Up @@ -182,8 +191,11 @@ protected Object[] buildArgs(Method method, Object exchange, GetPromptRequest re
java.lang.reflect.Parameter param = parameters[i];
Class<?> paramType = param.getType();

if (isExchangeOrContextType(paramType)) {
args[i] = exchange;
if (McpTransportContext.class.isAssignableFrom(paramType)
|| McpSyncServerExchange.class.isAssignableFrom(paramType)
|| McpAsyncServerExchange.class.isAssignableFrom(paramType)) {

args[i] = this.assignExchangeType(paramType, exchange);
}
else if (GetPromptRequest.class.isAssignableFrom(paramType)) {
args[i] = request;
Expand Down Expand Up @@ -367,30 +379,4 @@ protected void validate() {

}

/**
* Exception thrown when there is an error invoking a prompt method.
*/
public static class McpPromptMethodException extends RuntimeException {

private static final long serialVersionUID = 1L;

/**
* Constructs a new exception with the specified detail message and cause.
* @param message The detail message
* @param cause The cause
*/
public McpPromptMethodException(String message, Throwable cause) {
super(message, cause);
}

/**
* Constructs a new exception with the specified detail message.
* @param message The detail message
*/
public McpPromptMethodException(String message) {
super(message);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@
import java.lang.reflect.Method;
import java.util.function.BiFunction;

import org.springaicommunity.mcp.ErrorUtils;
import org.springaicommunity.mcp.annotation.McpPrompt;

import io.modelcontextprotocol.common.McpTransportContext;
import io.modelcontextprotocol.server.McpAsyncServerExchange;
import io.modelcontextprotocol.server.McpSyncServerExchange;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema.ErrorCodes;
import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest;
import io.modelcontextprotocol.spec.McpSchema.GetPromptResult;
import reactor.core.publisher.Mono;
Expand All @@ -31,6 +35,48 @@ private AsyncMcpPromptMethodCallback(Builder builder) {
super(builder.method, builder.bean, builder.prompt);
}

@Override
protected void validateParamType(Class<?> paramType) {

if (McpSyncServerExchange.class.isAssignableFrom(paramType)) {
throw new IllegalArgumentException("Async prompt method must not declare parameter of type: "
+ paramType.getName() + ". Use McpAsyncServerExchange instead." + " Method: "
+ this.method.getName() + " in " + this.method.getDeclaringClass().getName());
}
}

@Override
protected Object assignExchangeType(Class<?> paramType, Object exchange) {

if (McpTransportContext.class.isAssignableFrom(paramType)) {
if (exchange instanceof McpTransportContext transportContext) {
return transportContext;
}
else if (exchange instanceof McpSyncServerExchange syncServerExchange) {
throw new IllegalArgumentException("Unsupported Async exchange type: "
+ syncServerExchange.getClass().getName() + " for Async method: " + method.getName() + " in "
+ method.getDeclaringClass().getName());

}
else if (exchange instanceof McpAsyncServerExchange asyncServerExchange) {
return asyncServerExchange.transportContext();
}
}
else if (McpAsyncServerExchange.class.isAssignableFrom(paramType)) {
if (exchange instanceof McpAsyncServerExchange asyncServerExchange) {
return asyncServerExchange;
}

throw new IllegalArgumentException(
"Unsupported exchange type: " + (exchange != null ? exchange.getClass().getName() : "null")
+ " for Async method: " + method.getName() + " in " + method.getDeclaringClass().getName());
}

throw new IllegalArgumentException(
"Unsupported exchange type: " + (exchange != null ? exchange.getClass().getName() : "null")
+ " for method: " + method.getName() + " in " + method.getDeclaringClass().getName());
}

/**
* Apply the callback to the given exchange and request.
* <p>
Expand Down Expand Up @@ -69,15 +115,24 @@ public Mono<GetPromptResult> apply(McpAsyncServerExchange exchange, GetPromptReq
}
}
catch (Exception e) {
return Mono
.error(new McpPromptMethodException("Error invoking prompt method: " + this.method.getName(), e));
if (e instanceof McpError mcpError && mcpError.getJsonRpcError() != null) {
return Mono.error(mcpError);
}

return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS)
.message("Error invoking prompt method: " + this.method.getName() + " in "
+ this.bean.getClass().getName() + ". /nCause: "
+ ErrorUtils.findCauseUsingPlainJava(e).getMessage())
.data(ErrorUtils.findCauseUsingPlainJava(e).getMessage())
.build());
}
});
}

@Override
protected boolean isExchangeOrContextType(Class<?> paramType) {
return McpAsyncServerExchange.class.isAssignableFrom(paramType);
protected boolean isSupportedExchangeOrContextType(Class<?> paramType) {
return (McpAsyncServerExchange.class.isAssignableFrom(paramType)
|| McpTransportContext.class.isAssignableFrom(paramType));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,13 @@
import java.util.List;
import java.util.function.BiFunction;

import org.springaicommunity.mcp.ErrorUtils;
import org.springaicommunity.mcp.annotation.McpPrompt;
import io.modelcontextprotocol.common.McpTransportContext;
import io.modelcontextprotocol.server.McpAsyncServerExchange;
import io.modelcontextprotocol.server.McpSyncServerExchange;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema.ErrorCodes;
import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest;
import io.modelcontextprotocol.spec.McpSchema.GetPromptResult;
import io.modelcontextprotocol.spec.McpSchema.PromptMessage;
Expand All @@ -32,6 +37,42 @@ private AsyncStatelessMcpPromptMethodCallback(Builder builder) {
super(builder.method, builder.bean, builder.prompt);
}

@Override
protected void validateParamType(Class<?> paramType) {

if (McpSyncServerExchange.class.isAssignableFrom(paramType)
|| McpAsyncServerExchange.class.isAssignableFrom(paramType)) {

throw new IllegalArgumentException(
"Stateless Streamable-Http prompt method must not declare parameter of type: " + paramType.getName()
+ ". Use McpTransportContext instead." + " Method: " + this.method.getName() + " in "
+ this.method.getDeclaringClass().getName());
}
}

@Override
protected Object assignExchangeType(Class<?> paramType, Object exchange) {

if (McpTransportContext.class.isAssignableFrom(paramType)) {
if (exchange instanceof McpTransportContext transportContext) {
return transportContext;
}
else if (exchange instanceof McpSyncServerExchange syncServerExchange) {
throw new IllegalArgumentException("Unsupported Sync exchange type: "
+ syncServerExchange.getClass().getName() + " for Sync method: " + method.getName() + " in "
+ method.getDeclaringClass().getName());

}
else if (exchange instanceof McpAsyncServerExchange asyncServerExchange) {
return asyncServerExchange.transportContext();
}
}

throw new IllegalArgumentException(
"Unsupported exchange type: " + (exchange != null ? exchange.getClass().getName() : "null")
+ " for method: " + method.getName() + " in " + method.getDeclaringClass().getName());
}

/**
* Apply the callback to the given context and request.
* <p>
Expand Down Expand Up @@ -70,14 +111,23 @@ public Mono<GetPromptResult> apply(McpTransportContext context, GetPromptRequest
}
}
catch (Exception e) {
return Mono
.error(new McpPromptMethodException("Error invoking prompt method: " + this.method.getName(), e));

if (e instanceof McpError mcpError && mcpError.getJsonRpcError() != null) {
return Mono.error(mcpError);
}

return Mono.error(McpError.builder(ErrorCodes.INVALID_PARAMS)
.message("Error invoking prompt method: " + this.method.getName() + " in "
+ this.bean.getClass().getName() + ". /nCause: "
+ ErrorUtils.findCauseUsingPlainJava(e).getMessage())
.data(ErrorUtils.findCauseUsingPlainJava(e).getMessage())
.build());
}
});
}

@Override
protected boolean isExchangeOrContextType(Class<?> paramType) {
protected boolean isSupportedExchangeOrContextType(Class<?> paramType) {
return McpTransportContext.class.isAssignableFrom(paramType);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@
import java.util.List;
import java.util.function.BiFunction;

import org.springaicommunity.mcp.ErrorUtils;
import org.springaicommunity.mcp.annotation.McpPrompt;

import io.modelcontextprotocol.common.McpTransportContext;
import io.modelcontextprotocol.server.McpAsyncServerExchange;
import io.modelcontextprotocol.server.McpSyncServerExchange;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema.ErrorCodes;
import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest;
import io.modelcontextprotocol.spec.McpSchema.GetPromptResult;
import io.modelcontextprotocol.spec.McpSchema.PromptMessage;
Expand All @@ -31,6 +35,47 @@ private SyncMcpPromptMethodCallback(Builder builder) {
super(builder.method, builder.bean, builder.prompt);
}

@Override
protected void validateParamType(Class<?> paramType) {

if (McpAsyncServerExchange.class.isAssignableFrom(paramType)) {
throw new IllegalArgumentException("Sync prompt method must not declare parameter of type: "
+ paramType.getName() + ". Use McpSyncServerExchange instead." + " Method: " + this.method.getName()
+ " in " + this.method.getDeclaringClass().getName());
}
}

@Override
protected Object assignExchangeType(Class<?> paramType, Object exchange) {

if (McpTransportContext.class.isAssignableFrom(paramType)) {
if (exchange instanceof McpTransportContext transportContext) {
return transportContext;
}
else if (exchange instanceof McpSyncServerExchange syncServerExchange) {
return syncServerExchange.transportContext();
}
else if (exchange instanceof McpAsyncServerExchange asyncServerExchange) {
throw new IllegalArgumentException("Unsupported Async exchange type: "
+ asyncServerExchange.getClass().getName() + " for Sync method: " + method.getName() + " in "
+ method.getDeclaringClass().getName());
}
}
else if (McpSyncServerExchange.class.isAssignableFrom(paramType)) {
if (exchange instanceof McpSyncServerExchange syncServerExchange) {
return syncServerExchange;
}

throw new IllegalArgumentException(
"Unsupported exchange type: " + (exchange != null ? exchange.getClass().getName() : "null")
+ " for Sync method: " + method.getName() + " in " + method.getDeclaringClass().getName());
}

throw new IllegalArgumentException(
"Unsupported exchange type: " + (exchange != null ? exchange.getClass().getName() : "null")
+ " for method: " + method.getName() + " in " + method.getDeclaringClass().getName());
}

/**
* Apply the callback to the given exchange and request.
* <p>
Expand Down Expand Up @@ -62,13 +107,23 @@ public GetPromptResult apply(McpSyncServerExchange exchange, GetPromptRequest re
return promptResult;
}
catch (Exception e) {
throw new McpPromptMethodException("Error invoking prompt method: " + this.method.getName(), e);
if (e instanceof McpError mcpError && mcpError.getJsonRpcError() != null) {
throw mcpError;
}

throw McpError.builder(ErrorCodes.INVALID_PARAMS)
.message("Error invoking prompt method: " + this.method.getName() + " in "
+ this.bean.getClass().getName() + "./nCause: "
+ ErrorUtils.findCauseUsingPlainJava(e).getMessage())
.data(ErrorUtils.findCauseUsingPlainJava(e).getMessage())
.build();
}
}

@Override
protected boolean isExchangeOrContextType(Class<?> paramType) {
return McpSyncServerExchange.class.isAssignableFrom(paramType);
protected boolean isSupportedExchangeOrContextType(Class<?> paramType) {
return (McpSyncServerExchange.class.isAssignableFrom(paramType)
|| McpTransportContext.class.isAssignableFrom(paramType));
}

@Override
Expand Down
Loading