From b1c3f2c02ad2655998b04aa620dcd8ebcaf0ee77 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Sun, 5 Oct 2025 18:52:04 +0200 Subject: [PATCH 1/4] feat: add unified request context interfaces for MCP operations Introduces McpSyncRequestContext and McpAsyncRequestContext as unified interfaces for accessing MCP request context across both stateful and stateless operations. These new context types provide: - Unified API that works for both stateful and stateless operations - Convenient methods for logging, progress updates, sampling, and elicitation - Type-safe access to request data and context - Automatic injection by the framework Key changes: - Added McpSyncRequestContext and McpAsyncRequestContext interfaces - Added default implementations (DefaultMcpSyncRequestContext, DefaultMcpAsyncRequestContext) - Updated all tool method callbacks to support new context types - Updated JSON schema generator to exclude context parameters - Added tests for new context functionality - Updated README with detailed documentation and examples - Deprecated @McpProgressToken in favor of internal handling by context - Marked McpSyncServerExchange and McpAsyncServerExchange as deprecated in favor of McpSyncRequestContext and McpAsyncRequestContext Resolves #69 Releated to https://github.com/spring-projects/spring-ai/issues/4471 Signed-off-by: Christian Tzolov --- README.md | 191 +++++- .../mcp/context/DefaultElicitationSpec.java | 57 ++ .../mcp/context/DefaultLoggingSpec.java | 60 ++ .../DefaultMcpAsyncRequestContext.java | 474 +++++++++++++++ .../context/DefaultMcpSyncRequestContext.java | 458 ++++++++++++++ .../mcp/context/DefaultProgressSpec.java | 59 ++ .../mcp/context/DefaultSamplingSpec.java | 199 ++++++ .../mcp/context/McpAsyncRequestContext.java | 76 +++ .../mcp/context/McpRequestContextTypes.java | 164 +++++ .../mcp/context/McpSyncRequestContext.java | 74 +++ .../AbstractAsyncMcpToolMethodCallback.java | 4 +- .../tool/AbstractMcpToolMethodCallback.java | 15 +- .../AbstractSyncMcpToolMethodCallback.java | 4 +- .../tool/AsyncMcpToolMethodCallback.java | 17 +- .../AsyncStatelessMcpToolMethodCallback.java | 20 +- .../tool/SyncMcpToolMethodCallback.java | 15 +- .../SyncStatelessMcpToolMethodCallback.java | 20 +- .../tool/utils/JsonSchemaGenerator.java | 11 +- .../context/DefaultElicitationSpecTests.java | 134 +++++ .../mcp/context/DefaultLoggingSpecTests.java | 146 +++++ .../DefaultMcpAsyncRequestContextTests.java | 564 ++++++++++++++++++ .../DefaultMcpSyncRequestContextTests.java | 547 +++++++++++++++++ .../mcp/context/DefaultProgressSpecTests.java | 167 ++++++ .../mcp/context/DefaultSamplingSpecTests.java | 215 +++++++ .../tool/AsyncMcpToolMethodCallbackTests.java | 28 + .../tool/SyncMcpToolMethodCallbackTests.java | 27 + 26 files changed, 3723 insertions(+), 23 deletions(-) create mode 100644 mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultElicitationSpec.java create mode 100644 mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultLoggingSpec.java create mode 100644 mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultMcpAsyncRequestContext.java create mode 100644 mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultMcpSyncRequestContext.java create mode 100644 mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultProgressSpec.java create mode 100644 mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultSamplingSpec.java create mode 100644 mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpAsyncRequestContext.java create mode 100644 mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpRequestContextTypes.java create mode 100644 mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpSyncRequestContext.java create mode 100644 mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultElicitationSpecTests.java create mode 100644 mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultLoggingSpecTests.java create mode 100644 mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultMcpAsyncRequestContextTests.java create mode 100644 mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultMcpSyncRequestContextTests.java create mode 100644 mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultProgressSpecTests.java create mode 100644 mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultSamplingSpecTests.java diff --git a/README.md b/README.md index c69cdaa..de87f65 100644 --- a/README.md +++ b/README.md @@ -120,11 +120,15 @@ Each operation type has both synchronous and asynchronous implementations, allow - **`@McpToolParam`** - Annotates tool method parameters with descriptions and requirement specifications #### Special Parameters and Annotations -- **`@McpProgressToken`** - Marks a method parameter to receive the progress token from the request. This parameter is automatically injected and excluded from the generated JSON schema -- **`McpMeta`** - Special parameter type that provides access to metadata from MCP requests, notifications, and results. This parameter is automatically injected and excluded from parameter count limits and JSON schema generation -- **`McpSyncServerExchange`** - Special parameter type for stateful synchronous operations that provides access to server exchange functionality including logging notifications, progress updates, and other server-side operations. This parameter is automatically injected and excluded from JSON schema generation -- **`McpAsyncServerExchange`** - Special parameter type for stateful asynchronous operations that provides access to server exchange functionality with reactive support. This parameter is automatically injected and excluded from JSON schema generation +- **`McpSyncRequestContext`** - Special parameter type for synchronous operations that provides a unified interface for accessing MCP request context, including the original request, server exchange (for stateful operations), transport context (for stateless operations), and convenient methods for logging, progress, sampling, and elicitation. This parameter is automatically injected and excluded from JSON schema generation +- **`McpAsyncRequestContext`** - Special parameter type for asynchronous operations that provides the same unified interface as `McpSyncRequestContext` but with reactive (Mono-based) return types. This parameter is automatically injected and excluded from JSON schema generation +- **(Deprecated and replaced by `McpSyncRequestContext`) `McpSyncServerExchange`** - Special parameter type for stateful synchronous operations that provides access to server exchange functionality including logging notifications, progress updates, and other server-side operations. This parameter is automatically injected and excluded from JSON schema generation. +- **(Deprecated and replaced by `McpAsyncRequestContext`) `McpAsyncServerExchange`** - Special parameter type for stateful asynchronous operations that provides access to server exchange functionality with reactive support. This parameter is automatically injected and excluded from JSON schema generation - **`McpTransportContext`** - Special parameter type for stateless operations that provides lightweight access to transport-level context without full server exchange functionality. This parameter is automatically injected and excluded from JSON schema generation +- **(Deprecated. Handled internally by `McpSyncRequestContext` and `McpAsyncRequestContext`)`@McpProgressToken`** - Marks a method parameter to receive the progress token from the request. This parameter is automatically injected and excluded from the generated JSON schema +**Note:** if using the `McpSyncRequestContext` or `McpAsyncRequestContext` the progress token is handled internally. +- **`McpMeta`** - Special parameter type that provides access to metadata from MCP requests, notifications, and results. This parameter is automatically injected and excluded from parameter count limits and JSON schema generation. +**Note:** if using the McpSyncRequestContext or McpAsyncRequestContext the meta can be obatined via `requestMeta()` instead. ### Method Callbacks @@ -870,6 +874,185 @@ public List smartComplete( This feature enables context-aware MCP operations where the behavior can be customized based on client-provided metadata such as user identity, preferences, session information, or any other contextual data. +#### McpRequestContext Support + +The library provides unified request context interfaces (`McpSyncRequestContext` and `McpAsyncRequestContext`) that offer a higher-level abstraction over the underlying MCP infrastructure. These context objects provide convenient access to: + +- The original request (CallToolRequest, ReadResourceRequest, etc.) +- Server exchange (for stateful operations) or transport context (for stateless operations) +- Convenient methods for logging, progress updates, sampling, elicitation, and more + +**Key Benefits:** +- **Unified API**: Single parameter type works for both stateful and stateless operations +- **Convenience Methods**: Built-in helpers for common operations like logging and progress tracking +- **Type Safety**: Strongly-typed access to request data and context +- **Automatic Injection**: Context is automatically created and injected by the framework + +When a method parameter is of type `McpSyncRequestContext` or `McpAsyncRequestContext`: +- The parameter is automatically injected with the appropriate context implementation +- The parameter is excluded from JSON schema generation +- For stateful operations, the context provides access to `McpSyncServerExchange` or `McpAsyncServerExchange` +- For stateless operations, the context provides access to `McpTransportContext` + +**Synchronous Context Example:** + +```java +public record ElicitReturnType(String message) {} + +@McpTool(name = "process-with-context", description = "Process data with unified context") +public String processWithContext( + McpSyncRequestContext context, + @McpToolParam(description = "Data to process", required = true) String data) { + + // Access the original request + CallToolRequest request = (CallToolRequest) context.request(); + + // Log information + context.info("Processing data: " + data); + + // Send progress updates + context.progress(50); // 50% complete + + // Check if running in stateful mode + if (!context.isStateless()) { + // Access server exchange for stateful operations + McpSyncServerExchange exchange = context.exchange().orElseThrow(); + // Use exchange for additional operations... + } + + // Perform elicitation if needed + Optional userInput = context.elicitation(spec -> { + spec.message("Please provide additional information"); + spec.regurnType(ElicitReturnType.class); + }); + + return "Processed: " + data; +} + +@McpResource(uri = "data://{id}", name = "Data Resource", description = "Resource with context") +public ReadResourceResult getDataWithContext( + McpSyncRequestContext context, + String id) { + + // Log the resource access + context.debug("Accessing resource: " + id); + + // Access metadata from the request + Map metadata = context.request()._meta(); + + String content = "Data for " + id; + return new ReadResourceResult(List.of( + new TextResourceContents("data://" + id, "text/plain", content) + )); +} + +@McpPrompt(name = "generate-with-context", description = "Generate prompt with context") +public GetPromptResult generateWithContext( + McpSyncRequestContext context, + @McpArg(name = "topic", required = true) String topic) { + + // Log prompt generation + context.info("Generating prompt for topic: " + topic); + + // Perform sampling if needed + Optional samplingResult = context.sampling( + "What are the key points about " + topic + "?" + ); + + String message = "Let's discuss " + topic; + return new GetPromptResult("Generated Prompt", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent(message)))); +} +``` + +**Asynchronous Context Example:** + +```java + +public record ElicitReturnType(String message) {} + +@McpTool(name = "async-process-with-context", description = "Async process with unified context") +public Mono asyncProcessWithContext( + McpAsyncRequestContext context, + @McpToolParam(description = "Data to process", required = true) String data) { + + return Mono.fromCallable(() -> { + // Access the original request + CallToolRequest request = (CallToolRequest) context.request(); + return data; + }) + .flatMap(processedData -> { + // Log information (returns Mono) + return context.info("Processing data: " + processedData) + .thenReturn(processedData); + }) + .flatMap(processedData -> { + // Send progress updates (returns Mono) + return context.progress(50) + .thenReturn(processedData); + }) + .flatMap(processedData -> { + // Perform elicitation if needed (returns Mono) + return context.elicitation(spec -> { + spec.message("Please provide additional information"); + spec.returnType(ElicitReturnType.class); + }) + .map(result -> "Processed: " + processedData + " with user input"); + }); +} + +@McpResource(uri = "async-data://{id}", name = "Async Data Resource", + description = "Async resource with context") +public Mono getAsyncDataWithContext( + McpAsyncRequestContext context, + String id) { + + // Log the resource access (returns Mono) + return context.debug("Accessing async resource: " + id) + .then(Mono.fromCallable(() -> { + String content = "Async data for " + id; + return new ReadResourceResult(List.of( + new TextResourceContents("async-data://" + id, "text/plain", content) + )); + })); +} + +@McpPrompt(name = "async-generate-with-context", + description = "Async generate prompt with context") +public Mono asyncGenerateWithContext( + McpAsyncRequestContext context, + @McpArg(name = "topic", required = true) String topic) { + + // Log prompt generation and perform sampling + return context.info("Generating async prompt for topic: " + topic) + .then(context.sampling("What are the key points about " + topic + "?")) + .map(samplingResult -> { + String message = "Let's discuss " + topic; + return new GetPromptResult("Generated Async Prompt", + List.of(new PromptMessage(Role.ASSISTANT, new TextContent(message)))); + }); +} +``` + +**Available Context Methods:** + +`McpSyncRequestContext` provides: +- `request()` - Access the original request object +- `exchange()` - Access the server exchange (for stateful operations) +- `transportContext()` - Access the transport context (for stateless operations) +- `isStateless()` - Check if running in stateless mode +- `log(Consumer)` - Send log messages with custom configuration +- `debug(String)`, `info(String)`, `warn(String)`, `error(String)` - Convenience logging methods +- `progress(int)`, `progress(Consumer)` - Send progress updates +- `elicitation(...)` - Request user input with various configuration options +- `sampling(...)` - Request LLM sampling with various configuration options +- `roots()` - Access root directories (returns `Optional`) +- `ping()` - Send ping to check connection + +`McpAsyncRequestContext` provides the same methods but with reactive return types (`Mono` instead of `T` or `Optional`). + +This unified context approach simplifies method signatures and provides a consistent API across different operation types and execution modes (stateful vs stateless, sync vs async). + ### Async Tool Example ```java diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultElicitationSpec.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultElicitationSpec.java new file mode 100644 index 0000000..9d9f74a --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultElicitationSpec.java @@ -0,0 +1,57 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.context; + +import java.lang.reflect.Type; +import java.util.HashMap; +import java.util.Map; + +import io.modelcontextprotocol.util.Assert; +import org.springaicommunity.mcp.context.McpRequestContextTypes.ElicitationSpec; + +/** + * @author Christian Tzolov + */ +public class DefaultElicitationSpec implements ElicitationSpec { + + protected String message; + + protected Type responseType; + + protected Map meta = new HashMap<>(); + + @Override + public McpSyncRequestContext.ElicitationSpec message(String message) { + Assert.hasText(message, "Message must not be empty"); + this.message = message; + return this; + } + + @Override + public McpSyncRequestContext.ElicitationSpec responseType(Type type) { + Assert.notNull(type, "Response type must not be null"); + this.responseType = type; + return this; + } + + @Override + public McpSyncRequestContext.ElicitationSpec meta(Map m) { + Assert.notNull(m, "Meta map must not be null"); + this.meta.putAll(m); + return this; + } + + @Override + public McpSyncRequestContext.ElicitationSpec meta(String k, Object v) { + if (k != null && v != null) { + if (this.meta == null) { + this.meta = new java.util.HashMap<>(); + } + this.meta.put(k, v); + } + return this; + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultLoggingSpec.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultLoggingSpec.java new file mode 100644 index 0000000..85c5e83 --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultLoggingSpec.java @@ -0,0 +1,60 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.context; + +import java.util.HashMap; +import java.util.Map; + +import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; +import org.springaicommunity.mcp.context.McpRequestContextTypes.LoggingSpec; + +/** + * @author Christian Tzolov + */ +public class DefaultLoggingSpec implements LoggingSpec { + + protected String message; + + protected String logger; + + protected LoggingLevel level = LoggingLevel.INFO; + + protected Map meta = new HashMap<>(); + + @Override + public LoggingSpec message(String message) { + this.message = message; + return this; + } + + @Override + public LoggingSpec logger(String logger) { + this.logger = logger; + return this; + } + + @Override + public LoggingSpec level(LoggingLevel level) { + this.level = level; + return this; + } + + @Override + public LoggingSpec meta(Map m) { + if (m != null) { + this.meta.putAll(m); + } + return this; + } + + @Override + public LoggingSpec meta(String k, Object v) { + if (k != null && v != null) { + this.meta.put(k, v); + } + return this; + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultMcpAsyncRequestContext.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultMcpAsyncRequestContext.java new file mode 100644 index 0000000..3dfc618 --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultMcpAsyncRequestContext.java @@ -0,0 +1,474 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.context; + +import java.lang.reflect.Type; +import java.util.Map; +import java.util.function.Consumer; + +import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; +import io.modelcontextprotocol.spec.McpSchema.Implementation; +import io.modelcontextprotocol.spec.McpSchema.ListRootsResult; +import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; +import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; +import io.modelcontextprotocol.spec.McpSchema.ProgressNotification; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springaicommunity.mcp.method.tool.utils.ConcurrentReferenceHashMap; +import org.springaicommunity.mcp.method.tool.utils.JsonParser; +import org.springaicommunity.mcp.method.tool.utils.JsonSchemaGenerator; +import reactor.core.publisher.Mono; + +/** + * Async (Reactor) implementation of McpAsyncRequestContext that returns Mono of value + * types. + * + * @author Christian Tzolov + */ +public class DefaultMcpAsyncRequestContext implements McpAsyncRequestContext { + + private static final Logger logger = LoggerFactory.getLogger(DefaultMcpAsyncRequestContext.class); + + private static final Map> typeSchemaCache = new ConcurrentReferenceHashMap<>(256); + + private static TypeReference> MAP_TYPE_REF = new TypeReference>() { + }; + + private final McpSchema.Request request; + + private final McpAsyncServerExchange exchange; + + private DefaultMcpAsyncRequestContext(McpSchema.Request request, McpAsyncServerExchange exchange) { + Assert.notNull(request, "Request must not be null"); + Assert.notNull(exchange, "Exchange must not be null"); + this.request = request; + this.exchange = exchange; + } + + // Roots + + @Override + public Mono roots() { + if (this.exchange.getClientCapabilities() == null || this.exchange.getClientCapabilities().roots() == null) { + logger.warn("Roots not supported by the client! Ignoring the roots request for request:" + this.request); + return Mono.empty(); + } + return this.exchange.listRoots(); + } + + // Elicitation + + @Override + public Mono elicitation(Consumer elicitationSpec) { + Assert.notNull(elicitationSpec, "Elicitation spec consumer must not be null"); + DefaultElicitationSpec spec = new DefaultElicitationSpec(); + elicitationSpec.accept(spec); + Assert.hasText(spec.message, "Elicitation message must not be empty"); + Assert.notNull(spec.responseType, "Elicitation response type must not be null"); + + return this.elicitationInternal(spec.message, spec.responseType, spec.meta.isEmpty() ? null : spec.meta); + } + + @Override + public Mono elicitation(String message, Type type) { + return this.elicitationInternal(message, type, null); + } + + @Override + public Mono elicitation(ElicitRequest elicitRequest) { + Assert.notNull(elicitRequest, "Elicit request must not be null"); + + if (this.exchange.getClientCapabilities() == null + || this.exchange.getClientCapabilities().elicitation() == null) { + logger.warn("Elicitation not supported by the client! Ignoring the elicitation request for request:" + + elicitRequest); + return Mono.empty(); + } + + return this.exchange.createElicitation(elicitRequest); + } + + public Mono elicitationInternal(String message, Type type, Map meta) { + Assert.hasText(message, "Elicitation message must not be empty"); + Assert.notNull(type, "Elicitation response type must not be null"); + + Map schema = typeSchemaCache.computeIfAbsent(type, t -> this.generateElicitSchema(t)); + + return this.elicitation(ElicitRequest.builder().message(message).requestedSchema(schema).meta(meta).build()); + } + + private Map generateElicitSchema(Type type) { + Map schema = JsonParser.fromJson(JsonSchemaGenerator.generateFromType(type), MAP_TYPE_REF); + // remove as elicitation schema does not support it + schema.remove("$schema"); + return schema; + } + + // Sampling + + @Override + public Mono sampling(String... messages) { + return this.sampling(s -> s.message(messages)); + } + + @Override + public Mono sampling(Consumer samplingSpec) { + Assert.notNull(samplingSpec, "Sampling spec consumer must not be null"); + DefaultSamplingSpec spec = new DefaultSamplingSpec(); + samplingSpec.accept(spec); + + var progressToken = this.request.progressToken(); + + if (!Utils.hasText(progressToken)) { + logger.warn("Progress notification not supported by the client!"); + } + return this.sampling(McpSchema.CreateMessageRequest.builder() + .messages(spec.messages) + .modelPreferences(spec.modelPreferences) + .systemPrompt(spec.systemPrompt) + .temperature(spec.temperature) + .maxTokens(spec.maxTokens != null && spec.maxTokens > 0 ? spec.maxTokens : 500) + .stopSequences(spec.stopSequences.isEmpty() ? null : spec.stopSequences) + .includeContext(spec.includeContextStrategy) + .meta(spec.metadata.isEmpty() ? null : spec.metadata) + .progressToken(progressToken) + .meta(spec.meta.isEmpty() ? null : spec.meta) + .build()); + } + + @Override + public Mono sampling(CreateMessageRequest createMessageRequest) { + + // check if supported + if (this.exchange.getClientCapabilities() == null || this.exchange.getClientCapabilities().sampling() == null) { + logger.warn("Sampling not supported by the client! Ignoring the sampling request for messages:" + + createMessageRequest); + return Mono.empty(); + } + + return this.exchange.createMessage(createMessageRequest); + } + + // Progress + + @Override + public Mono progress(int percentage) { + Assert.isTrue(percentage >= 0 && percentage <= 100, "Percentage must be between 0 and 100"); + return this.progress(p -> p.progress(percentage / 100.0).total(1.0).message(null)); + } + + @Override + public Mono progress(Consumer progressSpec) { + + Assert.notNull(progressSpec, "Progress spec consumer must not be null"); + DefaultProgressSpec spec = new DefaultProgressSpec(); + + progressSpec.accept(spec); + + if (!Utils.hasText(this.request.progressToken())) { + logger.warn("Progress notification not supported by the client!"); + return Mono.empty(); + } + + return this.progress(new ProgressNotification(this.request.progressToken(), spec.progress, spec.total, + spec.message, spec.meta)); + } + + @Override + public Mono progress(ProgressNotification progressNotification) { + return this.exchange.progressNotification(progressNotification).then(Mono.empty()); + } + + // Ping + + @Override + public Mono ping() { + return this.exchange.ping(); + } + + // Logging + + @Override + public Mono log(Consumer logSpec) { + Assert.notNull(logSpec, "Logging spec consumer must not be null"); + DefaultLoggingSpec spec = new DefaultLoggingSpec(); + logSpec.accept(spec); + + return this.exchange + .loggingNotification(LoggingMessageNotification.builder() + .data(spec.message) + .level(spec.level) + .logger(spec.logger) + .meta(spec.meta) + .build()) + .then(); + } + + @Override + public Mono debug(String message) { + return this.logInternal(message, LoggingLevel.DEBUG); + } + + @Override + public Mono info(String message) { + return this.logInternal(message, LoggingLevel.INFO); + } + + @Override + public Mono warn(String message) { + return this.logInternal(message, LoggingLevel.WARNING); + } + + @Override + public Mono error(String message) { + return this.logInternal(message, LoggingLevel.ERROR); + } + + private Mono logInternal(String message, LoggingLevel level) { + Assert.hasText(message, "Log message must not be empty"); + return this.exchange + .loggingNotification(LoggingMessageNotification.builder().data(message).level(level).build()) + .then(); + } + + // Getters + + @Override + public McpSchema.Request request() { + return this.request; + } + + @Override + public McpAsyncServerExchange exchange() { + return this.exchange; + } + + @Override + public String sessionId() { + return this.exchange.sessionId(); + } + + @Override + public Implementation clientInfo() { + return this.exchange.getClientInfo(); + } + + @Override + public ClientCapabilities clientCapabilities() { + return this.exchange.getClientCapabilities(); + } + + @Override + public Map requestMeta() { + return this.request.meta(); + } + + @Override + public McpTransportContext transportContext() { + return this.exchange.transportContext(); + } + + // Builder + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private McpSchema.Request request; + + private McpAsyncServerExchange exchange; + + private boolean isStateless = false; + + private McpTransportContext transportContext; + + private Builder() { + } + + public Builder request(McpSchema.Request request) { + this.request = request; + return this; + } + + public Builder exchange(McpAsyncServerExchange exchange) { + this.exchange = exchange; + return this; + } + + public Builder stateless(boolean isStateless) { + this.isStateless = isStateless; + return this; + } + + public Builder transportContext(McpTransportContext transportContext) { + this.transportContext = transportContext; + return this; + } + + public McpAsyncRequestContext build() { + if (this.isStateless) { + return new StatelessAsyncRequestContext(this.request, this.transportContext); + } + return new DefaultMcpAsyncRequestContext(this.request, this.exchange); + } + + } + + private static class StatelessAsyncRequestContext implements McpAsyncRequestContext { + + private final McpSchema.Request request; + + private McpTransportContext transportContext; + + public StatelessAsyncRequestContext(McpSchema.Request request, McpTransportContext transportContext) { + this.request = request; + this.transportContext = transportContext; + } + + @Override + public Mono roots() { + logger.warn("Roots not supported by the client! Ignoring the roots request"); + return Mono.empty(); + } + + @Override + public Mono elicitation(Consumer elicitationSpec) { + logger.warn("Elicitation not supported by the client! Ignoring the elicitation request"); + return Mono.empty(); + } + + @Override + public Mono elicitation(String message, Type type) { + logger.warn("Elicitation not supported by the client! Ignoring the elicitation request"); + return Mono.empty(); + } + + @Override + public Mono elicitation(ElicitRequest elicitRequest) { + logger.warn("Elicitation not supported by the client! Ignoring the elicitation request"); + return Mono.empty(); + } + + @Override + public Mono sampling(String... messages) { + logger.warn("Sampling not supported by the client! Ignoring the sampling request"); + return Mono.empty(); + } + + @Override + public Mono sampling(Consumer samplingSpec) { + logger.warn("Sampling not supported by the client! Ignoring the sampling request"); + return Mono.empty(); + } + + @Override + public Mono sampling(CreateMessageRequest createMessageRequest) { + logger.warn("Sampling not supported by the client! Ignoring the sampling request"); + return Mono.empty(); + } + + @Override + public Mono progress(int progress) { + logger.warn("Progress not supported by the client! Ignoring the progress request"); + return Mono.empty(); + } + + @Override + public Mono progress(Consumer progressSpec) { + logger.warn("Progress not supported by the client! Ignoring the progress request"); + return Mono.empty(); + } + + @Override + public Mono progress(ProgressNotification progressNotification) { + logger.warn("Progress not supported by the client! Ignoring the progress request"); + return Mono.empty(); + } + + @Override + public Mono ping() { + logger.warn("Ping not supported by the client! Ignoring the ping request"); + return Mono.empty(); + } + + @Override + public Mono log(Consumer logSpec) { + logger.warn("Logging not supported by the client! Ignoring the logging request"); + return Mono.empty(); + } + + @Override + public Mono debug(String message) { + logger.warn("Debug not supported by the client! Ignoring the debug request"); + return Mono.empty(); + } + + @Override + public Mono info(String message) { + logger.warn("Info not supported by the client! Ignoring the info request"); + return Mono.empty(); + } + + @Override + public Mono warn(String message) { + logger.warn("Warn not supported by the client! Ignoring the warn request"); + return Mono.empty(); + } + + @Override + public Mono error(String message) { + logger.warn("Error not supported by the client! Ignoring the error request"); + return Mono.empty(); + } + + // Getters + + public McpSchema.Request request() { + return this.request; + } + + public McpAsyncServerExchange exchange() { + logger.warn("Stateless servers do not support exchange! Returning null"); + return null; + } + + public String sessionId() { + logger.warn("Stateless servers do not support session ID! Returning null"); + return null; + } + + public Implementation clientInfo() { + logger.warn("Stateless servers do not support client info! Returning null"); + return null; + } + + public ClientCapabilities clientCapabilities() { + logger.warn("Stateless servers do not support client capabilities! Returning null"); + return null; + } + + public Map requestMeta() { + return this.request.meta(); + } + + public McpTransportContext transportContext() { + return transportContext; + } + + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultMcpSyncRequestContext.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultMcpSyncRequestContext.java new file mode 100644 index 0000000..2ef1804 --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultMcpSyncRequestContext.java @@ -0,0 +1,458 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.context; + +import java.lang.reflect.Type; +import java.util.Map; +import java.util.Optional; +import java.util.function.Consumer; + +import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; +import io.modelcontextprotocol.spec.McpSchema.Implementation; +import io.modelcontextprotocol.spec.McpSchema.ListRootsResult; +import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; +import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; +import io.modelcontextprotocol.spec.McpSchema.ProgressNotification; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springaicommunity.mcp.method.tool.utils.ConcurrentReferenceHashMap; +import org.springaicommunity.mcp.method.tool.utils.JsonParser; +import org.springaicommunity.mcp.method.tool.utils.JsonSchemaGenerator; + +/** + * @author Christian Tzolov + */ +public class DefaultMcpSyncRequestContext implements McpSyncRequestContext { + + private static final Logger logger = LoggerFactory.getLogger(DefaultMcpSyncRequestContext.class); + + private static final Map> typeSchemaCache = new ConcurrentReferenceHashMap<>(256); + + private static TypeReference> MAP_TYPE_REF = new TypeReference>() { + }; + + private final McpSchema.Request request; + + private final McpSyncServerExchange exchange; + + private DefaultMcpSyncRequestContext(McpSchema.Request request, McpSyncServerExchange exchange) { + Assert.notNull(request, "Request must not be null"); + Assert.notNull(exchange, "Exchange must not be null"); + this.request = request; + this.exchange = exchange; + } + + // Roots + + public Optional roots() { + if (this.exchange.getClientCapabilities() == null || this.exchange.getClientCapabilities().roots() == null) { + logger.warn("Roots not supported by the client! Ignoring the roots request for request:" + this.request); + return Optional.empty(); + } + return Optional.of(this.exchange.listRoots()); + } + + // Elicitation + + @Override + public Optional elicitation(Consumer elicitationSpec) { + Assert.notNull(elicitationSpec, "Elicitation spec consumer must not be null"); + DefaultElicitationSpec spec = new DefaultElicitationSpec(); + elicitationSpec.accept(spec); + Assert.hasText(spec.message, "Elicitation message must not be empty"); + Assert.notNull(spec.responseType, "Elicitation response type must not be null"); + + return this.elicitationInternal(spec.message, spec.responseType, spec.meta.isEmpty() ? null : spec.meta); + } + + @Override + public Optional elicitation(String message, Type type) { + return this.elicitationInternal(message, type, null); + } + + @Override + public Optional elicitation(ElicitRequest elicitRequest) { + Assert.notNull(elicitRequest, "Elicit request must not be null"); + + if (this.exchange.getClientCapabilities() == null + || this.exchange.getClientCapabilities().elicitation() == null) { + logger.warn("Elicitation not supported by the client! Ignoring the elicitation request for request:" + + elicitRequest); + return Optional.empty(); + } + + return Optional.of(this.exchange.createElicitation(elicitRequest)); + } + + public Optional elicitationInternal(String message, Type type, Map meta) { + Assert.hasText(message, "Elicitation message must not be empty"); + Assert.notNull(type, "Elicitation response type must not be null"); + + Map schema = typeSchemaCache.computeIfAbsent(type, t -> this.generateElicitSchema(t)); + + return this.elicitation(ElicitRequest.builder().message(message).requestedSchema(schema).meta(meta).build()); + } + + private Map generateElicitSchema(Type type) { + Map schema = JsonParser.fromJson(JsonSchemaGenerator.generateFromType(type), MAP_TYPE_REF); + // remove $schema as elicitation schema does not support it + schema.remove("$schema"); + return schema; + } + + // Sampling + + @Override + public Optional sampling(String... messages) { + return this.sampling(s -> s.message(messages)); + } + + @Override + public Optional sampling(Consumer samplingSpec) { + Assert.notNull(samplingSpec, "Sampling spec consumer must not be null"); + DefaultSamplingSpec spec = new DefaultSamplingSpec(); + samplingSpec.accept(spec); + + var progressToken = this.request.progressToken(); + + if (!Utils.hasText(progressToken)) { + logger.warn("Progress notification not supported by the client!"); + } + return this.sampling(McpSchema.CreateMessageRequest.builder() + .messages(spec.messages) + .modelPreferences(spec.modelPreferences) + .systemPrompt(spec.systemPrompt) + .temperature(spec.temperature) + .maxTokens(spec.maxTokens != null && spec.maxTokens > 0 ? spec.maxTokens : 500) + .stopSequences(spec.stopSequences.isEmpty() ? null : spec.stopSequences) + .includeContext(spec.includeContextStrategy) + .meta(spec.metadata.isEmpty() ? null : spec.metadata) + .progressToken(progressToken) + .meta(spec.meta.isEmpty() ? null : spec.meta) + .build()); + } + + @Override + public Optional sampling(CreateMessageRequest createMessageRequest) { + + // check if supported + if (this.exchange.getClientCapabilities() == null || this.exchange.getClientCapabilities().sampling() == null) { + logger.warn("Sampling not supported by the client! Ignoring the sampling request for messages:" + + createMessageRequest); + return Optional.empty(); + } + + return Optional.of(this.exchange.createMessage(createMessageRequest)); + } + + // Progress + + @Override + public void progress(int percentage) { + Assert.isTrue(percentage >= 0 && percentage <= 100, "Percentage must be between 0 and 100"); + this.progress(p -> p.progress(percentage / 100.0).total(1.0).message(null)); + } + + @Override + public void progress(Consumer progressSpec) { + + Assert.notNull(progressSpec, "Progress spec consumer must not be null"); + DefaultProgressSpec spec = new DefaultProgressSpec(); + + progressSpec.accept(spec); + + if (!Utils.hasText(this.request.progressToken())) { + logger.warn("Progress notification not supported by the client!"); + return; + } + + this.progress(new ProgressNotification(this.request.progressToken(), spec.progress, spec.total, spec.message, + spec.meta)); + } + + @Override + public void progress(ProgressNotification progressNotification) { + this.exchange.progressNotification(progressNotification); + } + + // Ping + + @Override + public void ping() { + this.exchange.ping(); + } + + // Logging + + @Override + public void log(Consumer logSpec) { + Assert.notNull(logSpec, "Logging spec consumer must not be null"); + DefaultLoggingSpec spec = new DefaultLoggingSpec(); + logSpec.accept(spec); + + this.exchange.loggingNotification(LoggingMessageNotification.builder() + .data(spec.message) + .level(spec.level) + .logger(spec.logger) + .meta(spec.meta) + .build()); + } + + @Override + public void debug(String message) { + this.logInternal(message, LoggingLevel.DEBUG); + } + + @Override + public void info(String message) { + this.logInternal(message, LoggingLevel.INFO); + } + + @Override + public void warn(String message) { + this.logInternal(message, LoggingLevel.WARNING); + } + + @Override + public void error(String message) { + this.logInternal(message, LoggingLevel.ERROR); + } + + private void logInternal(String message, LoggingLevel level) { + Assert.hasText(message, "Log message must not be empty"); + this.exchange.loggingNotification(LoggingMessageNotification.builder().data(message).level(level).build()); + } + + // Getters + + @Override + public McpSchema.Request request() { + return this.request; + } + + @Override + public McpSyncServerExchange exchange() { + return this.exchange; + } + + @Override + public String sessionId() { + return this.exchange.sessionId(); + } + + @Override + public Implementation clientInfo() { + return this.exchange.getClientInfo(); + } + + @Override + public ClientCapabilities clientCapabilities() { + return this.exchange.getClientCapabilities(); + } + + @Override + public Map requestMeta() { + return this.request.meta(); + } + + @Override + public McpTransportContext transportContext() { + return this.exchange.transportContext(); + } + + // Builder + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private McpSchema.Request request; + + private McpSyncServerExchange exchange; + + private McpTransportContext transportContext; + + private boolean isStateless = false; + + private Builder() { + } + + public Builder request(McpSchema.Request request) { + this.request = request; + return this; + } + + public Builder exchange(McpSyncServerExchange exchange) { + this.exchange = exchange; + return this; + } + + public Builder transportContext(McpTransportContext transportContext) { + this.transportContext = transportContext; + return this; + } + + public Builder stateless(boolean isStateless) { + this.isStateless = isStateless; + return this; + } + + public McpSyncRequestContext build() { + if (this.isStateless) { + return new StatelessMcpSyncRequestContext(this.request, this.transportContext); + } + return new DefaultMcpSyncRequestContext(this.request, this.exchange); + } + + } + + public final static class StatelessMcpSyncRequestContext implements McpSyncRequestContext { + + private static final Logger logger = LoggerFactory.getLogger(StatelessMcpSyncRequestContext.class); + + private final McpSchema.Request request; + + private final McpTransportContext transportContext; + + private StatelessMcpSyncRequestContext(McpSchema.Request request, McpTransportContext transportContext) { + this.request = request; + this.transportContext = transportContext; + } + + @Override + public Optional roots() { + logger.warn("Roots not supported by the client! Ignoring the roots request"); + return Optional.empty(); + } + + @Override + public Optional elicitation(Consumer elicitationSpec) { + logger.warn("Stateless servers do not support elicitation! Ignoring the elicitation request"); + return Optional.empty(); + } + + @Override + public Optional elicitation(String message, Type type) { + logger.warn("Stateless servers do not support elicitation! Ignoring the elicitation request"); + return Optional.empty(); + } + + @Override + public Optional elicitation(ElicitRequest elicitRequest) { + logger.warn("Stateless servers do not support elicitation! Ignoring the elicitation request"); + return Optional.empty(); + } + + @Override + public Optional sampling(String... messages) { + logger.warn("Stateless servers do not support sampling! Ignoring the sampling request"); + return Optional.empty(); + } + + @Override + public Optional sampling(Consumer samplingSpec) { + logger.warn("Stateless servers do not support sampling! Ignoring the sampling request"); + return Optional.empty(); + } + + @Override + public Optional sampling(CreateMessageRequest createMessageRequest) { + logger.warn("Stateless servers do not support sampling! Ignoring the sampling request"); + return Optional.empty(); + } + + @Override + public void progress(int progress) { + logger.warn("Stateless servers do not support progress notifications! Ignoring the progress request"); + } + + @Override + public void progress(Consumer progressSpec) { + logger.warn("Stateless servers do not support progress notifications! Ignoring the progress request"); + } + + @Override + public void progress(ProgressNotification progressNotification) { + logger.warn("Stateless servers do not support progress notifications! Ignoring the progress request"); + } + + @Override + public void ping() { + logger.warn("Stateless servers do not support ping! Ignoring the ping request"); + } + + @Override + public void log(Consumer logSpec) { + logger.warn("Stateless servers do not support logging! Ignoring the logging request"); + } + + @Override + public void debug(String message) { + logger.warn("Stateless servers do not support debugging! Ignoring the debugging request"); + } + + @Override + public void info(String message) { + logger.warn("Stateless servers do not support info logging! Ignoring the info request"); + } + + @Override + public void warn(String message) { + logger.warn("Stateless servers do not support warning logging! Ignoring the warning request"); + } + + @Override + public void error(String message) { + logger.warn("Stateless servers do not support error logging! Ignoring the error request"); + } + + public McpSchema.Request request() { + return this.request; + } + + public McpTransportContext transportContext() { + return transportContext; + } + + public String sessionId() { + logger.warn("Stateless servers do not support session ID! Returning null"); + return null; + } + + public Implementation clientInfo() { + logger.warn("Stateless servers do not support client info! Returning null"); + return null; + } + + public ClientCapabilities clientCapabilities() { + logger.warn("Stateless servers do not support client capabilities! Returning null"); + return null; + } + + public Map requestMeta() { + return this.request.meta(); + } + + @Override + public McpSyncServerExchange exchange() { + logger.warn("Stateless servers do not support exchange! Returning null"); + return null; + } + + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultProgressSpec.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultProgressSpec.java new file mode 100644 index 0000000..65af97f --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultProgressSpec.java @@ -0,0 +1,59 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.context; + +import java.util.HashMap; +import java.util.Map; + +import org.springaicommunity.mcp.context.McpRequestContextTypes.ProgressSpec; + +/** + * @author Christian Tzolov + */ +public class DefaultProgressSpec implements ProgressSpec { + + protected double progress = 0.0; + + protected double total = 1.0; + + protected String message; + + protected Map meta = new HashMap<>(); + + @Override + public ProgressSpec progress(double progress) { + this.progress = progress; + return this; + } + + @Override + public ProgressSpec total(double total) { + this.total = total; + return this; + } + + @Override + public ProgressSpec message(String message) { + this.message = message; + return this; + } + + @Override + public ProgressSpec meta(Map m) { + if (m != null) { + this.meta.putAll(m); + } + return this; + } + + @Override + public ProgressSpec meta(String k, Object v) { + if (k != null && v != null) { + this.meta.put(k, v); + } + return this; + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultSamplingSpec.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultSamplingSpec.java new file mode 100644 index 0000000..1801b7f --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultSamplingSpec.java @@ -0,0 +1,199 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.context; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; + +import io.modelcontextprotocol.spec.McpSchema.AudioContent; +import io.modelcontextprotocol.spec.McpSchema.Content; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest.ContextInclusionStrategy; +import io.modelcontextprotocol.spec.McpSchema.EmbeddedResource; +import io.modelcontextprotocol.spec.McpSchema.ImageContent; +import io.modelcontextprotocol.spec.McpSchema.ModelHint; +import io.modelcontextprotocol.spec.McpSchema.ModelPreferences; +import io.modelcontextprotocol.spec.McpSchema.ResourceLink; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.SamplingMessage; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import io.modelcontextprotocol.util.Assert; +import org.springaicommunity.mcp.context.McpRequestContextTypes.ModelPreferenceSpec; +import org.springaicommunity.mcp.context.McpRequestContextTypes.SamplingSpec; + +/** + * @author Christian Tzolov + */ +public class DefaultSamplingSpec implements SamplingSpec { + + protected List messages = new ArrayList<>(); + + protected ModelPreferences modelPreferences; + + protected String systemPrompt; + + protected Double temperature; + + protected Integer maxTokens; + + protected List stopSequences = new ArrayList<>(); + + protected Map metadata = new HashMap<>(); + + protected Map meta = new HashMap<>(); + + protected ContextInclusionStrategy includeContextStrategy = ContextInclusionStrategy.NONE; + + @Override + public SamplingSpec message(ResourceLink... content) { + return this.messageInternal(content); + } + + @Override + public SamplingSpec message(EmbeddedResource... content) { + return this.messageInternal(content); + } + + @Override + public SamplingSpec message(AudioContent... content) { + return this.messageInternal(content); + } + + @Override + public SamplingSpec message(ImageContent... content) { + return this.messageInternal(content); + } + + @Override + public SamplingSpec message(TextContent... content) { + return this.messageInternal(content); + } + + private SamplingSpec messageInternal(Content... content) { + this.messages.addAll(List.of(content).stream().map(c -> new SamplingMessage(Role.USER, c)).toList()); + return this; + } + + @Override + public SamplingSpec message(SamplingMessage... message) { + this.messages.addAll(List.of(message)); + return this; + } + + @Override + public SamplingSpec modelPreferences(Consumer modelPreferenceSpec) { + var modelPreferencesSpec = new DefaultModelPreferenceSpec(); + modelPreferenceSpec.accept(modelPreferencesSpec); + + this.modelPreferences = ModelPreferences.builder() + .hints(modelPreferencesSpec.modelHints) + .costPriority(modelPreferencesSpec.costPriority) + .speedPriority(modelPreferencesSpec.speedPriority) + .intelligencePriority(modelPreferencesSpec.intelligencePriority) + .build(); + return this; + } + + @Override + public SamplingSpec systemPrompt(String systemPrompt) { + this.systemPrompt = systemPrompt; + return this; + } + + @Override + public SamplingSpec includeContextStrategy(ContextInclusionStrategy includeContextStrategy) { + this.includeContextStrategy = includeContextStrategy; + return this; + } + + @Override + public SamplingSpec temperature(Double temperature) { + this.temperature = temperature; + return this; + } + + @Override + public SamplingSpec maxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + @Override + public SamplingSpec stopSequences(String... stopSequences) { + this.stopSequences.addAll(List.of(stopSequences)); + return this; + } + + @Override + public SamplingSpec metadata(Map m) { + this.metadata.putAll(m); + return this; + } + + @Override + public SamplingSpec metadata(String k, Object v) { + this.metadata.put(k, v); + return this; + } + + @Override + public SamplingSpec meta(Map m) { + this.meta.putAll(m); + return this; + } + + @Override + public SamplingSpec meta(String k, Object v) { + this.meta.put(k, v); + return this; + } + + public static class DefaultModelPreferenceSpec implements ModelPreferenceSpec { + + private List modelHints = new ArrayList<>(); + + private Double costPriority; + + private Double speedPriority; + + private Double intelligencePriority; + + @Override + public ModelPreferenceSpec modelHints(String... models) { + Assert.notNull(models, "Models must not be null"); + this.modelHints.addAll(List.of(models).stream().map(ModelHint::new).toList()); + return this; + } + + @Override + public ModelPreferenceSpec modelHint(String modelHint) { + Assert.notNull(modelHint, "Model hint must not be null"); + this.modelHints.add(new ModelHint(modelHint)); + return this; + } + + @Override + public ModelPreferenceSpec costPriority(Double costPriority) { + this.costPriority = costPriority; + return this; + } + + @Override + public ModelPreferenceSpec speedPriority(Double speedPriority) { + this.speedPriority = speedPriority; + return this; + } + + @Override + public ModelPreferenceSpec intelligencePriority(Double intelligencePriority) { + this.intelligencePriority = intelligencePriority; + return this; + } + + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpAsyncRequestContext.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpAsyncRequestContext.java new file mode 100644 index 0000000..96dccce --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpAsyncRequestContext.java @@ -0,0 +1,76 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.context; + +import java.lang.reflect.Type; +import java.util.function.Consumer; + +import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; +import io.modelcontextprotocol.spec.McpSchema.ListRootsResult; +import io.modelcontextprotocol.spec.McpSchema.ProgressNotification; +import reactor.core.publisher.Mono; + +/** + * Async (Reactor) version of McpSyncRequestContext that returns Mono of value types. + * + * @author Christian Tzolov + */ +public interface McpAsyncRequestContext extends McpRequestContextTypes { + + // -------------------------------------- + // Roots + // -------------------------------------- + Mono roots(); + + // -------------------------------------- + // Elicitation + // -------------------------------------- + Mono elicitation(Consumer elicitationSpec); + + Mono elicitation(String message, Type type); + + Mono elicitation(ElicitRequest elicitRequest); + + // -------------------------------------- + // Sampling + // -------------------------------------- + Mono sampling(String... messages); + + Mono sampling(Consumer samplingSpec); + + Mono sampling(CreateMessageRequest createMessageRequest); + + // -------------------------------------- + // Progress + // -------------------------------------- + Mono progress(int progress); + + Mono progress(Consumer progressSpec); + + Mono progress(ProgressNotification progressNotification); + + // -------------------------------------- + // Ping + // -------------------------------------- + Mono ping(); + + // -------------------------------------- + // Logging + // -------------------------------------- + Mono log(Consumer logSpec); + + Mono debug(String message); + + Mono info(String message); + + Mono warn(String message); + + Mono error(String message); + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpRequestContextTypes.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpRequestContextTypes.java new file mode 100644 index 0000000..754baf1 --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpRequestContextTypes.java @@ -0,0 +1,164 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.context; + +import java.lang.reflect.Type; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; + +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.AudioContent; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest.ContextInclusionStrategy; +import io.modelcontextprotocol.spec.McpSchema.EmbeddedResource; +import io.modelcontextprotocol.spec.McpSchema.ImageContent; +import io.modelcontextprotocol.spec.McpSchema.Implementation; +import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; +import io.modelcontextprotocol.spec.McpSchema.ResourceLink; +import io.modelcontextprotocol.spec.McpSchema.SamplingMessage; +import io.modelcontextprotocol.spec.McpSchema.TextContent; + +/** + * @author Christian Tzolov + */ +public interface McpRequestContextTypes { + + // -------------------------------------- + // Elicitation + // -------------------------------------- + + interface ElicitationSpec { + + /** + * The prompt message to display to the user + */ + ElicitationSpec message(String message); + + /** + * The response type defining the expected response structure. Note that + * elicitation responses are subject to a restricted subset of JSON Schema types. + */ + ElicitationSpec responseType(Type type); + + ElicitationSpec meta(Map m); + + ElicitationSpec meta(String k, Object v); + + } + + // -------------------------------------- + // Sampling + // -------------------------------------- + + interface ModelPreferenceSpec { + + ModelPreferenceSpec modelHints(String... models); + + ModelPreferenceSpec modelHint(String modelHint); + + ModelPreferenceSpec costPriority(Double costPriority); + + ModelPreferenceSpec speedPriority(Double speedPriority); + + ModelPreferenceSpec intelligencePriority(Double intelligencePriority); + + } + + interface SamplingSpec { + + SamplingSpec message(ResourceLink... content); + + SamplingSpec message(EmbeddedResource... content); + + SamplingSpec message(AudioContent... content); + + SamplingSpec message(ImageContent... content); + + SamplingSpec message(TextContent... content); + + default SamplingSpec message(String... text) { + return message(List.of(text).stream().map(t -> new TextContent(t)).toList().toArray(new TextContent[0])); + } + + SamplingSpec message(SamplingMessage... message); + + SamplingSpec modelPreferences(Consumer modelPreferenceSpec); + + SamplingSpec systemPrompt(String systemPrompt); + + SamplingSpec includeContextStrategy(ContextInclusionStrategy includeContextStrategy); + + SamplingSpec temperature(Double temperature); + + SamplingSpec maxTokens(Integer maxTokens); + + SamplingSpec stopSequences(String... stopSequences); + + SamplingSpec metadata(Map m); + + SamplingSpec metadata(String k, Object v); + + SamplingSpec meta(Map m); + + SamplingSpec meta(String k, Object v); + + } + + // -------------------------------------- + // Progress + // -------------------------------------- + + interface ProgressSpec { + + ProgressSpec progress(double progress); + + ProgressSpec total(double total); + + ProgressSpec message(String message); + + ProgressSpec meta(Map m); + + ProgressSpec meta(String k, Object v); + + } + + // -------------------------------------- + // Logging + // -------------------------------------- + + interface LoggingSpec { + + LoggingSpec message(String message); + + LoggingSpec logger(String logger); + + LoggingSpec level(LoggingLevel level); + + LoggingSpec meta(Map m); + + LoggingSpec meta(String k, Object v); + + } + + // -------------------------------------- + // Getters + // -------------------------------------- + McpSchema.Request request(); + + ET exchange(); + + String sessionId(); + + Implementation clientInfo(); + + ClientCapabilities clientCapabilities(); + + Map requestMeta(); + + McpTransportContext transportContext(); + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpSyncRequestContext.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpSyncRequestContext.java new file mode 100644 index 0000000..c36a538 --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpSyncRequestContext.java @@ -0,0 +1,74 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.context; + +import java.lang.reflect.Type; +import java.util.Optional; +import java.util.function.Consumer; + +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; +import io.modelcontextprotocol.spec.McpSchema.ListRootsResult; +import io.modelcontextprotocol.spec.McpSchema.ProgressNotification; + +/** + * @author Christian Tzolov + */ +public interface McpSyncRequestContext extends McpRequestContextTypes { + + // -------------------------------------- + // Roots + // -------------------------------------- + Optional roots(); + + // -------------------------------------- + // Elicitation + // -------------------------------------- + Optional elicitation(Consumer elicitationSpec); + + Optional elicitation(String message, Type type); + + Optional elicitation(ElicitRequest elicitRequest); + + // -------------------------------------- + // Sampling + // -------------------------------------- + Optional sampling(String... messages); + + Optional sampling(Consumer samplingSpec); + + Optional sampling(CreateMessageRequest createMessageRequest); + + // -------------------------------------- + // Progress + // -------------------------------------- + void progress(int progress); + + void progress(Consumer progressSpec); + + void progress(ProgressNotification progressNotification); + + // -------------------------------------- + // Ping + // -------------------------------------- + void ping(); + + // -------------------------------------- + // Logging + // -------------------------------------- + void log(Consumer logSpec); + + void debug(String message); + + void info(String message); + + void warn(String message); + + void error(String message); + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractAsyncMcpToolMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractAsyncMcpToolMethodCallback.java index b1e19be..43cec0a 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractAsyncMcpToolMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractAsyncMcpToolMethodCallback.java @@ -22,6 +22,7 @@ import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import org.reactivestreams.Publisher; import org.springaicommunity.mcp.annotation.McpTool; +import org.springaicommunity.mcp.context.McpRequestContextTypes; import org.springaicommunity.mcp.method.tool.utils.JsonParser; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -37,7 +38,8 @@ * McpTransportContext) * @author Christian Tzolov */ -public abstract class AbstractAsyncMcpToolMethodCallback extends AbstractMcpToolMethodCallback { +public abstract class AbstractAsyncMcpToolMethodCallback> + extends AbstractMcpToolMethodCallback { protected final Class toolCallExceptionClass; diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractMcpToolMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractMcpToolMethodCallback.java index f5fa170..3e9ea3f 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractMcpToolMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractMcpToolMethodCallback.java @@ -28,6 +28,9 @@ import org.springaicommunity.mcp.annotation.McpMeta; import org.springaicommunity.mcp.annotation.McpProgressToken; import org.springaicommunity.mcp.annotation.McpTool; +import org.springaicommunity.mcp.context.McpAsyncRequestContext; +import org.springaicommunity.mcp.context.McpRequestContextTypes; +import org.springaicommunity.mcp.context.McpSyncRequestContext; import org.springaicommunity.mcp.method.tool.utils.JsonParser; /** @@ -41,7 +44,7 @@ * McpSyncServerExchange, or McpAsyncServerExchange) * @author Christian Tzolov */ -public abstract class AbstractMcpToolMethodCallback { +public abstract class AbstractMcpToolMethodCallback { protected final Method toolMethod; @@ -89,7 +92,15 @@ protected Object callMethod(Object[] methodArguments) { */ protected Object[] buildMethodArguments(T exchangeOrContext, Map toolInputArguments, CallToolRequest request) { + return Stream.of(this.toolMethod.getParameters()).map(parameter -> { + + if (McpSyncRequestContext.class.isAssignableFrom(parameter.getType()) + || McpAsyncRequestContext.class.isAssignableFrom(parameter.getType())) { + + return this.createRequestContext(exchangeOrContext, request); + } + // Check if parameter is annotated with @McpProgressToken if (parameter.isAnnotationPresent(McpProgressToken.class)) { // Return the progress token from the request @@ -203,4 +214,6 @@ protected Throwable findCauseUsingPlainJava(Throwable throwable) { return rootCause; } + protected abstract RC createRequestContext(T exchange, CallToolRequest request); + } diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractSyncMcpToolMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractSyncMcpToolMethodCallback.java index d05d016..0ad552f 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractSyncMcpToolMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AbstractSyncMcpToolMethodCallback.java @@ -20,6 +20,7 @@ import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import org.springaicommunity.mcp.context.McpRequestContextTypes; /** * Abstract base class for creating Function callbacks around synchronous tool methods. @@ -33,7 +34,8 @@ * McpSyncServerExchange) * @author Christian Tzolov */ -public abstract class AbstractSyncMcpToolMethodCallback extends AbstractAsyncMcpToolMethodCallback { +public abstract class AbstractSyncMcpToolMethodCallback> + extends AbstractAsyncMcpToolMethodCallback { protected AbstractSyncMcpToolMethodCallback(ReturnMode returnMode, Method toolMethod, Object toolObject, Class toolCallExceptionClass) { diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AsyncMcpToolMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AsyncMcpToolMethodCallback.java index ab51cab..d8cdfd8 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AsyncMcpToolMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AsyncMcpToolMethodCallback.java @@ -19,11 +19,12 @@ import java.lang.reflect.Method; import java.util.function.BiFunction; -import org.springaicommunity.mcp.annotation.McpTool; - import io.modelcontextprotocol.server.McpAsyncServerExchange; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import org.springaicommunity.mcp.annotation.McpTool; +import org.springaicommunity.mcp.context.DefaultMcpAsyncRequestContext; +import org.springaicommunity.mcp.context.McpAsyncRequestContext; import reactor.core.publisher.Mono; /** @@ -34,7 +35,8 @@ * * @author Christian Tzolov */ -public final class AsyncMcpToolMethodCallback extends AbstractAsyncMcpToolMethodCallback +public final class AsyncMcpToolMethodCallback + extends AbstractAsyncMcpToolMethodCallback implements BiFunction> { public AsyncMcpToolMethodCallback(ReturnMode returnMode, Method toolMethod, Object toolObject) { @@ -48,7 +50,14 @@ public AsyncMcpToolMethodCallback(ReturnMode returnMode, Method toolMethod, Obje @Override protected boolean isExchangeOrContextType(Class paramType) { - return McpAsyncServerExchange.class.isAssignableFrom(paramType); + return McpAsyncServerExchange.class.isAssignableFrom(paramType) + || McpAsyncRequestContext.class.isAssignableFrom(paramType); + } + + @Override + protected McpAsyncRequestContext createRequestContext(McpAsyncServerExchange exchange, CallToolRequest request) { + + return DefaultMcpAsyncRequestContext.builder().request(request).exchange(exchange).build(); } /** diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AsyncStatelessMcpToolMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AsyncStatelessMcpToolMethodCallback.java index 300f9cf..06957d9 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AsyncStatelessMcpToolMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/AsyncStatelessMcpToolMethodCallback.java @@ -18,10 +18,12 @@ import java.util.function.BiFunction; -import org.springaicommunity.mcp.annotation.McpTool; import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import org.springaicommunity.mcp.annotation.McpTool; +import org.springaicommunity.mcp.context.DefaultMcpAsyncRequestContext; +import org.springaicommunity.mcp.context.McpAsyncRequestContext; import reactor.core.publisher.Mono; /** @@ -33,7 +35,8 @@ * * @author Christian Tzolov */ -public final class AsyncStatelessMcpToolMethodCallback extends AbstractAsyncMcpToolMethodCallback +public final class AsyncStatelessMcpToolMethodCallback + extends AbstractAsyncMcpToolMethodCallback implements BiFunction> { public AsyncStatelessMcpToolMethodCallback(ReturnMode returnMode, java.lang.reflect.Method toolMethod, @@ -48,7 +51,18 @@ public AsyncStatelessMcpToolMethodCallback(ReturnMode returnMode, java.lang.refl @Override protected boolean isExchangeOrContextType(Class paramType) { - return McpTransportContext.class.isAssignableFrom(paramType); + return McpTransportContext.class.isAssignableFrom(paramType) + || McpAsyncRequestContext.class.isAssignableFrom(paramType); + } + + @Override + protected McpAsyncRequestContext createRequestContext(McpTransportContext exchange, CallToolRequest request) { + + return DefaultMcpAsyncRequestContext.builder() + .request(request) + .transportContext(exchange) + .stateless(true) + .build(); } /** diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncMcpToolMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncMcpToolMethodCallback.java index eaba1b2..fcb14e3 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncMcpToolMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncMcpToolMethodCallback.java @@ -19,7 +19,8 @@ import java.util.function.BiFunction; import org.springaicommunity.mcp.annotation.McpTool; - +import org.springaicommunity.mcp.context.DefaultMcpSyncRequestContext; +import org.springaicommunity.mcp.context.McpSyncRequestContext; import io.modelcontextprotocol.server.McpSyncServerExchange; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; @@ -32,7 +33,8 @@ * * @author Christian Tzolov */ -public final class SyncMcpToolMethodCallback extends AbstractSyncMcpToolMethodCallback +public final class SyncMcpToolMethodCallback + extends AbstractSyncMcpToolMethodCallback implements BiFunction { public SyncMcpToolMethodCallback(ReturnMode returnMode, java.lang.reflect.Method toolMethod, Object toolObject) { @@ -46,7 +48,14 @@ public SyncMcpToolMethodCallback(ReturnMode returnMode, java.lang.reflect.Method @Override protected boolean isExchangeOrContextType(Class paramType) { - return McpSyncServerExchange.class.isAssignableFrom(paramType); + return McpSyncServerExchange.class.isAssignableFrom(paramType) + || McpSyncRequestContext.class.isAssignableFrom(paramType); + } + + @Override + protected McpSyncRequestContext createRequestContext(McpSyncServerExchange exchange, CallToolRequest request) { + + return DefaultMcpSyncRequestContext.builder().request(request).exchange(exchange).build(); } /** diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncStatelessMcpToolMethodCallback.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncStatelessMcpToolMethodCallback.java index a17cc9f..88029e4 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncStatelessMcpToolMethodCallback.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/SyncStatelessMcpToolMethodCallback.java @@ -18,10 +18,12 @@ import java.util.function.BiFunction; -import org.springaicommunity.mcp.annotation.McpTool; import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import org.springaicommunity.mcp.annotation.McpTool; +import org.springaicommunity.mcp.context.DefaultMcpSyncRequestContext; +import org.springaicommunity.mcp.context.McpSyncRequestContext; /** * Class for creating Function callbacks around tool methods. @@ -32,7 +34,8 @@ * @author James Ward * @author Christian Tzolov */ -public final class SyncStatelessMcpToolMethodCallback extends AbstractSyncMcpToolMethodCallback +public final class SyncStatelessMcpToolMethodCallback + extends AbstractSyncMcpToolMethodCallback implements BiFunction { public SyncStatelessMcpToolMethodCallback(ReturnMode returnMode, java.lang.reflect.Method toolMethod, @@ -47,7 +50,18 @@ public SyncStatelessMcpToolMethodCallback(ReturnMode returnMode, java.lang.refle @Override protected boolean isExchangeOrContextType(Class paramType) { - return McpTransportContext.class.isAssignableFrom(paramType); + return McpTransportContext.class.isAssignableFrom(paramType) + || McpSyncRequestContext.class.isAssignableFrom(paramType); + } + + @Override + protected McpSyncRequestContext createRequestContext(McpTransportContext exchange, CallToolRequest request) { + + return DefaultMcpSyncRequestContext.builder() + .request(request) + .transportContext(exchange) + .stateless(true) + .build(); } @Override diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/utils/JsonSchemaGenerator.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/utils/JsonSchemaGenerator.java index 814e2da..a480d67 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/utils/JsonSchemaGenerator.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/method/tool/utils/JsonSchemaGenerator.java @@ -27,7 +27,8 @@ import org.springaicommunity.mcp.annotation.McpMeta; import org.springaicommunity.mcp.annotation.McpProgressToken; import org.springaicommunity.mcp.annotation.McpToolParam; - +import org.springaicommunity.mcp.context.McpAsyncRequestContext; +import org.springaicommunity.mcp.context.McpSyncRequestContext; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonPropertyDescription; import com.fasterxml.jackson.databind.JsonNode; @@ -110,7 +111,9 @@ private static String internalGenerateFromMethodArguments(Method method) { // @McpProgressToken annotated parameters, and McpMeta parameters boolean hasOtherParams = Arrays.stream(method.getParameters()).anyMatch(param -> { Class type = param.getType(); - return !CallToolRequest.class.isAssignableFrom(type) + return !McpSyncRequestContext.class.isAssignableFrom(type) + && !McpAsyncRequestContext.class.isAssignableFrom(type) + && !CallToolRequest.class.isAssignableFrom(type) && !McpSyncServerExchange.class.isAssignableFrom(type) && !McpAsyncServerExchange.class.isAssignableFrom(type) && !param.isAnnotationPresent(McpProgressToken.class) && !McpMeta.class.isAssignableFrom(type); @@ -150,7 +153,9 @@ private static String internalGenerateFromMethodArguments(Method method) { // Skip special parameter types if (parameterType instanceof Class parameterClass - && (ClassUtils.isAssignable(McpSyncServerExchange.class, parameterClass) + && (ClassUtils.isAssignable(McpSyncRequestContext.class, parameterClass) + || ClassUtils.isAssignable(McpAsyncRequestContext.class, parameterClass) + || ClassUtils.isAssignable(McpSyncServerExchange.class, parameterClass) || ClassUtils.isAssignable(McpAsyncServerExchange.class, parameterClass) || ClassUtils.isAssignable(CallToolRequest.class, parameterClass))) { continue; diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultElicitationSpecTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultElicitationSpecTests.java new file mode 100644 index 0000000..f7c703d --- /dev/null +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultElicitationSpecTests.java @@ -0,0 +1,134 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.context; + +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link DefaultElicitationSpec}. + * + * @author Christian Tzolov + */ +public class DefaultElicitationSpecTests { + + @Test + public void testMessageSetting() { + DefaultElicitationSpec spec = new DefaultElicitationSpec(); + + spec.message("Test message"); + + assertThat(spec.message).isEqualTo("Test message"); + } + + @Test + public void testMessageWithEmptyString() { + DefaultElicitationSpec spec = new DefaultElicitationSpec(); + + assertThatThrownBy(() -> spec.message("")).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Message must not be empty"); + } + + @Test + public void testMessageWithNull() { + DefaultElicitationSpec spec = new DefaultElicitationSpec(); + + assertThatThrownBy(() -> spec.message(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Message must not be empty"); + } + + @Test + public void testResponseTypeSetting() { + DefaultElicitationSpec spec = new DefaultElicitationSpec(); + + spec.responseType(String.class); + + assertThat(spec.responseType).isEqualTo(String.class); + } + + @Test + public void testResponseTypeWithNull() { + DefaultElicitationSpec spec = new DefaultElicitationSpec(); + + assertThatThrownBy(() -> spec.responseType(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Response type must not be null"); + } + + @Test + public void testMetaWithMap() { + DefaultElicitationSpec spec = new DefaultElicitationSpec(); + Map metaMap = Map.of("key1", "value1", "key2", "value2"); + + spec.meta(metaMap); + + assertThat(spec.meta).containsEntry("key1", "value1").containsEntry("key2", "value2"); + } + + @Test + public void testMetaWithNullMap() { + DefaultElicitationSpec spec = new DefaultElicitationSpec(); + + assertThatThrownBy(() -> spec.meta((Map) null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Meta map must not be null"); + } + + @Test + public void testMetaWithKeyValue() { + DefaultElicitationSpec spec = new DefaultElicitationSpec(); + + spec.meta("key", "value"); + + assertThat(spec.meta).containsEntry("key", "value"); + } + + @Test + public void testMetaWithNullKey() { + DefaultElicitationSpec spec = new DefaultElicitationSpec(); + + spec.meta(null, "value"); + + assertThat(spec.meta).isEmpty(); + } + + @Test + public void testMetaWithNullValue() { + DefaultElicitationSpec spec = new DefaultElicitationSpec(); + + spec.meta("key", null); + + assertThat(spec.meta).isEmpty(); + } + + @Test + public void testMetaMultipleEntries() { + DefaultElicitationSpec spec = new DefaultElicitationSpec(); + + spec.meta("key1", "value1").meta("key2", "value2").meta("key3", "value3"); + + assertThat(spec.meta).hasSize(3) + .containsEntry("key1", "value1") + .containsEntry("key2", "value2") + .containsEntry("key3", "value3"); + } + + @Test + public void testFluentInterface() { + DefaultElicitationSpec spec = new DefaultElicitationSpec(); + + McpSyncRequestContext.ElicitationSpec result = spec.message("Test message") + .responseType(String.class) + .meta("key", "value"); + + assertThat(result).isSameAs(spec); + assertThat(spec.message).isEqualTo("Test message"); + assertThat(spec.responseType).isEqualTo(String.class); + assertThat(spec.meta).containsEntry("key", "value"); + } + +} diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultLoggingSpecTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultLoggingSpecTests.java new file mode 100644 index 0000000..8b13f08 --- /dev/null +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultLoggingSpecTests.java @@ -0,0 +1,146 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.context; + +import java.util.Map; + +import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link DefaultLoggingSpec}. + * + * @author Christian Tzolov + */ +public class DefaultLoggingSpecTests { + + @Test + public void testMessageSetting() { + DefaultLoggingSpec spec = new DefaultLoggingSpec(); + + spec.message("Test log message"); + + assertThat(spec.message).isEqualTo("Test log message"); + } + + @Test + public void testLoggerSetting() { + DefaultLoggingSpec spec = new DefaultLoggingSpec(); + + spec.logger("test-logger"); + + assertThat(spec.logger).isEqualTo("test-logger"); + } + + @Test + public void testLevelSetting() { + DefaultLoggingSpec spec = new DefaultLoggingSpec(); + + spec.level(LoggingLevel.ERROR); + + assertThat(spec.level).isEqualTo(LoggingLevel.ERROR); + } + + @Test + public void testDefaultLevel() { + DefaultLoggingSpec spec = new DefaultLoggingSpec(); + + assertThat(spec.level).isEqualTo(LoggingLevel.INFO); + } + + @Test + public void testMetaWithMap() { + DefaultLoggingSpec spec = new DefaultLoggingSpec(); + Map metaMap = Map.of("key1", "value1", "key2", "value2"); + + spec.meta(metaMap); + + assertThat(spec.meta).containsEntry("key1", "value1").containsEntry("key2", "value2"); + } + + @Test + public void testMetaWithNullMap() { + DefaultLoggingSpec spec = new DefaultLoggingSpec(); + + spec.meta((Map) null); + + assertThat(spec.meta).isEmpty(); + } + + @Test + public void testMetaWithKeyValue() { + DefaultLoggingSpec spec = new DefaultLoggingSpec(); + + spec.meta("key", "value"); + + assertThat(spec.meta).containsEntry("key", "value"); + } + + @Test + public void testMetaWithNullKey() { + DefaultLoggingSpec spec = new DefaultLoggingSpec(); + + spec.meta(null, "value"); + + assertThat(spec.meta).isEmpty(); + } + + @Test + public void testMetaWithNullValue() { + DefaultLoggingSpec spec = new DefaultLoggingSpec(); + + spec.meta("key", null); + + assertThat(spec.meta).isEmpty(); + } + + @Test + public void testMetaMultipleEntries() { + DefaultLoggingSpec spec = new DefaultLoggingSpec(); + + spec.meta("key1", "value1").meta("key2", "value2").meta("key3", "value3"); + + assertThat(spec.meta).hasSize(3) + .containsEntry("key1", "value1") + .containsEntry("key2", "value2") + .containsEntry("key3", "value3"); + } + + @Test + public void testFluentInterface() { + DefaultLoggingSpec spec = new DefaultLoggingSpec(); + + McpRequestContextTypes.LoggingSpec result = spec.message("Test message") + .logger("test-logger") + .level(LoggingLevel.DEBUG) + .meta("key", "value"); + + assertThat(result).isSameAs(spec); + assertThat(spec.message).isEqualTo("Test message"); + assertThat(spec.logger).isEqualTo("test-logger"); + assertThat(spec.level).isEqualTo(LoggingLevel.DEBUG); + assertThat(spec.meta).containsEntry("key", "value"); + } + + @Test + public void testAllLoggingLevels() { + DefaultLoggingSpec spec = new DefaultLoggingSpec(); + + spec.level(LoggingLevel.DEBUG); + assertThat(spec.level).isEqualTo(LoggingLevel.DEBUG); + + spec.level(LoggingLevel.INFO); + assertThat(spec.level).isEqualTo(LoggingLevel.INFO); + + spec.level(LoggingLevel.WARNING); + assertThat(spec.level).isEqualTo(LoggingLevel.WARNING); + + spec.level(LoggingLevel.ERROR); + assertThat(spec.level).isEqualTo(LoggingLevel.ERROR); + } + +} diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultMcpAsyncRequestContextTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultMcpAsyncRequestContextTests.java new file mode 100644 index 0000000..f687e97 --- /dev/null +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultMcpAsyncRequestContextTests.java @@ -0,0 +1,564 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.context; + +import java.util.Map; + +import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; +import io.modelcontextprotocol.spec.McpSchema.Implementation; +import io.modelcontextprotocol.spec.McpSchema.ListRootsResult; +import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; +import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; +import io.modelcontextprotocol.spec.McpSchema.ProgressNotification; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.SamplingMessage; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link DefaultMcpAsyncRequestContext}. + * + * @author Christian Tzolov + */ +public class DefaultMcpAsyncRequestContextTests { + + private CallToolRequest request; + + private McpAsyncServerExchange exchange; + + private McpAsyncRequestContext context; + + @BeforeEach + public void setUp() { + request = new CallToolRequest("test-tool", Map.of()); + exchange = mock(McpAsyncServerExchange.class); + context = DefaultMcpAsyncRequestContext.builder().request(request).exchange(exchange).build(); + } + + // Builder Tests + + @Test + public void testBuilderWithValidParameters() { + CallToolRequest testRequest = new CallToolRequest("test-tool", Map.of()); + McpAsyncRequestContext ctx = DefaultMcpAsyncRequestContext.builder() + .request(testRequest) + .exchange(exchange) + .build(); + + assertThat(ctx).isNotNull(); + assertThat(ctx.request()).isEqualTo(testRequest); + assertThat(ctx.exchange()).isEqualTo(exchange); + } + + @Test + public void testBuilderWithNullRequest() { + StepVerifier + .create(Mono + .fromCallable(() -> DefaultMcpAsyncRequestContext.builder().request(null).exchange(exchange).build())) + .expectErrorMatches(throwable -> throwable instanceof IllegalArgumentException + && throwable.getMessage().contains("Request must not be null")) + .verify(); + } + + @Test + public void testBuilderWithNullExchange() { + CallToolRequest testRequest = new CallToolRequest("test-tool", Map.of()); + StepVerifier + .create(Mono.fromCallable( + () -> DefaultMcpAsyncRequestContext.builder().request(testRequest).exchange(null).build())) + .expectErrorMatches(throwable -> throwable instanceof IllegalArgumentException + && throwable.getMessage().contains("Exchange must not be null")) + .verify(); + } + + // Roots Tests + + @Test + public void testRootsWhenSupported() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + McpSchema.ClientCapabilities.RootCapabilities roots = mock(McpSchema.ClientCapabilities.RootCapabilities.class); + when(capabilities.roots()).thenReturn(roots); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + ListRootsResult expectedResult = mock(ListRootsResult.class); + when(exchange.listRoots()).thenReturn(Mono.just(expectedResult)); + + StepVerifier.create(context.roots()).expectNext(expectedResult).verifyComplete(); + + verify(exchange).listRoots(); + } + + @Test + public void testRootsWhenNotSupported() { + when(exchange.getClientCapabilities()).thenReturn(null); + + StepVerifier.create(context.roots()).verifyComplete(); + } + + @Test + public void testRootsWhenCapabilitiesNullRoots() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + when(capabilities.roots()).thenReturn(null); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + StepVerifier.create(context.roots()).verifyComplete(); + } + + // Elicitation Tests + + @Test + public void testElicitationWithConsumer() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Elicitation elicitation = mock(ClientCapabilities.Elicitation.class); + when(capabilities.elicitation()).thenReturn(elicitation); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + ElicitResult expectedResult = mock(ElicitResult.class); + when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(Mono.just(expectedResult)); + + Mono result = context.elicitation(spec -> { + spec.message("Test message"); + spec.responseType(String.class); + }); + + StepVerifier.create(result).expectNext(expectedResult).verifyComplete(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(ElicitRequest.class); + verify(exchange).createElicitation(captor.capture()); + + ElicitRequest capturedRequest = captor.getValue(); + assertThat(capturedRequest.message()).isEqualTo("Test message"); + assertThat(capturedRequest.requestedSchema()).isNotNull(); + } + + @Test + public void testElicitationWithConsumerAndMeta() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Elicitation elicitation = mock(ClientCapabilities.Elicitation.class); + when(capabilities.elicitation()).thenReturn(elicitation); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + ElicitResult expectedResult = mock(ElicitResult.class); + when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(Mono.just(expectedResult)); + + Mono result = context.elicitation(spec -> { + spec.message("Test message"); + spec.responseType(String.class); + spec.meta("key", "value"); + }); + + StepVerifier.create(result).expectNext(expectedResult).verifyComplete(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(ElicitRequest.class); + verify(exchange).createElicitation(captor.capture()); + + ElicitRequest capturedRequest = captor.getValue(); + assertThat(capturedRequest.meta()).containsEntry("key", "value"); + } + + @Test + public void testElicitationWithNullConsumer() { + assertThat(org.junit.jupiter.api.Assertions.assertThrows(IllegalArgumentException.class, () -> { + context.elicitation((java.util.function.Consumer) null); + })).hasMessageContaining("Elicitation spec consumer must not be null"); + } + + @Test + public void testElicitationWithEmptyMessage() { + assertThat(org.junit.jupiter.api.Assertions.assertThrows(IllegalArgumentException.class, () -> { + context.elicitation(spec -> { + spec.message(""); + spec.responseType(String.class); + }); + })).hasMessageContaining("Message must not be empty"); + } + + @Test + public void testElicitationWithNullResponseType() { + assertThat(org.junit.jupiter.api.Assertions.assertThrows(IllegalArgumentException.class, () -> { + context.elicitation(spec -> { + spec.message("Test message"); + spec.responseType(null); + }); + })).hasMessageContaining("Response type must not be null"); + } + + @Test + public void testElicitationWithMessageAndType() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Elicitation elicitation = mock(ClientCapabilities.Elicitation.class); + when(capabilities.elicitation()).thenReturn(elicitation); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + ElicitResult expectedResult = mock(ElicitResult.class); + when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(Mono.just(expectedResult)); + + Mono result = context.elicitation("Test message", String.class); + + StepVerifier.create(result).expectNext(expectedResult).verifyComplete(); + } + + @Test + public void testElicitationWithRequest() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Elicitation elicitation = mock(ClientCapabilities.Elicitation.class); + when(capabilities.elicitation()).thenReturn(elicitation); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + ElicitResult expectedResult = mock(ElicitResult.class); + ElicitRequest elicitRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema(Map.of("type", "string")) + .build(); + + when(exchange.createElicitation(elicitRequest)).thenReturn(Mono.just(expectedResult)); + + Mono result = context.elicitation(elicitRequest); + + StepVerifier.create(result).expectNext(expectedResult).verifyComplete(); + } + + @Test + public void testElicitationWhenNotSupported() { + when(exchange.getClientCapabilities()).thenReturn(null); + + ElicitRequest elicitRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema(Map.of("type", "string")) + .build(); + + Mono result = context.elicitation(elicitRequest); + + StepVerifier.create(result).verifyComplete(); + } + + // Sampling Tests + + @Test + public void testSamplingWithMessages() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Sampling sampling = mock(ClientCapabilities.Sampling.class); + when(capabilities.sampling()).thenReturn(sampling); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + CreateMessageResult expectedResult = mock(CreateMessageResult.class); + when(exchange.createMessage(any(CreateMessageRequest.class))).thenReturn(Mono.just(expectedResult)); + + Mono result = context.sampling("Message 1", "Message 2"); + + StepVerifier.create(result).expectNext(expectedResult).verifyComplete(); + } + + @Test + public void testSamplingWithConsumer() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Sampling sampling = mock(ClientCapabilities.Sampling.class); + when(capabilities.sampling()).thenReturn(sampling); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + CreateMessageResult expectedResult = mock(CreateMessageResult.class); + when(exchange.createMessage(any(CreateMessageRequest.class))).thenReturn(Mono.just(expectedResult)); + + Mono result = context.sampling(spec -> { + spec.message(new TextContent("Test message")); + spec.systemPrompt("System prompt"); + spec.temperature(0.7); + spec.maxTokens(100); + }); + + StepVerifier.create(result).expectNext(expectedResult).verifyComplete(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(CreateMessageRequest.class); + verify(exchange).createMessage(captor.capture()); + + CreateMessageRequest capturedRequest = captor.getValue(); + assertThat(capturedRequest.systemPrompt()).isEqualTo("System prompt"); + assertThat(capturedRequest.temperature()).isEqualTo(0.7); + assertThat(capturedRequest.maxTokens()).isEqualTo(100); + } + + @Test + public void testSamplingWithRequest() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Sampling sampling = mock(ClientCapabilities.Sampling.class); + when(capabilities.sampling()).thenReturn(sampling); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + CreateMessageResult expectedResult = mock(CreateMessageResult.class); + CreateMessageRequest createRequest = CreateMessageRequest.builder() + .messages(java.util.List.of(new SamplingMessage(Role.USER, new TextContent("Test")))) + .maxTokens(500) + .build(); + + when(exchange.createMessage(createRequest)).thenReturn(Mono.just(expectedResult)); + + Mono result = context.sampling(createRequest); + + StepVerifier.create(result).expectNext(expectedResult).verifyComplete(); + } + + @Test + public void testSamplingWhenNotSupported() { + when(exchange.getClientCapabilities()).thenReturn(null); + + CreateMessageRequest createRequest = CreateMessageRequest.builder() + .messages(java.util.List.of(new SamplingMessage(Role.USER, new TextContent("Test")))) + .maxTokens(500) + .build(); + + Mono result = context.sampling(createRequest); + + StepVerifier.create(result).verifyComplete(); + } + + // Progress Tests + + @Test + public void testProgressWithPercentage() { + CallToolRequest requestWithToken = CallToolRequest.builder() + .name("test-tool") + .arguments(Map.of()) + .progressToken("token-123") + .build(); + McpAsyncRequestContext contextWithToken = DefaultMcpAsyncRequestContext.builder() + .request(requestWithToken) + .exchange(exchange) + .build(); + + when(exchange.progressNotification(any(ProgressNotification.class))).thenReturn(Mono.empty()); + + StepVerifier.create(contextWithToken.progress(50)).verifyComplete(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(ProgressNotification.class); + verify(exchange).progressNotification(captor.capture()); + + ProgressNotification notification = captor.getValue(); + assertThat(notification.progressToken()).isEqualTo("token-123"); + assertThat(notification.progress()).isEqualTo(0.5); + assertThat(notification.total()).isEqualTo(1.0); + } + + @Test + public void testProgressWithInvalidPercentage() { + assertThat(org.junit.jupiter.api.Assertions.assertThrows(IllegalArgumentException.class, () -> { + context.progress(-1); + })).hasMessageContaining("Percentage must be between 0 and 100"); + + assertThat(org.junit.jupiter.api.Assertions.assertThrows(IllegalArgumentException.class, () -> { + context.progress(101); + })).hasMessageContaining("Percentage must be between 0 and 100"); + } + + @Test + public void testProgressWithConsumer() { + CallToolRequest requestWithToken = CallToolRequest.builder() + .name("test-tool") + .arguments(Map.of()) + .progressToken("token-123") + .build(); + McpAsyncRequestContext contextWithToken = DefaultMcpAsyncRequestContext.builder() + .request(requestWithToken) + .exchange(exchange) + .build(); + + when(exchange.progressNotification(any(ProgressNotification.class))).thenReturn(Mono.empty()); + + StepVerifier.create(contextWithToken.progress(spec -> { + spec.progress(0.75); + spec.total(1.0); + spec.message("Processing..."); + })).verifyComplete(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(ProgressNotification.class); + verify(exchange).progressNotification(captor.capture()); + + ProgressNotification notification = captor.getValue(); + assertThat(notification.progressToken()).isEqualTo("token-123"); + assertThat(notification.progress()).isEqualTo(0.75); + assertThat(notification.total()).isEqualTo(1.0); + assertThat(notification.message()).isEqualTo("Processing..."); + } + + @Test + public void testProgressWithNotification() { + ProgressNotification notification = new ProgressNotification("token-123", 0.5, 1.0, "Test", null); + when(exchange.progressNotification(notification)).thenReturn(Mono.empty()); + + StepVerifier.create(context.progress(notification)).verifyComplete(); + + verify(exchange).progressNotification(notification); + } + + @Test + public void testProgressWithoutToken() { + // request already has no progress token (null by default) + // Should not throw, just log warning and return empty + StepVerifier.create(context.progress(50)).verifyComplete(); + } + + // Ping Tests + + @Test + public void testPing() { + when(exchange.ping()).thenReturn(Mono.just(new Object())); + + StepVerifier.create(context.ping()).expectNextCount(1).verifyComplete(); + + verify(exchange).ping(); + } + + // Logging Tests + + @Test + public void testLogWithConsumer() { + when(exchange.loggingNotification(any(LoggingMessageNotification.class))).thenReturn(Mono.empty()); + + StepVerifier.create(context.log(spec -> { + spec.message("Test log message"); + spec.level(LoggingLevel.INFO); + spec.logger("test-logger"); + })).verifyComplete(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(LoggingMessageNotification.class); + verify(exchange).loggingNotification(captor.capture()); + + LoggingMessageNotification notification = captor.getValue(); + assertThat(notification.data()).isEqualTo("Test log message"); + assertThat(notification.level()).isEqualTo(LoggingLevel.INFO); + assertThat(notification.logger()).isEqualTo("test-logger"); + } + + @Test + public void testDebug() { + when(exchange.loggingNotification(any(LoggingMessageNotification.class))).thenReturn(Mono.empty()); + + StepVerifier.create(context.debug("Debug message")).verifyComplete(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(LoggingMessageNotification.class); + verify(exchange).loggingNotification(captor.capture()); + + LoggingMessageNotification notification = captor.getValue(); + assertThat(notification.data()).isEqualTo("Debug message"); + assertThat(notification.level()).isEqualTo(LoggingLevel.DEBUG); + } + + @Test + public void testInfo() { + when(exchange.loggingNotification(any(LoggingMessageNotification.class))).thenReturn(Mono.empty()); + + StepVerifier.create(context.info("Info message")).verifyComplete(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(LoggingMessageNotification.class); + verify(exchange).loggingNotification(captor.capture()); + + LoggingMessageNotification notification = captor.getValue(); + assertThat(notification.data()).isEqualTo("Info message"); + assertThat(notification.level()).isEqualTo(LoggingLevel.INFO); + } + + @Test + public void testWarn() { + when(exchange.loggingNotification(any(LoggingMessageNotification.class))).thenReturn(Mono.empty()); + + StepVerifier.create(context.warn("Warning message")).verifyComplete(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(LoggingMessageNotification.class); + verify(exchange).loggingNotification(captor.capture()); + + LoggingMessageNotification notification = captor.getValue(); + assertThat(notification.data()).isEqualTo("Warning message"); + assertThat(notification.level()).isEqualTo(LoggingLevel.WARNING); + } + + @Test + public void testError() { + when(exchange.loggingNotification(any(LoggingMessageNotification.class))).thenReturn(Mono.empty()); + + StepVerifier.create(context.error("Error message")).verifyComplete(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(LoggingMessageNotification.class); + verify(exchange).loggingNotification(captor.capture()); + + LoggingMessageNotification notification = captor.getValue(); + assertThat(notification.data()).isEqualTo("Error message"); + assertThat(notification.level()).isEqualTo(LoggingLevel.ERROR); + } + + @Test + public void testLogWithEmptyMessage() { + assertThat(org.junit.jupiter.api.Assertions.assertThrows(IllegalArgumentException.class, () -> { + context.debug(""); + })).hasMessageContaining("Log message must not be empty"); + } + + // Getter Tests + + @Test + public void testGetRequest() { + assertThat(context.request()).isEqualTo(request); + } + + @Test + public void testGetExchange() { + assertThat(context.exchange()).isEqualTo(exchange); + } + + @Test + public void testGetSessionId() { + when(exchange.sessionId()).thenReturn("session-123"); + + assertThat(context.sessionId()).isEqualTo("session-123"); + } + + @Test + public void testGetClientInfo() { + Implementation clientInfo = mock(Implementation.class); + when(exchange.getClientInfo()).thenReturn(clientInfo); + + assertThat(context.clientInfo()).isEqualTo(clientInfo); + } + + @Test + public void testGetClientCapabilities() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + assertThat(context.clientCapabilities()).isEqualTo(capabilities); + } + + @Test + public void testGetRequestMeta() { + Map meta = Map.of("key", "value"); + CallToolRequest requestWithMeta = CallToolRequest.builder() + .name("test-tool") + .arguments(Map.of()) + .meta(meta) + .build(); + McpAsyncRequestContext contextWithMeta = DefaultMcpAsyncRequestContext.builder() + .request(requestWithMeta) + .exchange(exchange) + .build(); + + assertThat(contextWithMeta.requestMeta()).isEqualTo(meta); + } + +} diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultMcpSyncRequestContextTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultMcpSyncRequestContextTests.java new file mode 100644 index 0000000..ebb4823 --- /dev/null +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultMcpSyncRequestContextTests.java @@ -0,0 +1,547 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.context; + +import java.util.Map; +import java.util.Optional; + +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; +import io.modelcontextprotocol.spec.McpSchema.Implementation; +import io.modelcontextprotocol.spec.McpSchema.ListRootsResult; +import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; +import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; +import io.modelcontextprotocol.spec.McpSchema.ProgressNotification; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.SamplingMessage; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link DefaultMcpSyncRequestContext}. + * + * @author Christian Tzolov + */ +public class DefaultMcpSyncRequestContextTests { + + private CallToolRequest request; + + private McpSyncServerExchange exchange; + + private McpSyncRequestContext context; + + @BeforeEach + public void setUp() { + request = new CallToolRequest("test-tool", Map.of()); + exchange = mock(McpSyncServerExchange.class); + context = DefaultMcpSyncRequestContext.builder().request(request).exchange(exchange).build(); + } + + // Builder Tests + + @Test + public void testBuilderWithValidParameters() { + CallToolRequest testRequest = new CallToolRequest("test-tool", Map.of()); + McpSyncRequestContext ctx = DefaultMcpSyncRequestContext.builder() + .request(testRequest) + .exchange(exchange) + .build(); + + assertThat(ctx).isNotNull(); + assertThat(ctx.request()).isEqualTo(testRequest); + assertThat(ctx.exchange()).isEqualTo(exchange); + } + + @Test + public void testBuilderWithNullRequest() { + assertThatThrownBy(() -> DefaultMcpSyncRequestContext.builder().request(null).exchange(exchange).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Request must not be null"); + } + + @Test + public void testBuilderWithNullExchange() { + CallToolRequest testRequest = new CallToolRequest("test-tool", Map.of()); + assertThatThrownBy(() -> DefaultMcpSyncRequestContext.builder().request(testRequest).exchange(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Exchange must not be null"); + } + + // Roots Tests + + @Test + public void testRootsWhenSupported() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + McpSchema.ClientCapabilities.RootCapabilities roots = mock(McpSchema.ClientCapabilities.RootCapabilities.class); + when(capabilities.roots()).thenReturn(roots); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + ListRootsResult expectedResult = mock(ListRootsResult.class); + when(exchange.listRoots()).thenReturn(expectedResult); + + Optional result = context.roots(); + + assertThat(result).isPresent(); + assertThat(result.get()).isEqualTo(expectedResult); + verify(exchange).listRoots(); + } + + @Test + public void testRootsWhenNotSupported() { + when(exchange.getClientCapabilities()).thenReturn(null); + + Optional result = context.roots(); + + assertThat(result).isEmpty(); + } + + @Test + public void testRootsWhenCapabilitiesNullRoots() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + when(capabilities.roots()).thenReturn(null); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + Optional result = context.roots(); + + assertThat(result).isEmpty(); + } + + // Elicitation Tests + + @Test + public void testElicitationWithConsumer() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Elicitation elicitation = mock(ClientCapabilities.Elicitation.class); + when(capabilities.elicitation()).thenReturn(elicitation); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + ElicitResult expectedResult = mock(ElicitResult.class); + when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(expectedResult); + + Optional result = context.elicitation(spec -> { + spec.message("Test message"); + spec.responseType(String.class); + }); + + assertThat(result).isPresent(); + assertThat(result.get()).isEqualTo(expectedResult); + + ArgumentCaptor captor = ArgumentCaptor.forClass(ElicitRequest.class); + verify(exchange).createElicitation(captor.capture()); + + ElicitRequest capturedRequest = captor.getValue(); + assertThat(capturedRequest.message()).isEqualTo("Test message"); + assertThat(capturedRequest.requestedSchema()).isNotNull(); + } + + @Test + public void testElicitationWithConsumerAndMeta() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Elicitation elicitation = mock(ClientCapabilities.Elicitation.class); + when(capabilities.elicitation()).thenReturn(elicitation); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + ElicitResult expectedResult = mock(ElicitResult.class); + when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(expectedResult); + + Optional result = context.elicitation(spec -> { + spec.message("Test message"); + spec.responseType(String.class); + spec.meta("key", "value"); + }); + + assertThat(result).isPresent(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(ElicitRequest.class); + verify(exchange).createElicitation(captor.capture()); + + ElicitRequest capturedRequest = captor.getValue(); + assertThat(capturedRequest.meta()).containsEntry("key", "value"); + } + + @Test + public void testElicitationWithNullConsumer() { + assertThatThrownBy( + () -> context.elicitation((java.util.function.Consumer) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Elicitation spec consumer must not be null"); + } + + @Test + public void testElicitationWithEmptyMessage() { + assertThatThrownBy(() -> context.elicitation(spec -> { + spec.message(""); + spec.responseType(String.class); + })).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("Message must not be empty"); + } + + @Test + public void testElicitationWithNullResponseType() { + assertThatThrownBy(() -> context.elicitation(spec -> { + spec.message("Test message"); + spec.responseType(null); + })).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("Response type must not be null"); + } + + @Test + public void testElicitationWithMessageAndType() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Elicitation elicitation = mock(ClientCapabilities.Elicitation.class); + when(capabilities.elicitation()).thenReturn(elicitation); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + ElicitResult expectedResult = mock(ElicitResult.class); + when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(expectedResult); + + Optional result = context.elicitation("Test message", String.class); + + assertThat(result).isPresent(); + assertThat(result.get()).isEqualTo(expectedResult); + } + + @Test + public void testElicitationWithRequest() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Elicitation elicitation = mock(ClientCapabilities.Elicitation.class); + when(capabilities.elicitation()).thenReturn(elicitation); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + ElicitResult expectedResult = mock(ElicitResult.class); + ElicitRequest elicitRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema(Map.of("type", "string")) + .build(); + + when(exchange.createElicitation(elicitRequest)).thenReturn(expectedResult); + + Optional result = context.elicitation(elicitRequest); + + assertThat(result).isPresent(); + assertThat(result.get()).isEqualTo(expectedResult); + } + + @Test + public void testElicitationWhenNotSupported() { + when(exchange.getClientCapabilities()).thenReturn(null); + + ElicitRequest elicitRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema(Map.of("type", "string")) + .build(); + + Optional result = context.elicitation(elicitRequest); + + assertThat(result).isEmpty(); + } + + // Sampling Tests + + @Test + public void testSamplingWithMessages() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Sampling sampling = mock(ClientCapabilities.Sampling.class); + when(capabilities.sampling()).thenReturn(sampling); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + CreateMessageResult expectedResult = mock(CreateMessageResult.class); + when(exchange.createMessage(any(CreateMessageRequest.class))).thenReturn(expectedResult); + + Optional result = context.sampling("Message 1", "Message 2"); + + assertThat(result).isPresent(); + assertThat(result.get()).isEqualTo(expectedResult); + } + + @Test + public void testSamplingWithConsumer() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Sampling sampling = mock(ClientCapabilities.Sampling.class); + when(capabilities.sampling()).thenReturn(sampling); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + CreateMessageResult expectedResult = mock(CreateMessageResult.class); + when(exchange.createMessage(any(CreateMessageRequest.class))).thenReturn(expectedResult); + + Optional result = context.sampling(spec -> { + spec.message(new TextContent("Test message")); + spec.systemPrompt("System prompt"); + spec.temperature(0.7); + spec.maxTokens(100); + }); + + assertThat(result).isPresent(); + assertThat(result.get()).isEqualTo(expectedResult); + + ArgumentCaptor captor = ArgumentCaptor.forClass(CreateMessageRequest.class); + verify(exchange).createMessage(captor.capture()); + + CreateMessageRequest capturedRequest = captor.getValue(); + assertThat(capturedRequest.systemPrompt()).isEqualTo("System prompt"); + assertThat(capturedRequest.temperature()).isEqualTo(0.7); + assertThat(capturedRequest.maxTokens()).isEqualTo(100); + } + + @Test + public void testSamplingWithRequest() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Sampling sampling = mock(ClientCapabilities.Sampling.class); + when(capabilities.sampling()).thenReturn(sampling); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + CreateMessageResult expectedResult = mock(CreateMessageResult.class); + CreateMessageRequest createRequest = CreateMessageRequest.builder() + .messages(java.util.List.of(new SamplingMessage(Role.USER, new TextContent("Test")))) + .maxTokens(500) + .build(); + + when(exchange.createMessage(createRequest)).thenReturn(expectedResult); + + Optional result = context.sampling(createRequest); + + assertThat(result).isPresent(); + assertThat(result.get()).isEqualTo(expectedResult); + } + + @Test + public void testSamplingWhenNotSupported() { + when(exchange.getClientCapabilities()).thenReturn(null); + + CreateMessageRequest createRequest = CreateMessageRequest.builder() + .messages(java.util.List.of(new SamplingMessage(Role.USER, new TextContent("Test")))) + .maxTokens(500) + .build(); + + Optional result = context.sampling(createRequest); + + assertThat(result).isEmpty(); + } + + // Progress Tests + + @Test + public void testProgressWithPercentage() { + CallToolRequest requestWithToken = CallToolRequest.builder() + .name("test-tool") + .arguments(Map.of()) + .progressToken("token-123") + .build(); + McpSyncRequestContext contextWithToken = DefaultMcpSyncRequestContext.builder() + .request(requestWithToken) + .exchange(exchange) + .build(); + + contextWithToken.progress(50); + + ArgumentCaptor captor = ArgumentCaptor.forClass(ProgressNotification.class); + verify(exchange).progressNotification(captor.capture()); + + ProgressNotification notification = captor.getValue(); + assertThat(notification.progressToken()).isEqualTo("token-123"); + assertThat(notification.progress()).isEqualTo(0.5); + assertThat(notification.total()).isEqualTo(1.0); + } + + @Test + public void testProgressWithInvalidPercentage() { + assertThatThrownBy(() -> context.progress(-1)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Percentage must be between 0 and 100"); + + assertThatThrownBy(() -> context.progress(101)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Percentage must be between 0 and 100"); + } + + @Test + public void testProgressWithConsumer() { + CallToolRequest requestWithToken = CallToolRequest.builder() + .name("test-tool") + .arguments(Map.of()) + .progressToken("token-123") + .build(); + McpSyncRequestContext contextWithToken = DefaultMcpSyncRequestContext.builder() + .request(requestWithToken) + .exchange(exchange) + .build(); + + contextWithToken.progress(spec -> { + spec.progress(0.75); + spec.total(1.0); + spec.message("Processing..."); + }); + + ArgumentCaptor captor = ArgumentCaptor.forClass(ProgressNotification.class); + verify(exchange).progressNotification(captor.capture()); + + ProgressNotification notification = captor.getValue(); + assertThat(notification.progressToken()).isEqualTo("token-123"); + assertThat(notification.progress()).isEqualTo(0.75); + assertThat(notification.total()).isEqualTo(1.0); + assertThat(notification.message()).isEqualTo("Processing..."); + } + + @Test + public void testProgressWithNotification() { + ProgressNotification notification = new ProgressNotification("token-123", 0.5, 1.0, "Test", null); + + context.progress(notification); + + verify(exchange).progressNotification(notification); + } + + @Test + public void testProgressWithoutToken() { + // request already has no progress token (null by default) + // Should not throw, just log warning + context.progress(50); + } + + // Ping Tests + + @Test + public void testPing() { + context.ping(); + + verify(exchange).ping(); + } + + // Logging Tests + + @Test + public void testLogWithConsumer() { + context.log(spec -> { + spec.message("Test log message"); + spec.level(LoggingLevel.INFO); + spec.logger("test-logger"); + }); + + ArgumentCaptor captor = ArgumentCaptor.forClass(LoggingMessageNotification.class); + verify(exchange).loggingNotification(captor.capture()); + + LoggingMessageNotification notification = captor.getValue(); + assertThat(notification.data()).isEqualTo("Test log message"); + assertThat(notification.level()).isEqualTo(LoggingLevel.INFO); + assertThat(notification.logger()).isEqualTo("test-logger"); + } + + @Test + public void testDebug() { + context.debug("Debug message"); + + ArgumentCaptor captor = ArgumentCaptor.forClass(LoggingMessageNotification.class); + verify(exchange).loggingNotification(captor.capture()); + + LoggingMessageNotification notification = captor.getValue(); + assertThat(notification.data()).isEqualTo("Debug message"); + assertThat(notification.level()).isEqualTo(LoggingLevel.DEBUG); + } + + @Test + public void testInfo() { + context.info("Info message"); + + ArgumentCaptor captor = ArgumentCaptor.forClass(LoggingMessageNotification.class); + verify(exchange).loggingNotification(captor.capture()); + + LoggingMessageNotification notification = captor.getValue(); + assertThat(notification.data()).isEqualTo("Info message"); + assertThat(notification.level()).isEqualTo(LoggingLevel.INFO); + } + + @Test + public void testWarn() { + context.warn("Warning message"); + + ArgumentCaptor captor = ArgumentCaptor.forClass(LoggingMessageNotification.class); + verify(exchange).loggingNotification(captor.capture()); + + LoggingMessageNotification notification = captor.getValue(); + assertThat(notification.data()).isEqualTo("Warning message"); + assertThat(notification.level()).isEqualTo(LoggingLevel.WARNING); + } + + @Test + public void testError() { + context.error("Error message"); + + ArgumentCaptor captor = ArgumentCaptor.forClass(LoggingMessageNotification.class); + verify(exchange).loggingNotification(captor.capture()); + + LoggingMessageNotification notification = captor.getValue(); + assertThat(notification.data()).isEqualTo("Error message"); + assertThat(notification.level()).isEqualTo(LoggingLevel.ERROR); + } + + @Test + public void testLogWithEmptyMessage() { + assertThatThrownBy(() -> context.debug("")).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Log message must not be empty"); + } + + // Getter Tests + + @Test + public void testGetRequest() { + assertThat(context.request()).isEqualTo(request); + } + + @Test + public void testGetExchange() { + assertThat(context.exchange()).isEqualTo(exchange); + } + + @Test + public void testGetSessionId() { + when(exchange.sessionId()).thenReturn("session-123"); + + assertThat(context.sessionId()).isEqualTo("session-123"); + } + + @Test + public void testGetClientInfo() { + Implementation clientInfo = mock(Implementation.class); + when(exchange.getClientInfo()).thenReturn(clientInfo); + + assertThat(context.clientInfo()).isEqualTo(clientInfo); + } + + @Test + public void testGetClientCapabilities() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + assertThat(context.clientCapabilities()).isEqualTo(capabilities); + } + + @Test + public void testGetRequestMeta() { + Map meta = Map.of("key", "value"); + CallToolRequest requestWithMeta = CallToolRequest.builder() + .name("test-tool") + .arguments(Map.of()) + .meta(meta) + .build(); + McpSyncRequestContext contextWithMeta = DefaultMcpSyncRequestContext.builder() + .request(requestWithMeta) + .exchange(exchange) + .build(); + + assertThat(contextWithMeta.requestMeta()).isEqualTo(meta); + } + +} diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultProgressSpecTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultProgressSpecTests.java new file mode 100644 index 0000000..113a076 --- /dev/null +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultProgressSpecTests.java @@ -0,0 +1,167 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.context; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link DefaultProgressSpec}. + * + * @author Christian Tzolov + */ +public class DefaultProgressSpecTests { + + @Test + public void testDefaultValues() { + DefaultProgressSpec spec = new DefaultProgressSpec(); + + assertThat(spec.progress).isEqualTo(0.0); + assertThat(spec.total).isEqualTo(1.0); + assertThat(spec.message).isNull(); + assertThat(spec.meta).isEmpty(); + } + + @Test + public void testProgressSetting() { + DefaultProgressSpec spec = new DefaultProgressSpec(); + + spec.progress(0.5); + + assertThat(spec.progress).isEqualTo(0.5); + } + + @Test + public void testTotalSetting() { + DefaultProgressSpec spec = new DefaultProgressSpec(); + + spec.total(100.0); + + assertThat(spec.total).isEqualTo(100.0); + } + + @Test + public void testMessageSetting() { + DefaultProgressSpec spec = new DefaultProgressSpec(); + + spec.message("Processing..."); + + assertThat(spec.message).isEqualTo("Processing..."); + } + + @Test + public void testMetaWithMap() { + DefaultProgressSpec spec = new DefaultProgressSpec(); + Map metaMap = new HashMap<>(); + metaMap.put("key1", "value1"); + metaMap.put("key2", "value2"); + + spec.meta(metaMap); + + assertThat(spec.meta).containsEntry("key1", "value1").containsEntry("key2", "value2"); + } + + @Test + public void testMetaWithNullMap() { + DefaultProgressSpec spec = new DefaultProgressSpec(); + + spec.meta((Map) null); + + assertThat(spec.meta).isEmpty(); + } + + @Test + public void testMetaWithKeyValue() { + DefaultProgressSpec spec = new DefaultProgressSpec(); + spec.meta = new HashMap<>(); + + spec.meta("key", "value"); + + assertThat(spec.meta).containsEntry("key", "value"); + } + + @Test + public void testMetaWithNullKey() { + DefaultProgressSpec spec = new DefaultProgressSpec(); + spec.meta = new HashMap<>(); + + spec.meta(null, "value"); + + assertThat(spec.meta).isEmpty(); + } + + @Test + public void testMetaWithNullValue() { + DefaultProgressSpec spec = new DefaultProgressSpec(); + spec.meta = new HashMap<>(); + + spec.meta("key", null); + + assertThat(spec.meta).isEmpty(); + } + + @Test + public void testMetaMultipleEntries() { + DefaultProgressSpec spec = new DefaultProgressSpec(); + spec.meta = new HashMap<>(); + + spec.meta("key1", "value1").meta("key2", "value2").meta("key3", "value3"); + + assertThat(spec.meta).hasSize(3) + .containsEntry("key1", "value1") + .containsEntry("key2", "value2") + .containsEntry("key3", "value3"); + } + + @Test + public void testFluentInterface() { + DefaultProgressSpec spec = new DefaultProgressSpec(); + spec.meta = new HashMap<>(); + + McpRequestContextTypes.ProgressSpec result = spec.progress(0.75) + .total(1.0) + .message("Processing...") + .meta("key", "value"); + + assertThat(result).isSameAs(spec); + assertThat(spec.progress).isEqualTo(0.75); + assertThat(spec.total).isEqualTo(1.0); + assertThat(spec.message).isEqualTo("Processing..."); + assertThat(spec.meta).containsEntry("key", "value"); + } + + @Test + public void testProgressBoundaries() { + DefaultProgressSpec spec = new DefaultProgressSpec(); + + spec.progress(0.0); + assertThat(spec.progress).isEqualTo(0.0); + + spec.progress(1.0); + assertThat(spec.progress).isEqualTo(1.0); + + spec.progress(0.5); + assertThat(spec.progress).isEqualTo(0.5); + } + + @Test + public void testTotalValues() { + DefaultProgressSpec spec = new DefaultProgressSpec(); + + spec.total(50.0); + assertThat(spec.total).isEqualTo(50.0); + + spec.total(100.0); + assertThat(spec.total).isEqualTo(100.0); + + spec.total(1.0); + assertThat(spec.total).isEqualTo(1.0); + } + +} diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultSamplingSpecTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultSamplingSpecTests.java new file mode 100644 index 0000000..b5adde3 --- /dev/null +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultSamplingSpecTests.java @@ -0,0 +1,215 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.context; + +import java.util.Map; + +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest.ContextInclusionStrategy; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.SamplingMessage; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link DefaultSamplingSpec}. + * + * @author Christian Tzolov + */ +public class DefaultSamplingSpecTests { + + @Test + public void testDefaultValues() { + DefaultSamplingSpec spec = new DefaultSamplingSpec(); + + assertThat(spec.messages).isEmpty(); + assertThat(spec.modelPreferences).isNull(); + assertThat(spec.systemPrompt).isNull(); + assertThat(spec.temperature).isNull(); + assertThat(spec.maxTokens).isNull(); + assertThat(spec.stopSequences).isEmpty(); + assertThat(spec.metadata).isEmpty(); + assertThat(spec.meta).isEmpty(); + assertThat(spec.includeContextStrategy).isEqualTo(ContextInclusionStrategy.NONE); + } + + @Test + public void testMessageWithTextContent() { + DefaultSamplingSpec spec = new DefaultSamplingSpec(); + TextContent content = new TextContent("Test message"); + + spec.message(content); + + assertThat(spec.messages).hasSize(1); + assertThat(spec.messages.get(0).role()).isEqualTo(Role.USER); + assertThat(spec.messages.get(0).content()).isEqualTo(content); + } + + @Test + public void testMessageWithMultipleTextContent() { + DefaultSamplingSpec spec = new DefaultSamplingSpec(); + TextContent content1 = new TextContent("Message 1"); + TextContent content2 = new TextContent("Message 2"); + + spec.message(content1, content2); + + assertThat(spec.messages).hasSize(2); + } + + @Test + public void testMessageWithSamplingMessage() { + DefaultSamplingSpec spec = new DefaultSamplingSpec(); + SamplingMessage message = new SamplingMessage(Role.ASSISTANT, new TextContent("Assistant message")); + + spec.message(message); + + assertThat(spec.messages).hasSize(1); + assertThat(spec.messages.get(0)).isEqualTo(message); + } + + @Test + public void testSystemPrompt() { + DefaultSamplingSpec spec = new DefaultSamplingSpec(); + + spec.systemPrompt("System instructions"); + + assertThat(spec.systemPrompt).isEqualTo("System instructions"); + } + + @Test + public void testTemperature() { + DefaultSamplingSpec spec = new DefaultSamplingSpec(); + + spec.temperature(0.7); + + assertThat(spec.temperature).isEqualTo(0.7); + } + + @Test + public void testMaxTokens() { + DefaultSamplingSpec spec = new DefaultSamplingSpec(); + + spec.maxTokens(1000); + + assertThat(spec.maxTokens).isEqualTo(1000); + } + + @Test + public void testStopSequences() { + DefaultSamplingSpec spec = new DefaultSamplingSpec(); + + spec.stopSequences("STOP", "END"); + + assertThat(spec.stopSequences).containsExactly("STOP", "END"); + } + + @Test + public void testIncludeContextStrategy() { + DefaultSamplingSpec spec = new DefaultSamplingSpec(); + + spec.includeContextStrategy(ContextInclusionStrategy.ALL_SERVERS); + + assertThat(spec.includeContextStrategy).isEqualTo(ContextInclusionStrategy.ALL_SERVERS); + } + + @Test + public void testMetadataWithMap() { + DefaultSamplingSpec spec = new DefaultSamplingSpec(); + Map metadataMap = Map.of("key1", "value1", "key2", "value2"); + + spec.metadata(metadataMap); + + assertThat(spec.metadata).containsEntry("key1", "value1").containsEntry("key2", "value2"); + } + + @Test + public void testMetadataWithKeyValue() { + DefaultSamplingSpec spec = new DefaultSamplingSpec(); + + spec.metadata("key", "value"); + + assertThat(spec.metadata).containsEntry("key", "value"); + } + + @Test + public void testMetaWithMap() { + DefaultSamplingSpec spec = new DefaultSamplingSpec(); + Map metaMap = Map.of("key1", "value1", "key2", "value2"); + + spec.meta(metaMap); + + assertThat(spec.meta).containsEntry("key1", "value1").containsEntry("key2", "value2"); + } + + @Test + public void testMetaWithKeyValue() { + DefaultSamplingSpec spec = new DefaultSamplingSpec(); + + spec.meta("key", "value"); + + assertThat(spec.meta).containsEntry("key", "value"); + } + + @Test + public void testModelPreferences() { + DefaultSamplingSpec spec = new DefaultSamplingSpec(); + + spec.modelPreferences(prefs -> { + prefs.modelHint("gpt-4"); + prefs.costPriority(0.5); + prefs.speedPriority(0.8); + prefs.intelligencePriority(0.9); + }); + + assertThat(spec.modelPreferences).isNotNull(); + assertThat(spec.modelPreferences.hints()).hasSize(1); + assertThat(spec.modelPreferences.costPriority()).isEqualTo(0.5); + assertThat(spec.modelPreferences.speedPriority()).isEqualTo(0.8); + assertThat(spec.modelPreferences.intelligencePriority()).isEqualTo(0.9); + } + + @Test + public void testFluentInterface() { + DefaultSamplingSpec spec = new DefaultSamplingSpec(); + + McpRequestContextTypes.SamplingSpec result = spec.message(new TextContent("Test")) + .systemPrompt("System") + .temperature(0.7) + .maxTokens(100) + .stopSequences("STOP") + .metadata("key", "value") + .meta("metaKey", "metaValue"); + + assertThat(result).isSameAs(spec); + assertThat(spec.messages).hasSize(1); + assertThat(spec.systemPrompt).isEqualTo("System"); + assertThat(spec.temperature).isEqualTo(0.7); + assertThat(spec.maxTokens).isEqualTo(100); + assertThat(spec.stopSequences).containsExactly("STOP"); + assertThat(spec.metadata).containsEntry("key", "value"); + assertThat(spec.meta).containsEntry("metaKey", "metaValue"); + } + + // ModelPreferenceSpec Tests + + @Test + public void testModelPreferenceSpecWithNullModelHint() { + DefaultSamplingSpec.DefaultModelPreferenceSpec spec = new DefaultSamplingSpec.DefaultModelPreferenceSpec(); + + assertThatThrownBy(() -> spec.modelHint(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Model hint must not be null"); + } + + @Test + public void testModelPreferenceSpecWithNullModelHints() { + DefaultSamplingSpec.DefaultModelPreferenceSpec spec = new DefaultSamplingSpec.DefaultModelPreferenceSpec(); + + assertThatThrownBy(() -> spec.modelHints((String[]) null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Models must not be null"); + } + +} diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/tool/AsyncMcpToolMethodCallbackTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/tool/AsyncMcpToolMethodCallbackTests.java index 0f5d9cf..7f0e5b4 100644 --- a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/tool/AsyncMcpToolMethodCallbackTests.java +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/tool/AsyncMcpToolMethodCallbackTests.java @@ -18,6 +18,7 @@ import org.springaicommunity.mcp.annotation.McpMeta; import org.springaicommunity.mcp.annotation.McpTool; import org.springaicommunity.mcp.annotation.McpToolParam; +import org.springaicommunity.mcp.context.McpAsyncRequestContext; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; @@ -73,6 +74,11 @@ public Mono monoToolWithExchange(McpAsyncServerExchange exchange, String return Mono.just("Exchange tool: " + message); } + @McpTool(name = "context-mono-tool", description = "Mono tool with context parameter") + public Mono monoToolWithContext(McpAsyncRequestContext context, String message) { + return Mono.just("Context tool: " + message); + } + @McpTool(name = "list-mono-tool", description = "Mono tool with list parameter") public Mono processListMono(List items) { return Mono.just("Items: " + String.join(", ", items)); @@ -664,12 +670,34 @@ public void testIsExchangeType() throws Exception { // Test that McpAsyncServerExchange is recognized as exchange type assertThat(callback.isExchangeOrContextType(McpAsyncServerExchange.class)).isTrue(); + // Test that McpAsyncRequestContext is recognized as context type + assertThat(callback.isExchangeOrContextType(McpAsyncRequestContext.class)).isTrue(); + // Test that other types are not recognized as exchange type assertThat(callback.isExchangeOrContextType(String.class)).isFalse(); assertThat(callback.isExchangeOrContextType(Integer.class)).isFalse(); assertThat(callback.isExchangeOrContextType(Object.class)).isFalse(); } + @Test + public void testMonoToolWithContextParameter() throws Exception { + TestAsyncToolProvider provider = new TestAsyncToolProvider(); + Method method = TestAsyncToolProvider.class.getMethod("monoToolWithContext", McpAsyncRequestContext.class, + String.class); + AsyncMcpToolMethodCallback callback = new AsyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpAsyncServerExchange exchange = mock(McpAsyncServerExchange.class); + CallToolRequest request = new CallToolRequest("context-mono-tool", Map.of("message", "hello")); + + StepVerifier.create(callback.apply(exchange, request)).assertNext(result -> { + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Context tool: hello"); + }).verifyComplete(); + } + @Test public void testMonoToolWithOptionalParameters() throws Exception { TestAsyncToolProvider provider = new TestAsyncToolProvider(); diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/tool/SyncMcpToolMethodCallbackTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/tool/SyncMcpToolMethodCallbackTests.java index 1f84686..066debc 100644 --- a/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/tool/SyncMcpToolMethodCallbackTests.java +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/method/tool/SyncMcpToolMethodCallbackTests.java @@ -17,6 +17,7 @@ import org.junit.jupiter.api.Test; import org.springaicommunity.mcp.annotation.McpTool; import org.springaicommunity.mcp.annotation.McpToolParam; +import org.springaicommunity.mcp.context.McpSyncRequestContext; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -56,6 +57,11 @@ public String toolWithExchange(McpSyncServerExchange exchange, String message) { return "Exchange tool: " + message; } + @McpTool(name = "context-tool", description = "Tool with context parameter") + public String toolWithContext(McpSyncRequestContext context, String message) { + return "Context tool: " + message; + } + @McpTool(name = "list-tool", description = "Tool with list parameter") public String processList(List items) { return "Items: " + String.join(", ", items); @@ -443,12 +449,33 @@ public void testIsExchangeType() throws Exception { // Test that McpSyncServerExchange is recognized as exchange type assertThat(callback.isExchangeOrContextType(McpSyncServerExchange.class)).isTrue(); + // Test that McpSyncRequestContext is recognized as context type + assertThat(callback.isExchangeOrContextType(McpSyncRequestContext.class)).isTrue(); + // Test that other types are not recognized as exchange type assertThat(callback.isExchangeOrContextType(String.class)).isFalse(); assertThat(callback.isExchangeOrContextType(Integer.class)).isFalse(); assertThat(callback.isExchangeOrContextType(Object.class)).isFalse(); } + @Test + public void testToolWithContextParameter() throws Exception { + TestToolProvider provider = new TestToolProvider(); + Method method = TestToolProvider.class.getMethod("toolWithContext", McpSyncRequestContext.class, String.class); + SyncMcpToolMethodCallback callback = new SyncMcpToolMethodCallback(ReturnMode.TEXT, method, provider); + + McpSyncServerExchange exchange = mock(McpSyncServerExchange.class); + CallToolRequest request = new CallToolRequest("context-tool", Map.of("message", "hello")); + + CallToolResult result = callback.apply(exchange, request); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isFalse(); + assertThat(result.content()).hasSize(1); + assertThat(result.content().get(0)).isInstanceOf(TextContent.class); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo("Context tool: hello"); + } + @Test public void testToolWithInvalidJsonConversion() throws Exception { TestToolProvider provider = new TestToolProvider(); From 56dfc29cc98fd02781d6bb50c4e5f46c4870d590 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Sun, 5 Oct 2025 22:30:06 +0200 Subject: [PATCH 2/4] refactor: simplify elicitation API with TypeReference-based methods Replace Consumer pattern with direct TypeReference parameters for better type safety and simpler API. Add StructuredElicitResult for type-safe responses and automatic Map-to-POJO conversion using Jackson. Signed-off-by: Christian Tzolov --- .../mcp/context/DefaultElicitationSpec.java | 57 ------ .../DefaultMcpAsyncRequestContext.java | 35 ++-- .../context/DefaultMcpSyncRequestContext.java | 52 ++++-- .../mcp/context/McpAsyncRequestContext.java | 10 +- .../mcp/context/McpRequestContextTypes.java | 24 --- .../mcp/context/McpSyncRequestContext.java | 8 +- .../mcp/context/StructuredElicitResult.java | 9 + .../context/DefaultElicitationSpecTests.java | 134 ------------- .../DefaultMcpAsyncRequestContextTests.java | 176 +++++++++++++++--- .../DefaultMcpSyncRequestContextTests.java | 149 ++++++++++++--- 10 files changed, 347 insertions(+), 307 deletions(-) delete mode 100644 mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultElicitationSpec.java create mode 100644 mcp-annotations/src/main/java/org/springaicommunity/mcp/context/StructuredElicitResult.java delete mode 100644 mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultElicitationSpecTests.java diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultElicitationSpec.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultElicitationSpec.java deleted file mode 100644 index 9d9f74a..0000000 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultElicitationSpec.java +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright 2025-2025 the original author or authors. - */ - -package org.springaicommunity.mcp.context; - -import java.lang.reflect.Type; -import java.util.HashMap; -import java.util.Map; - -import io.modelcontextprotocol.util.Assert; -import org.springaicommunity.mcp.context.McpRequestContextTypes.ElicitationSpec; - -/** - * @author Christian Tzolov - */ -public class DefaultElicitationSpec implements ElicitationSpec { - - protected String message; - - protected Type responseType; - - protected Map meta = new HashMap<>(); - - @Override - public McpSyncRequestContext.ElicitationSpec message(String message) { - Assert.hasText(message, "Message must not be empty"); - this.message = message; - return this; - } - - @Override - public McpSyncRequestContext.ElicitationSpec responseType(Type type) { - Assert.notNull(type, "Response type must not be null"); - this.responseType = type; - return this; - } - - @Override - public McpSyncRequestContext.ElicitationSpec meta(Map m) { - Assert.notNull(m, "Meta map must not be null"); - this.meta.putAll(m); - return this; - } - - @Override - public McpSyncRequestContext.ElicitationSpec meta(String k, Object v) { - if (k != null && v != null) { - if (this.meta == null) { - this.meta = new java.util.HashMap<>(); - } - this.meta.put(k, v); - } - return this; - } - -} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultMcpAsyncRequestContext.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultMcpAsyncRequestContext.java index 3dfc618..eb0a568 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultMcpAsyncRequestContext.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultMcpAsyncRequestContext.java @@ -9,6 +9,8 @@ import java.util.function.Consumer; import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.server.McpAsyncServerExchange; import io.modelcontextprotocol.spec.McpSchema; @@ -71,19 +73,27 @@ public Mono roots() { // Elicitation @Override - public Mono elicitation(Consumer elicitationSpec) { - Assert.notNull(elicitationSpec, "Elicitation spec consumer must not be null"); - DefaultElicitationSpec spec = new DefaultElicitationSpec(); - elicitationSpec.accept(spec); - Assert.hasText(spec.message, "Elicitation message must not be empty"); - Assert.notNull(spec.responseType, "Elicitation response type must not be null"); - - return this.elicitationInternal(spec.message, spec.responseType, spec.meta.isEmpty() ? null : spec.meta); + public Mono> elicitation(TypeReference type, String message, + Map meta) { + Assert.notNull(type, "Elicitation response type must not be null"); + Assert.hasText(message, "Elicitation message must not be empty"); + + return this.elicitationInternal(message, type.getType(), meta) + .map(er -> new StructuredElicitResult(er.action(), convertMapToType(er.content(), type.getType()), + er.meta())); } @Override - public Mono elicitation(String message, Type type) { - return this.elicitationInternal(message, type, null); + public Mono elicitation(TypeReference type) { + Assert.notNull(type, "Elicitation response type must not be null"); + return this.elicitationInternal("Please provide the required information.", type.getType(), null) + .map(er -> convertMapToType(er.content(), type.getType())); + } + + private static T convertMapToType(Map map, Type targetType) { + ObjectMapper mapper = new ObjectMapper(); + JavaType javaType = mapper.getTypeFactory().constructType(targetType); + return mapper.convertValue(map, javaType); } @Override @@ -346,13 +356,14 @@ public Mono roots() { } @Override - public Mono elicitation(Consumer elicitationSpec) { + public Mono> elicitation(TypeReference type, String message, + Map meta) { logger.warn("Elicitation not supported by the client! Ignoring the elicitation request"); return Mono.empty(); } @Override - public Mono elicitation(String message, Type type) { + public Mono elicitation(TypeReference type) { logger.warn("Elicitation not supported by the client! Ignoring the elicitation request"); return Mono.empty(); } diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultMcpSyncRequestContext.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultMcpSyncRequestContext.java index 2ef1804..95d11aa 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultMcpSyncRequestContext.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultMcpSyncRequestContext.java @@ -10,6 +10,8 @@ import java.util.function.Consumer; import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.server.McpSyncServerExchange; import io.modelcontextprotocol.spec.McpSchema; @@ -67,19 +69,38 @@ public Optional roots() { // Elicitation @Override - public Optional elicitation(Consumer elicitationSpec) { - Assert.notNull(elicitationSpec, "Elicitation spec consumer must not be null"); - DefaultElicitationSpec spec = new DefaultElicitationSpec(); - elicitationSpec.accept(spec); - Assert.hasText(spec.message, "Elicitation message must not be empty"); - Assert.notNull(spec.responseType, "Elicitation response type must not be null"); - - return this.elicitationInternal(spec.message, spec.responseType, spec.meta.isEmpty() ? null : spec.meta); + public Optional elicitation(TypeReference type) { + Assert.notNull(type, "Elicitation response type must not be null"); + + Optional elicitResult = this.elicitationInternal("Please provide the required information.", + type.getType(), null); + + if (!elicitResult.isPresent() || elicitResult.get().action() != ElicitResult.Action.ACCEPT) { + return Optional.empty(); + } + + return Optional.of(convertMapToType(elicitResult.get().content(), type)); } @Override - public Optional elicitation(String message, Type type) { - return this.elicitationInternal(message, type, null); + public Optional> elicitation(TypeReference type, String message, + Map meta) { + Assert.notNull(type, "Elicitation response type must not be null"); + + Optional elicitResult = this.elicitationInternal(message, type.getType(), meta); + + if (!elicitResult.isPresent() || elicitResult.get().action() != ElicitResult.Action.ACCEPT) { + return Optional.empty(); + } + + return Optional.of(new StructuredElicitResult<>(elicitResult.get().action(), + convertMapToType(elicitResult.get().content(), type), elicitResult.get().meta())); + } + + private static T convertMapToType(Map map, TypeReference targetType) { + ObjectMapper mapper = new ObjectMapper(); + JavaType javaType = mapper.getTypeFactory().constructType(targetType); + return mapper.convertValue(map, javaType); } @Override @@ -93,10 +114,12 @@ public Optional elicitation(ElicitRequest elicitRequest) { return Optional.empty(); } - return Optional.of(this.exchange.createElicitation(elicitRequest)); + ElicitResult elicitResult = this.exchange.createElicitation(elicitRequest); + + return Optional.of(elicitResult); } - public Optional elicitationInternal(String message, Type type, Map meta) { + private Optional elicitationInternal(String message, Type type, Map meta) { Assert.hasText(message, "Elicitation message must not be empty"); Assert.notNull(type, "Elicitation response type must not be null"); @@ -340,13 +363,14 @@ public Optional roots() { } @Override - public Optional elicitation(Consumer elicitationSpec) { + public Optional elicitation(TypeReference type) { logger.warn("Stateless servers do not support elicitation! Ignoring the elicitation request"); return Optional.empty(); } @Override - public Optional elicitation(String message, Type type) { + public Optional> elicitation(TypeReference type, String message, + Map meta) { logger.warn("Stateless servers do not support elicitation! Ignoring the elicitation request"); return Optional.empty(); } diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpAsyncRequestContext.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpAsyncRequestContext.java index 96dccce..21d1cd2 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpAsyncRequestContext.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpAsyncRequestContext.java @@ -4,9 +4,10 @@ package org.springaicommunity.mcp.context; -import java.lang.reflect.Type; +import java.util.Map; import java.util.function.Consumer; +import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.server.McpAsyncServerExchange; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; @@ -18,7 +19,7 @@ /** * Async (Reactor) version of McpSyncRequestContext that returns Mono of value types. - * + * * @author Christian Tzolov */ public interface McpAsyncRequestContext extends McpRequestContextTypes { @@ -31,9 +32,10 @@ public interface McpAsyncRequestContext extends McpRequestContextTypes elicitation(Consumer elicitationSpec); - Mono elicitation(String message, Type type); + Mono elicitation(TypeReference type); + + Mono> elicitation(TypeReference type, String message, Map meta); Mono elicitation(ElicitRequest elicitRequest); diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpRequestContextTypes.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpRequestContextTypes.java index 754baf1..f4c7257 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpRequestContextTypes.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpRequestContextTypes.java @@ -4,7 +4,6 @@ package org.springaicommunity.mcp.context; -import java.lang.reflect.Type; import java.util.List; import java.util.Map; import java.util.function.Consumer; @@ -27,29 +26,6 @@ */ public interface McpRequestContextTypes { - // -------------------------------------- - // Elicitation - // -------------------------------------- - - interface ElicitationSpec { - - /** - * The prompt message to display to the user - */ - ElicitationSpec message(String message); - - /** - * The response type defining the expected response structure. Note that - * elicitation responses are subject to a restricted subset of JSON Schema types. - */ - ElicitationSpec responseType(Type type); - - ElicitationSpec meta(Map m); - - ElicitationSpec meta(String k, Object v); - - } - // -------------------------------------- // Sampling // -------------------------------------- diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpSyncRequestContext.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpSyncRequestContext.java index c36a538..327978d 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpSyncRequestContext.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpSyncRequestContext.java @@ -4,10 +4,11 @@ package org.springaicommunity.mcp.context; -import java.lang.reflect.Type; +import java.util.Map; import java.util.Optional; import java.util.function.Consumer; +import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.server.McpSyncServerExchange; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; @@ -29,9 +30,10 @@ public interface McpSyncRequestContext extends McpRequestContextTypes elicitation(Consumer elicitationSpec); + Optional elicitation(TypeReference type); - Optional elicitation(String message, Type type); + Optional> elicitation(TypeReference type, String message, + Map meta); Optional elicitation(ElicitRequest elicitRequest); diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/StructuredElicitResult.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/StructuredElicitResult.java new file mode 100644 index 0000000..a71081c --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/StructuredElicitResult.java @@ -0,0 +1,9 @@ +package org.springaicommunity.mcp.context; + +import java.util.Map; + +import io.modelcontextprotocol.spec.McpSchema.ElicitResult.Action; + +public record StructuredElicitResult(Action action, T structuredContent, Map meta) { + +} diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultElicitationSpecTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultElicitationSpecTests.java deleted file mode 100644 index f7c703d..0000000 --- a/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultElicitationSpecTests.java +++ /dev/null @@ -1,134 +0,0 @@ -/* - * Copyright 2025-2025 the original author or authors. - */ - -package org.springaicommunity.mcp.context; - -import java.util.Map; - -import org.junit.jupiter.api.Test; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -/** - * Tests for {@link DefaultElicitationSpec}. - * - * @author Christian Tzolov - */ -public class DefaultElicitationSpecTests { - - @Test - public void testMessageSetting() { - DefaultElicitationSpec spec = new DefaultElicitationSpec(); - - spec.message("Test message"); - - assertThat(spec.message).isEqualTo("Test message"); - } - - @Test - public void testMessageWithEmptyString() { - DefaultElicitationSpec spec = new DefaultElicitationSpec(); - - assertThatThrownBy(() -> spec.message("")).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Message must not be empty"); - } - - @Test - public void testMessageWithNull() { - DefaultElicitationSpec spec = new DefaultElicitationSpec(); - - assertThatThrownBy(() -> spec.message(null)).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Message must not be empty"); - } - - @Test - public void testResponseTypeSetting() { - DefaultElicitationSpec spec = new DefaultElicitationSpec(); - - spec.responseType(String.class); - - assertThat(spec.responseType).isEqualTo(String.class); - } - - @Test - public void testResponseTypeWithNull() { - DefaultElicitationSpec spec = new DefaultElicitationSpec(); - - assertThatThrownBy(() -> spec.responseType(null)).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Response type must not be null"); - } - - @Test - public void testMetaWithMap() { - DefaultElicitationSpec spec = new DefaultElicitationSpec(); - Map metaMap = Map.of("key1", "value1", "key2", "value2"); - - spec.meta(metaMap); - - assertThat(spec.meta).containsEntry("key1", "value1").containsEntry("key2", "value2"); - } - - @Test - public void testMetaWithNullMap() { - DefaultElicitationSpec spec = new DefaultElicitationSpec(); - - assertThatThrownBy(() -> spec.meta((Map) null)).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Meta map must not be null"); - } - - @Test - public void testMetaWithKeyValue() { - DefaultElicitationSpec spec = new DefaultElicitationSpec(); - - spec.meta("key", "value"); - - assertThat(spec.meta).containsEntry("key", "value"); - } - - @Test - public void testMetaWithNullKey() { - DefaultElicitationSpec spec = new DefaultElicitationSpec(); - - spec.meta(null, "value"); - - assertThat(spec.meta).isEmpty(); - } - - @Test - public void testMetaWithNullValue() { - DefaultElicitationSpec spec = new DefaultElicitationSpec(); - - spec.meta("key", null); - - assertThat(spec.meta).isEmpty(); - } - - @Test - public void testMetaMultipleEntries() { - DefaultElicitationSpec spec = new DefaultElicitationSpec(); - - spec.meta("key1", "value1").meta("key2", "value2").meta("key3", "value3"); - - assertThat(spec.meta).hasSize(3) - .containsEntry("key1", "value1") - .containsEntry("key2", "value2") - .containsEntry("key3", "value3"); - } - - @Test - public void testFluentInterface() { - DefaultElicitationSpec spec = new DefaultElicitationSpec(); - - McpSyncRequestContext.ElicitationSpec result = spec.message("Test message") - .responseType(String.class) - .meta("key", "value"); - - assertThat(result).isSameAs(spec); - assertThat(spec.message).isEqualTo("Test message"); - assertThat(spec.responseType).isEqualTo(String.class); - assertThat(spec.meta).containsEntry("key", "value"); - } - -} diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultMcpAsyncRequestContextTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultMcpAsyncRequestContextTests.java index f687e97..285c9b3 100644 --- a/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultMcpAsyncRequestContextTests.java +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultMcpAsyncRequestContextTests.java @@ -6,6 +6,7 @@ import java.util.Map; +import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.server.McpAsyncServerExchange; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; @@ -126,21 +127,29 @@ public void testRootsWhenCapabilitiesNullRoots() { // Elicitation Tests @Test - public void testElicitationWithConsumer() { + public void testElicitationWithMessageAndMeta() { ClientCapabilities capabilities = mock(ClientCapabilities.class); ClientCapabilities.Elicitation elicitation = mock(ClientCapabilities.Elicitation.class); when(capabilities.elicitation()).thenReturn(elicitation); when(exchange.getClientCapabilities()).thenReturn(capabilities); + Map contentMap = Map.of("name", "John", "age", 30); ElicitResult expectedResult = mock(ElicitResult.class); + when(expectedResult.action()).thenReturn(ElicitResult.Action.ACCEPT); + when(expectedResult.content()).thenReturn(contentMap); + when(expectedResult.meta()).thenReturn(null); when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(Mono.just(expectedResult)); - Mono result = context.elicitation(spec -> { - spec.message("Test message"); - spec.responseType(String.class); - }); + Mono>> result = context + .elicitation(new TypeReference>() { + }, "Test message", null); - StepVerifier.create(result).expectNext(expectedResult).verifyComplete(); + StepVerifier.create(result).assertNext(structuredResult -> { + assertThat(structuredResult.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(structuredResult.structuredContent()).isNotNull(); + assertThat(structuredResult.structuredContent()).containsEntry("name", "John"); + assertThat(structuredResult.structuredContent()).containsEntry("age", 30); + }).verifyComplete(); ArgumentCaptor captor = ArgumentCaptor.forClass(ElicitRequest.class); verify(exchange).createElicitation(captor.capture()); @@ -151,22 +160,32 @@ public void testElicitationWithConsumer() { } @Test - public void testElicitationWithConsumerAndMeta() { + public void testElicitationWithMetadata() { ClientCapabilities capabilities = mock(ClientCapabilities.class); ClientCapabilities.Elicitation elicitation = mock(ClientCapabilities.Elicitation.class); when(capabilities.elicitation()).thenReturn(elicitation); when(exchange.getClientCapabilities()).thenReturn(capabilities); + record Person(String name, int age) { + } + + Map contentMap = Map.of("name", "Jane", "age", 25); ElicitResult expectedResult = mock(ElicitResult.class); + when(expectedResult.action()).thenReturn(ElicitResult.Action.ACCEPT); + when(expectedResult.content()).thenReturn(contentMap); + when(expectedResult.meta()).thenReturn(null); when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(Mono.just(expectedResult)); - Mono result = context.elicitation(spec -> { - spec.message("Test message"); - spec.responseType(String.class); - spec.meta("key", "value"); - }); + Map meta = Map.of("key", "value"); + Mono> result = context.elicitation(new TypeReference() { + }, "Test message", meta); - StepVerifier.create(result).expectNext(expectedResult).verifyComplete(); + StepVerifier.create(result).assertNext(structuredResult -> { + assertThat(structuredResult.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(structuredResult.structuredContent()).isNotNull(); + assertThat(structuredResult.structuredContent().name()).isEqualTo("Jane"); + assertThat(structuredResult.structuredContent().age()).isEqualTo(25); + }).verifyComplete(); ArgumentCaptor captor = ArgumentCaptor.forClass(ElicitRequest.class); verify(exchange).createElicitation(captor.capture()); @@ -176,45 +195,142 @@ public void testElicitationWithConsumerAndMeta() { } @Test - public void testElicitationWithNullConsumer() { + public void testElicitationWithNullTypeReference() { assertThat(org.junit.jupiter.api.Assertions.assertThrows(IllegalArgumentException.class, () -> { - context.elicitation((java.util.function.Consumer) null); - })).hasMessageContaining("Elicitation spec consumer must not be null"); + context.elicitation(null, "Test message", null); + })).hasMessageContaining("Elicitation response type must not be null"); } @Test public void testElicitationWithEmptyMessage() { assertThat(org.junit.jupiter.api.Assertions.assertThrows(IllegalArgumentException.class, () -> { - context.elicitation(spec -> { - spec.message(""); - spec.responseType(String.class); - }); - })).hasMessageContaining("Message must not be empty"); + context.elicitation(new TypeReference() { + }, "", null); + })).hasMessageContaining("Elicitation message must not be empty"); } @Test - public void testElicitationWithNullResponseType() { + public void testElicitationWithNullMessage() { assertThat(org.junit.jupiter.api.Assertions.assertThrows(IllegalArgumentException.class, () -> { - context.elicitation(spec -> { - spec.message("Test message"); - spec.responseType(null); - }); - })).hasMessageContaining("Response type must not be null"); + context.elicitation(new TypeReference() { + }, null, null); + })).hasMessageContaining("Elicitation message must not be empty"); + } + + @Test + public void testElicitationReturnsEmptyWhenNotSupported() { + when(exchange.getClientCapabilities()).thenReturn(null); + + Mono>> result = context + .elicitation(new TypeReference>() { + }, "Test message", null); + + StepVerifier.create(result).verifyComplete(); } @Test - public void testElicitationWithMessageAndType() { + public void testElicitationReturnsResultWhenActionIsNotAccept() { ClientCapabilities capabilities = mock(ClientCapabilities.class); ClientCapabilities.Elicitation elicitation = mock(ClientCapabilities.Elicitation.class); when(capabilities.elicitation()).thenReturn(elicitation); when(exchange.getClientCapabilities()).thenReturn(capabilities); + Map contentMap = Map.of(); ElicitResult expectedResult = mock(ElicitResult.class); + when(expectedResult.action()).thenReturn(ElicitResult.Action.DECLINE); + when(expectedResult.content()).thenReturn(contentMap); + when(expectedResult.meta()).thenReturn(null); when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(Mono.just(expectedResult)); - Mono result = context.elicitation("Test message", String.class); + Mono>> result = context + .elicitation(new TypeReference>() { + }, "Test message", null); - StepVerifier.create(result).expectNext(expectedResult).verifyComplete(); + StepVerifier.create(result).assertNext(structuredResult -> { + assertThat(structuredResult.action()).isEqualTo(ElicitResult.Action.DECLINE); + assertThat(structuredResult.structuredContent()).isNotNull(); + }).verifyComplete(); + } + + @Test + public void testElicitationConvertsComplexTypes() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Elicitation elicitation = mock(ClientCapabilities.Elicitation.class); + when(capabilities.elicitation()).thenReturn(elicitation); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + record Address(String street, String city) { + } + record PersonWithAddress(String name, int age, Address address) { + } + + Map addressMap = Map.of("street", "123 Main St", "city", "Springfield"); + Map contentMap = Map.of("name", "John", "age", 30, "address", addressMap); + ElicitResult expectedResult = mock(ElicitResult.class); + when(expectedResult.action()).thenReturn(ElicitResult.Action.ACCEPT); + when(expectedResult.content()).thenReturn(contentMap); + when(expectedResult.meta()).thenReturn(null); + when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(Mono.just(expectedResult)); + + Mono> result = context + .elicitation(new TypeReference() { + }, "Test message", null); + + StepVerifier.create(result).assertNext(structuredResult -> { + assertThat(structuredResult.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(structuredResult.structuredContent()).isNotNull(); + assertThat(structuredResult.structuredContent().name()).isEqualTo("John"); + assertThat(structuredResult.structuredContent().age()).isEqualTo(30); + assertThat(structuredResult.structuredContent().address()).isNotNull(); + assertThat(structuredResult.structuredContent().address().street()).isEqualTo("123 Main St"); + assertThat(structuredResult.structuredContent().address().city()).isEqualTo("Springfield"); + }).verifyComplete(); + } + + @Test + public void testElicitationHandlesListTypes() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Elicitation elicitation = mock(ClientCapabilities.Elicitation.class); + when(capabilities.elicitation()).thenReturn(elicitation); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + Map contentMap = Map.of("items", + java.util.List.of(Map.of("name", "Item1"), Map.of("name", "Item2"))); + ElicitResult expectedResult = mock(ElicitResult.class); + when(expectedResult.action()).thenReturn(ElicitResult.Action.ACCEPT); + when(expectedResult.content()).thenReturn(contentMap); + when(expectedResult.meta()).thenReturn(null); + when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(Mono.just(expectedResult)); + + Mono>> result = context + .elicitation(new TypeReference>() { + }, "Test message", null); + + StepVerifier.create(result).assertNext(structuredResult -> { + assertThat(structuredResult.structuredContent()).containsKey("items"); + }).verifyComplete(); + } + + @Test + public void testElicitationWithTypeReference() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Elicitation elicitation = mock(ClientCapabilities.Elicitation.class); + when(capabilities.elicitation()).thenReturn(elicitation); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + Map contentMap = Map.of("result", "success", "data", "test value"); + ElicitResult expectedResult = mock(ElicitResult.class); + when(expectedResult.action()).thenReturn(ElicitResult.Action.ACCEPT); + when(expectedResult.content()).thenReturn(contentMap); + when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(Mono.just(expectedResult)); + + Mono> result = context.elicitation(new TypeReference>() { + }); + + StepVerifier.create(result).assertNext(map -> { + assertThat(map).containsEntry("result", "success"); + assertThat(map).containsEntry("data", "test value"); + }).verifyComplete(); } @Test diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultMcpSyncRequestContextTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultMcpSyncRequestContextTests.java index ebb4823..b807c21 100644 --- a/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultMcpSyncRequestContextTests.java +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultMcpSyncRequestContextTests.java @@ -7,6 +7,7 @@ import java.util.Map; import java.util.Optional; +import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.server.McpSyncServerExchange; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; @@ -126,22 +127,28 @@ public void testRootsWhenCapabilitiesNullRoots() { // Elicitation Tests @Test - public void testElicitationWithConsumer() { + public void testElicitationWithTypeAndMessage() { ClientCapabilities capabilities = mock(ClientCapabilities.class); ClientCapabilities.Elicitation elicitation = mock(ClientCapabilities.Elicitation.class); when(capabilities.elicitation()).thenReturn(elicitation); when(exchange.getClientCapabilities()).thenReturn(capabilities); + Map contentMap = Map.of("name", "John", "age", 30); ElicitResult expectedResult = mock(ElicitResult.class); + when(expectedResult.action()).thenReturn(ElicitResult.Action.ACCEPT); + when(expectedResult.content()).thenReturn(contentMap); + when(expectedResult.meta()).thenReturn(null); when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(expectedResult); - Optional result = context.elicitation(spec -> { - spec.message("Test message"); - spec.responseType(String.class); - }); + Optional>> result = context + .elicitation(new TypeReference>() { + }, "Test message", null); assertThat(result).isPresent(); - assertThat(result.get()).isEqualTo(expectedResult); + assertThat(result.get().action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.get().structuredContent()).isNotNull(); + assertThat(result.get().structuredContent()).containsEntry("name", "John"); + assertThat(result.get().structuredContent()).containsEntry("age", 30); ArgumentCaptor captor = ArgumentCaptor.forClass(ElicitRequest.class); verify(exchange).createElicitation(captor.capture()); @@ -152,22 +159,33 @@ public void testElicitationWithConsumer() { } @Test - public void testElicitationWithConsumerAndMeta() { + public void testElicitationWithTypeMessageAndMeta() { ClientCapabilities capabilities = mock(ClientCapabilities.class); ClientCapabilities.Elicitation elicitation = mock(ClientCapabilities.Elicitation.class); when(capabilities.elicitation()).thenReturn(elicitation); when(exchange.getClientCapabilities()).thenReturn(capabilities); + record Person(String name, int age) { + } + + Map contentMap = Map.of("name", "Jane", "age", 25); + Map requestMeta = Map.of("key", "value"); + Map resultMeta = Map.of("resultKey", "resultValue"); ElicitResult expectedResult = mock(ElicitResult.class); + when(expectedResult.action()).thenReturn(ElicitResult.Action.ACCEPT); + when(expectedResult.content()).thenReturn(contentMap); + when(expectedResult.meta()).thenReturn(resultMeta); when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(expectedResult); - Optional result = context.elicitation(spec -> { - spec.message("Test message"); - spec.responseType(String.class); - spec.meta("key", "value"); - }); + Optional> result = context.elicitation(new TypeReference() { + }, "Test message", requestMeta); assertThat(result).isPresent(); + assertThat(result.get().action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.get().structuredContent()).isNotNull(); + assertThat(result.get().structuredContent().name()).isEqualTo("Jane"); + assertThat(result.get().structuredContent().age()).isEqualTo(25); + assertThat(result.get().meta()).containsEntry("resultKey", "resultValue"); ArgumentCaptor captor = ArgumentCaptor.forClass(ElicitRequest.class); verify(exchange).createElicitation(captor.capture()); @@ -177,43 +195,116 @@ public void testElicitationWithConsumerAndMeta() { } @Test - public void testElicitationWithNullConsumer() { - assertThatThrownBy( - () -> context.elicitation((java.util.function.Consumer) null)) + public void testElicitationWithNullResponseType() { + assertThatThrownBy(() -> context.elicitation((TypeReference) null)) .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Elicitation spec consumer must not be null"); + .hasMessageContaining("Elicitation response type must not be null"); } @Test - public void testElicitationWithEmptyMessage() { - assertThatThrownBy(() -> context.elicitation(spec -> { - spec.message(""); - spec.responseType(String.class); - })).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("Message must not be empty"); + public void testElicitationWithTypeReturnsEmptyWhenNotSupported() { + when(exchange.getClientCapabilities()).thenReturn(null); + + Optional> result = context.elicitation(new TypeReference>() { + }); + + assertThat(result).isEmpty(); } @Test - public void testElicitationWithNullResponseType() { - assertThatThrownBy(() -> context.elicitation(spec -> { - spec.message("Test message"); - spec.responseType(null); - })).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("Response type must not be null"); + public void testElicitationWithTypeReturnsEmptyWhenActionIsNotAccept() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Elicitation elicitation = mock(ClientCapabilities.Elicitation.class); + when(capabilities.elicitation()).thenReturn(elicitation); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + ElicitResult expectedResult = mock(ElicitResult.class); + when(expectedResult.action()).thenReturn(ElicitResult.Action.DECLINE); + when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(expectedResult); + + Optional>> result = context + .elicitation(new TypeReference>() { + }, "Test message", null); + + assertThat(result).isEmpty(); } @Test - public void testElicitationWithMessageAndType() { + public void testElicitationWithTypeConvertsComplexTypes() { ClientCapabilities capabilities = mock(ClientCapabilities.class); ClientCapabilities.Elicitation elicitation = mock(ClientCapabilities.Elicitation.class); when(capabilities.elicitation()).thenReturn(elicitation); when(exchange.getClientCapabilities()).thenReturn(capabilities); + record Address(String street, String city) { + } + record PersonWithAddress(String name, int age, Address address) { + } + + Map addressMap = Map.of("street", "123 Main St", "city", "Springfield"); + Map contentMap = Map.of("name", "John", "age", 30, "address", addressMap); ElicitResult expectedResult = mock(ElicitResult.class); + when(expectedResult.action()).thenReturn(ElicitResult.Action.ACCEPT); + when(expectedResult.content()).thenReturn(contentMap); + when(expectedResult.meta()).thenReturn(null); when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(expectedResult); - Optional result = context.elicitation("Test message", String.class); + Optional> result = context + .elicitation(new TypeReference() { + }, "Test message", null); assertThat(result).isPresent(); - assertThat(result.get()).isEqualTo(expectedResult); + assertThat(result.get().action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.get().structuredContent()).isNotNull(); + assertThat(result.get().structuredContent().name()).isEqualTo("John"); + assertThat(result.get().structuredContent().age()).isEqualTo(30); + assertThat(result.get().structuredContent().address()).isNotNull(); + assertThat(result.get().structuredContent().address().street()).isEqualTo("123 Main St"); + assertThat(result.get().structuredContent().address().city()).isEqualTo("Springfield"); + } + + @Test + public void testElicitationWithTypeHandlesListTypes() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Elicitation elicitation = mock(ClientCapabilities.Elicitation.class); + when(capabilities.elicitation()).thenReturn(elicitation); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + Map contentMap = Map.of("items", + java.util.List.of(Map.of("name", "Item1"), Map.of("name", "Item2"))); + ElicitResult expectedResult = mock(ElicitResult.class); + when(expectedResult.action()).thenReturn(ElicitResult.Action.ACCEPT); + when(expectedResult.content()).thenReturn(contentMap); + when(expectedResult.meta()).thenReturn(null); + when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(expectedResult); + + Optional>> result = context + .elicitation(new TypeReference>() { + }, "Test message", null); + + assertThat(result).isPresent(); + assertThat(result.get().structuredContent()).containsKey("items"); + } + + @Test + public void testElicitationWithTypeReference() { + ClientCapabilities capabilities = mock(ClientCapabilities.class); + ClientCapabilities.Elicitation elicitation = mock(ClientCapabilities.Elicitation.class); + when(capabilities.elicitation()).thenReturn(elicitation); + when(exchange.getClientCapabilities()).thenReturn(capabilities); + + Map contentMap = Map.of("result", "success", "data", "test value"); + ElicitResult expectedResult = mock(ElicitResult.class); + when(expectedResult.action()).thenReturn(ElicitResult.Action.ACCEPT); + when(expectedResult.content()).thenReturn(contentMap); + when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(expectedResult); + + Optional> result = context.elicitation(new TypeReference>() { + }); + + assertThat(result).isPresent(); + assertThat(result.get()).containsEntry("result", "success"); + assertThat(result.get()).containsEntry("data", "test value"); } @Test From abd02608d7f3c7db736afe95168a95565928ddac Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Sun, 5 Oct 2025 22:36:31 +0200 Subject: [PATCH 3/4] Update README Signed-off-by: Christian Tzolov --- README.md | 58 ++++++++++++++++++++++++++++++++++++------------------- 1 file changed, 38 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index de87f65..d651a5e 100644 --- a/README.md +++ b/README.md @@ -897,7 +897,7 @@ When a method parameter is of type `McpSyncRequestContext` or `McpAsyncRequestCo **Synchronous Context Example:** ```java -public record ElicitReturnType(String message) {} +public record UserInfo(String name, String email, Number age) {} @McpTool(name = "process-with-context", description = "Process data with unified context") public String processWithContext( @@ -920,11 +920,20 @@ public String processWithContext( // Use exchange for additional operations... } - // Perform elicitation if needed - Optional userInput = context.elicitation(spec -> { - spec.message("Please provide additional information"); - spec.regurnType(ElicitReturnType.class); - }); + // Perform elicitation with default message - returns just the typed content + Optional userInfo = context.elicitation(new TypeReference() {}); + + // Or perform elicitation with custom message and metadata - returns structured result + Optional> structuredResult = context.elicitation( + new TypeReference() {}, + "Please provide your information", + Map.of("context", "user-registration") + ); + + if (structuredResult.isPresent() && structuredResult.get().action() == ElicitResult.Action.ACCEPT) { + UserInfo info = structuredResult.get().structuredContent(); + return "Processed: " + data + " for user " + info.name(); + } return "Processed: " + data; } @@ -968,8 +977,7 @@ public GetPromptResult generateWithContext( **Asynchronous Context Example:** ```java - -public record ElicitReturnType(String message) {} +public record UserInfo(String name, String email, int age) {} @McpTool(name = "async-process-with-context", description = "Async process with unified context") public Mono asyncProcessWithContext( @@ -992,13 +1000,21 @@ public Mono asyncProcessWithContext( .thenReturn(processedData); }) .flatMap(processedData -> { - // Perform elicitation if needed (returns Mono) - return context.elicitation(spec -> { - spec.message("Please provide additional information"); - spec.returnType(ElicitReturnType.class); - }) - .map(result -> "Processed: " + processedData + " with user input"); - }); + // Perform elicitation with default message - returns Mono + return context.elicitation(new TypeReference() {}) + .map(userInfo -> "Processed: " + processedData + " for user " + userInfo.name()); + }) + .switchIfEmpty(Mono.fromCallable(() -> { + // Or perform elicitation with custom message and metadata - returns Mono> + return context.elicitation( + new TypeReference() {}, + "Please provide your information", + Map.of("context", "user-registration") + ) + .filter(result -> result.action() == ElicitResult.Action.ACCEPT) + .map(result -> "Processed: " + data + " for user " + result.structuredContent().name()) + .defaultIfEmpty("Processed: " + data); + }).flatMap(mono -> mono)); } @McpResource(uri = "async-data://{id}", name = "Async Data Resource", @@ -1044,7 +1060,9 @@ public Mono asyncGenerateWithContext( - `log(Consumer)` - Send log messages with custom configuration - `debug(String)`, `info(String)`, `warn(String)`, `error(String)` - Convenience logging methods - `progress(int)`, `progress(Consumer)` - Send progress updates -- `elicitation(...)` - Request user input with various configuration options +- `elicitation(TypeReference)` - Request user input with default message, returns typed content directly +- `elicitation(TypeReference, String, Map)` - Request user input with custom message and metadata, returns `StructuredElicitResult` with action, typed content, and metadata +- `elicitation(ElicitRequest)` - Request user input with full control over the elicitation request - `sampling(...)` - Request LLM sampling with various configuration options - `roots()` - Access root directories (returns `Optional`) - `ping()` - Send ping to check connection @@ -1954,7 +1972,7 @@ public class AsyncElicitationHandler { public class MyMcpClient { public static McpSyncClient createSyncClientWithElicitation(ElicitationHandler elicitationHandler) { - Function elicitationHandler = + Function elicitationHandlerFunc = new SyncMcpElicitationProvider(List.of(elicitationHandler)).getElicitationHandler(); McpSyncClient client = McpClient.sync(transport) @@ -1962,14 +1980,14 @@ public class MyMcpClient { .elicitation() // Enable elicitation support // Other capabilities... .build()) - .elicitationHandler(elicitationHandler) + .elicitationHandler(elicitationHandlerFunc) .build(); return client; } public static McpAsyncClient createAsyncClientWithElicitation(AsyncElicitationHandler asyncElicitationHandler) { - Function> elicitationHandler = + Function> elicitationHandlerFunc = new AsyncMcpElicitationProvider(List.of(asyncElicitationHandler)).getElicitationHandler(); McpAsyncClient client = McpClient.async(transport) @@ -1977,7 +1995,7 @@ public class MyMcpClient { .elicitation() // Enable elicitation support // Other capabilities... .build()) - .elicitationHandler(elicitationHandler) + .elicitationHandler(elicitationHandlerFunc) .build(); return client; From 35f022e26c3a01c49efca6e5262e0beed4a6ba3c Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Mon, 6 Oct 2025 09:56:55 +0200 Subject: [PATCH 4/4] refactor: rename elicitation/sampling methods and standardize API - Rename `elicitation()` methods to `elicit()` across all context interfaces - Rename `sampling()` methods to `sample()` across all context interfaces - Standardize elicit() to always return `StructuredElicitResult` instead of raw type - Replace positional parameters with builder-style `Consumer` configuration - Add support for both `Class` and `TypeReference` type parameters - Introduce `ElicitationSpec` interface with fluent API for message and metadata - Update README documentation with new method signatures and examples - Update all test cases to use new API methods - Add copyright header to StructuredElicitResult class Breaking Changes: - `elicitation(TypeReference)` now returns `Optional>` instead of `Optional` - `elicitation(TypeReference, String, Map)` replaced with `elicit(Consumer, TypeReference)` - `sampling(...)` renamed to `sample(...)` - All elicit methods now consistently return structured results with action, content, and metadata Signed-off-by: Christian Tzolov --- README.md | 25 ++-- .../mcp/context/DefaultElicitationSpec.java | 48 ++++++++ .../DefaultMcpAsyncRequestContext.java | 85 +++++++++----- .../context/DefaultMcpSyncRequestContext.java | 108 ++++++++++++++---- .../mcp/context/McpAsyncRequestContext.java | 16 ++- .../mcp/context/McpRequestContextTypes.java | 10 ++ .../mcp/context/McpSyncRequestContext.java | 18 +-- .../mcp/context/StructuredElicitResult.java | 10 ++ .../DefaultMcpAsyncRequestContextTests.java | 73 ++++++------ .../DefaultMcpSyncRequestContextTests.java | 53 +++++---- 10 files changed, 312 insertions(+), 134 deletions(-) create mode 100644 mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultElicitationSpec.java diff --git a/README.md b/README.md index d651a5e..e3f3b54 100644 --- a/README.md +++ b/README.md @@ -920,14 +920,13 @@ public String processWithContext( // Use exchange for additional operations... } - // Perform elicitation with default message - returns just the typed content - Optional userInfo = context.elicitation(new TypeReference() {}); + // Perform elicitation with default message - returns StructuredElicitResult + Optional> result = context.elicit(new TypeReference() {}); - // Or perform elicitation with custom message and metadata - returns structured result - Optional> structuredResult = context.elicitation( - new TypeReference() {}, - "Please provide your information", - Map.of("context", "user-registration") + // Or perform elicitation with custom configuration - returns StructuredElicitResult + Optional> structuredResult = context.elicit( + e -> e.message("Please provide your information").meta("context", "user-registration"), + new TypeReference() {} ); if (structuredResult.isPresent() && structuredResult.get().action() == ElicitResult.Action.ACCEPT) { @@ -964,7 +963,7 @@ public GetPromptResult generateWithContext( context.info("Generating prompt for topic: " + topic); // Perform sampling if needed - Optional samplingResult = context.sampling( + Optional samplingResult = context.sample( "What are the key points about " + topic + "?" ); @@ -1060,10 +1059,12 @@ public Mono asyncGenerateWithContext( - `log(Consumer)` - Send log messages with custom configuration - `debug(String)`, `info(String)`, `warn(String)`, `error(String)` - Convenience logging methods - `progress(int)`, `progress(Consumer)` - Send progress updates -- `elicitation(TypeReference)` - Request user input with default message, returns typed content directly -- `elicitation(TypeReference, String, Map)` - Request user input with custom message and metadata, returns `StructuredElicitResult` with action, typed content, and metadata -- `elicitation(ElicitRequest)` - Request user input with full control over the elicitation request -- `sampling(...)` - Request LLM sampling with various configuration options +- `elicit(TypeReference)` - Request user input with default message, returns `StructuredElicitResult` with action, typed content, and metadata +- `elicit(Class)` - Request user input with default message using Class type, returns `StructuredElicitResult` +- `elicit(Consumer, TypeReference)` - Request user input with custom configuration, returns `StructuredElicitResult` +- `elicit(Consumer, Class)` - Request user input with custom configuration using Class type, returns `StructuredElicitResult` +- `elicit(ElicitRequest)` - Request user input with full control over the elicitation request +- `sample(...)` - Request LLM sampling with various configuration options - `roots()` - Access root directories (returns `Optional`) - `ping()` - Send ping to check connection diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultElicitationSpec.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultElicitationSpec.java new file mode 100644 index 0000000..674f53d --- /dev/null +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultElicitationSpec.java @@ -0,0 +1,48 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package org.springaicommunity.mcp.context; + +import java.util.HashMap; +import java.util.Map; + +import org.springaicommunity.mcp.context.McpRequestContextTypes.ElicitationSpec; + +public class DefaultElicitationSpec implements ElicitationSpec { + + protected String message; + + protected Map meta = new HashMap<>(); + + protected String message() { + return message; + } + + protected Map meta() { + return meta; + } + + @Override + public ElicitationSpec message(String message) { + this.message = message; + return this; + } + + @Override + public ElicitationSpec meta(Map m) { + if (m != null) { + this.meta.putAll(m); + } + return this; + } + + @Override + public ElicitationSpec meta(String k, Object v) { + if (k != null && v != null) { + this.meta.put(k, v); + } + return this; + } + +} diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultMcpAsyncRequestContext.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultMcpAsyncRequestContext.java index eb0a568..716c195 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultMcpAsyncRequestContext.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultMcpAsyncRequestContext.java @@ -73,31 +73,41 @@ public Mono roots() { // Elicitation @Override - public Mono> elicitation(TypeReference type, String message, - Map meta) { + public Mono> elicit(Consumer spec, TypeReference type) { Assert.notNull(type, "Elicitation response type must not be null"); - Assert.hasText(message, "Elicitation message must not be empty"); + Assert.notNull(spec, "Elicitation spec consumer must not be null"); + DefaultElicitationSpec elicitationSpec = new DefaultElicitationSpec(); + spec.accept(elicitationSpec); + return this.elicitationInternal(elicitationSpec.message, type.getType(), elicitationSpec.meta) + .map(er -> new StructuredElicitResult(er.action(), convertMapToType(er.content(), type), er.meta())); + } - return this.elicitationInternal(message, type.getType(), meta) - .map(er -> new StructuredElicitResult(er.action(), convertMapToType(er.content(), type.getType()), - er.meta())); + @Override + public Mono> elicit(Consumer spec, Class type) { + Assert.notNull(type, "Elicitation response type must not be null"); + Assert.notNull(spec, "Elicitation spec consumer must not be null"); + DefaultElicitationSpec elicitationSpec = new DefaultElicitationSpec(); + spec.accept(elicitationSpec); + return this.elicitationInternal(elicitationSpec.message, type, elicitationSpec.meta) + .map(er -> new StructuredElicitResult(er.action(), convertMapToType(er.content(), type), er.meta())); } @Override - public Mono elicitation(TypeReference type) { + public Mono> elicit(TypeReference type) { Assert.notNull(type, "Elicitation response type must not be null"); return this.elicitationInternal("Please provide the required information.", type.getType(), null) - .map(er -> convertMapToType(er.content(), type.getType())); + .map(er -> new StructuredElicitResult(er.action(), convertMapToType(er.content(), type), er.meta())); } - private static T convertMapToType(Map map, Type targetType) { - ObjectMapper mapper = new ObjectMapper(); - JavaType javaType = mapper.getTypeFactory().constructType(targetType); - return mapper.convertValue(map, javaType); + @Override + public Mono> elicit(Class type) { + Assert.notNull(type, "Elicitation response type must not be null"); + return this.elicitationInternal("Please provide the required information.", type, null) + .map(er -> new StructuredElicitResult(er.action(), convertMapToType(er.content(), type), er.meta())); } @Override - public Mono elicitation(ElicitRequest elicitRequest) { + public Mono elicit(ElicitRequest elicitRequest) { Assert.notNull(elicitRequest, "Elicit request must not be null"); if (this.exchange.getClientCapabilities() == null @@ -116,7 +126,7 @@ public Mono elicitationInternal(String message, Type type, Map schema = typeSchemaCache.computeIfAbsent(type, t -> this.generateElicitSchema(t)); - return this.elicitation(ElicitRequest.builder().message(message).requestedSchema(schema).meta(meta).build()); + return this.elicit(ElicitRequest.builder().message(message).requestedSchema(schema).meta(meta).build()); } private Map generateElicitSchema(Type type) { @@ -126,15 +136,27 @@ private Map generateElicitSchema(Type type) { return schema; } + private static T convertMapToType(Map map, Class targetType) { + ObjectMapper mapper = new ObjectMapper(); + JavaType javaType = mapper.getTypeFactory().constructType(targetType); + return mapper.convertValue(map, javaType); + } + + private static T convertMapToType(Map map, TypeReference targetType) { + ObjectMapper mapper = new ObjectMapper(); + JavaType javaType = mapper.getTypeFactory().constructType(targetType); + return mapper.convertValue(map, javaType); + } + // Sampling @Override - public Mono sampling(String... messages) { - return this.sampling(s -> s.message(messages)); + public Mono sample(String... messages) { + return this.sample(s -> s.message(messages)); } @Override - public Mono sampling(Consumer samplingSpec) { + public Mono sample(Consumer samplingSpec) { Assert.notNull(samplingSpec, "Sampling spec consumer must not be null"); DefaultSamplingSpec spec = new DefaultSamplingSpec(); samplingSpec.accept(spec); @@ -144,7 +166,7 @@ public Mono sampling(Consumer samplingSpec) { if (!Utils.hasText(progressToken)) { logger.warn("Progress notification not supported by the client!"); } - return this.sampling(McpSchema.CreateMessageRequest.builder() + return this.sample(McpSchema.CreateMessageRequest.builder() .messages(spec.messages) .modelPreferences(spec.modelPreferences) .systemPrompt(spec.systemPrompt) @@ -159,7 +181,7 @@ public Mono sampling(Consumer samplingSpec) { } @Override - public Mono sampling(CreateMessageRequest createMessageRequest) { + public Mono sample(CreateMessageRequest createMessageRequest) { // check if supported if (this.exchange.getClientCapabilities() == null || this.exchange.getClientCapabilities().sampling() == null) { @@ -356,38 +378,49 @@ public Mono roots() { } @Override - public Mono> elicitation(TypeReference type, String message, - Map meta) { + public Mono> elicit(Consumer spec, TypeReference returnType) { + logger.warn("Elicitation not supported by the client! Ignoring the elicitation request"); + return Mono.empty(); + } + + @Override + public Mono> elicit(TypeReference type) { + logger.warn("Elicitation not supported by the client! Ignoring the elicitation request"); + return Mono.empty(); + } + + @Override + public Mono> elicit(Consumer spec, Class returnType) { logger.warn("Elicitation not supported by the client! Ignoring the elicitation request"); return Mono.empty(); } @Override - public Mono elicitation(TypeReference type) { + public Mono> elicit(Class type) { logger.warn("Elicitation not supported by the client! Ignoring the elicitation request"); return Mono.empty(); } @Override - public Mono elicitation(ElicitRequest elicitRequest) { + public Mono elicit(ElicitRequest elicitRequest) { logger.warn("Elicitation not supported by the client! Ignoring the elicitation request"); return Mono.empty(); } @Override - public Mono sampling(String... messages) { + public Mono sample(String... messages) { logger.warn("Sampling not supported by the client! Ignoring the sampling request"); return Mono.empty(); } @Override - public Mono sampling(Consumer samplingSpec) { + public Mono sample(Consumer samplingSpec) { logger.warn("Sampling not supported by the client! Ignoring the sampling request"); return Mono.empty(); } @Override - public Mono sampling(CreateMessageRequest createMessageRequest) { + public Mono sample(CreateMessageRequest createMessageRequest) { logger.warn("Sampling not supported by the client! Ignoring the sampling request"); return Mono.empty(); } diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultMcpSyncRequestContext.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultMcpSyncRequestContext.java index 95d11aa..0450b8a 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultMcpSyncRequestContext.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/DefaultMcpSyncRequestContext.java @@ -69,25 +69,26 @@ public Optional roots() { // Elicitation @Override - public Optional elicitation(TypeReference type) { + public Optional> elicit(Class type) { Assert.notNull(type, "Elicitation response type must not be null"); - Optional elicitResult = this.elicitationInternal("Please provide the required information.", - type.getType(), null); + Optional elicitResult = this.elicitationInternal("Please provide the required information.", type, + null); if (!elicitResult.isPresent() || elicitResult.get().action() != ElicitResult.Action.ACCEPT) { return Optional.empty(); } - return Optional.of(convertMapToType(elicitResult.get().content(), type)); + return Optional.of(new StructuredElicitResult<>(elicitResult.get().action(), + convertMapToType(elicitResult.get().content(), type), elicitResult.get().meta())); } @Override - public Optional> elicitation(TypeReference type, String message, - Map meta) { + public Optional> elicit(TypeReference type) { Assert.notNull(type, "Elicitation response type must not be null"); - Optional elicitResult = this.elicitationInternal(message, type.getType(), meta); + Optional elicitResult = this.elicitationInternal("Please provide the required information.", + type.getType(), null); if (!elicitResult.isPresent() || elicitResult.get().action() != ElicitResult.Action.ACCEPT) { return Optional.empty(); @@ -97,14 +98,47 @@ public Optional> elicitation(TypeReference type convertMapToType(elicitResult.get().content(), type), elicitResult.get().meta())); } - private static T convertMapToType(Map map, TypeReference targetType) { - ObjectMapper mapper = new ObjectMapper(); - JavaType javaType = mapper.getTypeFactory().constructType(targetType); - return mapper.convertValue(map, javaType); + @Override + public Optional> elicit(Consumer params, Class returnType) { + Assert.notNull(returnType, "Elicitation response type must not be null"); + Assert.notNull(params, "Elicitation params must not be null"); + + DefaultElicitationSpec paramSpec = new DefaultElicitationSpec(); + params.accept(paramSpec); + + Optional elicitResult = this.elicitationInternal(paramSpec.message(), returnType, + paramSpec.meta()); + + if (!elicitResult.isPresent() || elicitResult.get().action() != ElicitResult.Action.ACCEPT) { + return Optional.empty(); + } + + return Optional.of(new StructuredElicitResult<>(elicitResult.get().action(), + convertMapToType(elicitResult.get().content(), returnType), elicitResult.get().meta())); } @Override - public Optional elicitation(ElicitRequest elicitRequest) { + public Optional> elicit(Consumer params, + TypeReference returnType) { + Assert.notNull(returnType, "Elicitation response type must not be null"); + Assert.notNull(params, "Elicitation params must not be null"); + + DefaultElicitationSpec paramSpec = new DefaultElicitationSpec(); + params.accept(paramSpec); + + Optional elicitResult = this.elicitationInternal(paramSpec.message(), returnType.getType(), + paramSpec.meta()); + + if (!elicitResult.isPresent() || elicitResult.get().action() != ElicitResult.Action.ACCEPT) { + return Optional.empty(); + } + + return Optional.of(new StructuredElicitResult<>(elicitResult.get().action(), + convertMapToType(elicitResult.get().content(), returnType), elicitResult.get().meta())); + } + + @Override + public Optional elicit(ElicitRequest elicitRequest) { Assert.notNull(elicitRequest, "Elicit request must not be null"); if (this.exchange.getClientCapabilities() == null @@ -125,7 +159,7 @@ private Optional elicitationInternal(String message, Type type, Ma Map schema = typeSchemaCache.computeIfAbsent(type, t -> this.generateElicitSchema(t)); - return this.elicitation(ElicitRequest.builder().message(message).requestedSchema(schema).meta(meta).build()); + return this.elicit(ElicitRequest.builder().message(message).requestedSchema(schema).meta(meta).build()); } private Map generateElicitSchema(Type type) { @@ -135,15 +169,27 @@ private Map generateElicitSchema(Type type) { return schema; } + private static T convertMapToType(Map map, Class targetType) { + ObjectMapper mapper = new ObjectMapper(); + JavaType javaType = mapper.getTypeFactory().constructType(targetType); + return mapper.convertValue(map, javaType); + } + + private static T convertMapToType(Map map, TypeReference targetType) { + ObjectMapper mapper = new ObjectMapper(); + JavaType javaType = mapper.getTypeFactory().constructType(targetType); + return mapper.convertValue(map, javaType); + } + // Sampling @Override - public Optional sampling(String... messages) { - return this.sampling(s -> s.message(messages)); + public Optional sample(String... messages) { + return this.sample(s -> s.message(messages)); } @Override - public Optional sampling(Consumer samplingSpec) { + public Optional sample(Consumer samplingSpec) { Assert.notNull(samplingSpec, "Sampling spec consumer must not be null"); DefaultSamplingSpec spec = new DefaultSamplingSpec(); samplingSpec.accept(spec); @@ -153,7 +199,7 @@ public Optional sampling(Consumer samplingSpe if (!Utils.hasText(progressToken)) { logger.warn("Progress notification not supported by the client!"); } - return this.sampling(McpSchema.CreateMessageRequest.builder() + return this.sample(McpSchema.CreateMessageRequest.builder() .messages(spec.messages) .modelPreferences(spec.modelPreferences) .systemPrompt(spec.systemPrompt) @@ -168,7 +214,7 @@ public Optional sampling(Consumer samplingSpe } @Override - public Optional sampling(CreateMessageRequest createMessageRequest) { + public Optional sample(CreateMessageRequest createMessageRequest) { // check if supported if (this.exchange.getClientCapabilities() == null || this.exchange.getClientCapabilities().sampling() == null) { @@ -363,38 +409,50 @@ public Optional roots() { } @Override - public Optional elicitation(TypeReference type) { + public Optional> elicit(Class type) { + logger.warn("Stateless servers do not support elicitation! Ignoring the elicitation request"); + return Optional.empty(); + } + + @Override + public Optional> elicit(Consumer params, Class returnType) { + logger.warn("Stateless servers do not support elicitation! Ignoring the elicitation request"); + return Optional.empty(); + } + + @Override + public Optional> elicit(TypeReference type) { logger.warn("Stateless servers do not support elicitation! Ignoring the elicitation request"); return Optional.empty(); } @Override - public Optional> elicitation(TypeReference type, String message, - Map meta) { + public Optional> elicit(Consumer params, + TypeReference returnType) { logger.warn("Stateless servers do not support elicitation! Ignoring the elicitation request"); return Optional.empty(); } @Override - public Optional elicitation(ElicitRequest elicitRequest) { + public Optional elicit(ElicitRequest elicitRequest) { logger.warn("Stateless servers do not support elicitation! Ignoring the elicitation request"); return Optional.empty(); } @Override - public Optional sampling(String... messages) { + public Optional sample(String... messages) { logger.warn("Stateless servers do not support sampling! Ignoring the sampling request"); return Optional.empty(); } @Override - public Optional sampling(Consumer samplingSpec) { + public Optional sample(Consumer samplingSpec) { logger.warn("Stateless servers do not support sampling! Ignoring the sampling request"); return Optional.empty(); } @Override - public Optional sampling(CreateMessageRequest createMessageRequest) { + public Optional sample(CreateMessageRequest createMessageRequest) { logger.warn("Stateless servers do not support sampling! Ignoring the sampling request"); return Optional.empty(); } diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpAsyncRequestContext.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpAsyncRequestContext.java index 21d1cd2..08bad44 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpAsyncRequestContext.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpAsyncRequestContext.java @@ -33,20 +33,24 @@ public interface McpAsyncRequestContext extends McpRequestContextTypes Mono elicitation(TypeReference type); + Mono> elicit(Class type); - Mono> elicitation(TypeReference type, String message, Map meta); + Mono> elicit(TypeReference type); - Mono elicitation(ElicitRequest elicitRequest); + Mono> elicit(Consumer spec, TypeReference returnType); + + Mono> elicit(Consumer spec, Class returnType); + + Mono elicit(ElicitRequest elicitRequest); // -------------------------------------- // Sampling // -------------------------------------- - Mono sampling(String... messages); + Mono sample(String... messages); - Mono sampling(Consumer samplingSpec); + Mono sample(Consumer samplingSpec); - Mono sampling(CreateMessageRequest createMessageRequest); + Mono sample(CreateMessageRequest createMessageRequest); // -------------------------------------- // Progress diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpRequestContextTypes.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpRequestContextTypes.java index f4c7257..70c9b40 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpRequestContextTypes.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpRequestContextTypes.java @@ -26,6 +26,16 @@ */ public interface McpRequestContextTypes { + interface ElicitationSpec { + + ElicitationSpec message(String message); + + ElicitationSpec meta(Map m); + + ElicitationSpec meta(String k, Object v); + + } + // -------------------------------------- // Sampling // -------------------------------------- diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpSyncRequestContext.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpSyncRequestContext.java index 327978d..38e74f1 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpSyncRequestContext.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/McpSyncRequestContext.java @@ -4,7 +4,6 @@ package org.springaicommunity.mcp.context; -import java.util.Map; import java.util.Optional; import java.util.function.Consumer; @@ -30,21 +29,24 @@ public interface McpSyncRequestContext extends McpRequestContextTypes Optional elicitation(TypeReference type); + Optional> elicit(Class type); - Optional> elicitation(TypeReference type, String message, - Map meta); + Optional> elicit(TypeReference type); - Optional elicitation(ElicitRequest elicitRequest); + Optional> elicit(Consumer params, Class returnType); + + Optional> elicit(Consumer params, TypeReference returnType); + + Optional elicit(ElicitRequest elicitRequest); // -------------------------------------- // Sampling // -------------------------------------- - Optional sampling(String... messages); + Optional sample(String... messages); - Optional sampling(Consumer samplingSpec); + Optional sample(Consumer samplingSpec); - Optional sampling(CreateMessageRequest createMessageRequest); + Optional sample(CreateMessageRequest createMessageRequest); // -------------------------------------- // Progress diff --git a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/StructuredElicitResult.java b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/StructuredElicitResult.java index a71081c..0afd5f5 100644 --- a/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/StructuredElicitResult.java +++ b/mcp-annotations/src/main/java/org/springaicommunity/mcp/context/StructuredElicitResult.java @@ -1,9 +1,19 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + package org.springaicommunity.mcp.context; import java.util.Map; import io.modelcontextprotocol.spec.McpSchema.ElicitResult.Action; +/** + * A record representing the result of a structured elicit action. + * + * @param the type of the structured content + * @author Christian Tzolov + */ public record StructuredElicitResult(Action action, T structuredContent, Map meta) { } diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultMcpAsyncRequestContextTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultMcpAsyncRequestContextTests.java index 285c9b3..f48a216 100644 --- a/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultMcpAsyncRequestContextTests.java +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultMcpAsyncRequestContextTests.java @@ -140,9 +140,9 @@ public void testElicitationWithMessageAndMeta() { when(expectedResult.meta()).thenReturn(null); when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(Mono.just(expectedResult)); - Mono>> result = context - .elicitation(new TypeReference>() { - }, "Test message", null); + Mono>> result = context.elicit(e -> e.message("Test message"), + new TypeReference>() { + }); StepVerifier.create(result).assertNext(structuredResult -> { assertThat(structuredResult.action()).isEqualTo(ElicitResult.Action.ACCEPT); @@ -177,8 +177,9 @@ record Person(String name, int age) { when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(Mono.just(expectedResult)); Map meta = Map.of("key", "value"); - Mono> result = context.elicitation(new TypeReference() { - }, "Test message", meta); + Mono> result = context.elicit(e -> e.message("Test message").meta(meta), + new TypeReference() { + }); StepVerifier.create(result).assertNext(structuredResult -> { assertThat(structuredResult.action()).isEqualTo(ElicitResult.Action.ACCEPT); @@ -197,23 +198,30 @@ record Person(String name, int age) { @Test public void testElicitationWithNullTypeReference() { assertThat(org.junit.jupiter.api.Assertions.assertThrows(IllegalArgumentException.class, () -> { - context.elicitation(null, "Test message", null); + context.elicit((TypeReference) null); + })).hasMessageContaining("Elicitation response type must not be null"); + } + + @Test + public void testElicitationWithNullClassType() { + assertThat(org.junit.jupiter.api.Assertions.assertThrows(IllegalArgumentException.class, () -> { + context.elicit((Class) null); })).hasMessageContaining("Elicitation response type must not be null"); } @Test public void testElicitationWithEmptyMessage() { assertThat(org.junit.jupiter.api.Assertions.assertThrows(IllegalArgumentException.class, () -> { - context.elicitation(new TypeReference() { - }, "", null); + context.elicit(e -> e.message("").meta(null), new TypeReference() { + }); })).hasMessageContaining("Elicitation message must not be empty"); } @Test public void testElicitationWithNullMessage() { assertThat(org.junit.jupiter.api.Assertions.assertThrows(IllegalArgumentException.class, () -> { - context.elicitation(new TypeReference() { - }, null, null); + context.elicit(e -> e.message(null).meta(null), new TypeReference() { + }); })).hasMessageContaining("Elicitation message must not be empty"); } @@ -221,9 +229,9 @@ public void testElicitationWithNullMessage() { public void testElicitationReturnsEmptyWhenNotSupported() { when(exchange.getClientCapabilities()).thenReturn(null); - Mono>> result = context - .elicitation(new TypeReference>() { - }, "Test message", null); + Mono>> result = context.elicit(e -> e.message("Test message"), + new TypeReference>() { + }); StepVerifier.create(result).verifyComplete(); } @@ -242,9 +250,9 @@ public void testElicitationReturnsResultWhenActionIsNotAccept() { when(expectedResult.meta()).thenReturn(null); when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(Mono.just(expectedResult)); - Mono>> result = context - .elicitation(new TypeReference>() { - }, "Test message", null); + Mono>> result = context.elicit(e -> e.message("Test message"), + new TypeReference>() { + }); StepVerifier.create(result).assertNext(structuredResult -> { assertThat(structuredResult.action()).isEqualTo(ElicitResult.Action.DECLINE); @@ -272,9 +280,9 @@ record PersonWithAddress(String name, int age, Address address) { when(expectedResult.meta()).thenReturn(null); when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(Mono.just(expectedResult)); - Mono> result = context - .elicitation(new TypeReference() { - }, "Test message", null); + Mono> result = context.elicit(e -> e.message("Test message"), + new TypeReference() { + }); StepVerifier.create(result).assertNext(structuredResult -> { assertThat(structuredResult.action()).isEqualTo(ElicitResult.Action.ACCEPT); @@ -302,9 +310,9 @@ public void testElicitationHandlesListTypes() { when(expectedResult.meta()).thenReturn(null); when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(Mono.just(expectedResult)); - Mono>> result = context - .elicitation(new TypeReference>() { - }, "Test message", null); + Mono>> result = context.elicit(e -> e.message("Test message"), + new TypeReference>() { + }); StepVerifier.create(result).assertNext(structuredResult -> { assertThat(structuredResult.structuredContent()).containsKey("items"); @@ -324,12 +332,13 @@ public void testElicitationWithTypeReference() { when(expectedResult.content()).thenReturn(contentMap); when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(Mono.just(expectedResult)); - Mono> result = context.elicitation(new TypeReference>() { - }); + Mono>> result = context + .elicit(new TypeReference>() { + }); StepVerifier.create(result).assertNext(map -> { - assertThat(map).containsEntry("result", "success"); - assertThat(map).containsEntry("data", "test value"); + assertThat(map.structuredContent()).containsEntry("result", "success"); + assertThat(map.structuredContent()).containsEntry("data", "test value"); }).verifyComplete(); } @@ -348,7 +357,7 @@ public void testElicitationWithRequest() { when(exchange.createElicitation(elicitRequest)).thenReturn(Mono.just(expectedResult)); - Mono result = context.elicitation(elicitRequest); + Mono result = context.elicit(elicitRequest); StepVerifier.create(result).expectNext(expectedResult).verifyComplete(); } @@ -362,7 +371,7 @@ public void testElicitationWhenNotSupported() { .requestedSchema(Map.of("type", "string")) .build(); - Mono result = context.elicitation(elicitRequest); + Mono result = context.elicit(elicitRequest); StepVerifier.create(result).verifyComplete(); } @@ -379,7 +388,7 @@ public void testSamplingWithMessages() { CreateMessageResult expectedResult = mock(CreateMessageResult.class); when(exchange.createMessage(any(CreateMessageRequest.class))).thenReturn(Mono.just(expectedResult)); - Mono result = context.sampling("Message 1", "Message 2"); + Mono result = context.sample("Message 1", "Message 2"); StepVerifier.create(result).expectNext(expectedResult).verifyComplete(); } @@ -394,7 +403,7 @@ public void testSamplingWithConsumer() { CreateMessageResult expectedResult = mock(CreateMessageResult.class); when(exchange.createMessage(any(CreateMessageRequest.class))).thenReturn(Mono.just(expectedResult)); - Mono result = context.sampling(spec -> { + Mono result = context.sample(spec -> { spec.message(new TextContent("Test message")); spec.systemPrompt("System prompt"); spec.temperature(0.7); @@ -427,7 +436,7 @@ public void testSamplingWithRequest() { when(exchange.createMessage(createRequest)).thenReturn(Mono.just(expectedResult)); - Mono result = context.sampling(createRequest); + Mono result = context.sample(createRequest); StepVerifier.create(result).expectNext(expectedResult).verifyComplete(); } @@ -441,7 +450,7 @@ public void testSamplingWhenNotSupported() { .maxTokens(500) .build(); - Mono result = context.sampling(createRequest); + Mono result = context.sample(createRequest); StepVerifier.create(result).verifyComplete(); } diff --git a/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultMcpSyncRequestContextTests.java b/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultMcpSyncRequestContextTests.java index b807c21..b2ece1a 100644 --- a/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultMcpSyncRequestContextTests.java +++ b/mcp-annotations/src/test/java/org/springaicommunity/mcp/context/DefaultMcpSyncRequestContextTests.java @@ -140,9 +140,9 @@ public void testElicitationWithTypeAndMessage() { when(expectedResult.meta()).thenReturn(null); when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(expectedResult); - Optional>> result = context - .elicitation(new TypeReference>() { - }, "Test message", null); + Optional>> result = context.elicit(e -> e.message("Test message"), + new TypeReference>() { + }); assertThat(result).isPresent(); assertThat(result.get().action()).isEqualTo(ElicitResult.Action.ACCEPT); @@ -177,8 +177,9 @@ record Person(String name, int age) { when(expectedResult.meta()).thenReturn(resultMeta); when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(expectedResult); - Optional> result = context.elicitation(new TypeReference() { - }, "Test message", requestMeta); + Optional> result = context + .elicit(e -> e.message("Test message").meta(requestMeta), new TypeReference() { + }); assertThat(result).isPresent(); assertThat(result.get().action()).isEqualTo(ElicitResult.Action.ACCEPT); @@ -196,7 +197,7 @@ record Person(String name, int age) { @Test public void testElicitationWithNullResponseType() { - assertThatThrownBy(() -> context.elicitation((TypeReference) null)) + assertThatThrownBy(() -> context.elicit((TypeReference) null)) .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("Elicitation response type must not be null"); } @@ -205,8 +206,9 @@ public void testElicitationWithNullResponseType() { public void testElicitationWithTypeReturnsEmptyWhenNotSupported() { when(exchange.getClientCapabilities()).thenReturn(null); - Optional> result = context.elicitation(new TypeReference>() { - }); + Optional>> result = context + .elicit(new TypeReference>() { + }); assertThat(result).isEmpty(); } @@ -222,9 +224,9 @@ public void testElicitationWithTypeReturnsEmptyWhenActionIsNotAccept() { when(expectedResult.action()).thenReturn(ElicitResult.Action.DECLINE); when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(expectedResult); - Optional>> result = context - .elicitation(new TypeReference>() { - }, "Test message", null); + Optional>> result = context.elicit(e -> e.message("Test message"), + new TypeReference>() { + }); assertThat(result).isEmpty(); } @@ -250,8 +252,8 @@ record PersonWithAddress(String name, int age, Address address) { when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(expectedResult); Optional> result = context - .elicitation(new TypeReference() { - }, "Test message", null); + .elicit(e -> e.message("Test message").meta(null), new TypeReference() { + }); assertThat(result).isPresent(); assertThat(result.get().action()).isEqualTo(ElicitResult.Action.ACCEPT); @@ -279,8 +281,8 @@ public void testElicitationWithTypeHandlesListTypes() { when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(expectedResult); Optional>> result = context - .elicitation(new TypeReference>() { - }, "Test message", null); + .elicit(e -> e.message("Test message").meta(null), new TypeReference>() { + }); assertThat(result).isPresent(); assertThat(result.get().structuredContent()).containsKey("items"); @@ -299,12 +301,13 @@ public void testElicitationWithTypeReference() { when(expectedResult.content()).thenReturn(contentMap); when(exchange.createElicitation(any(ElicitRequest.class))).thenReturn(expectedResult); - Optional> result = context.elicitation(new TypeReference>() { - }); + Optional>> result = context + .elicit(e -> e.message("Test message").meta(null), new TypeReference>() { + }); assertThat(result).isPresent(); - assertThat(result.get()).containsEntry("result", "success"); - assertThat(result.get()).containsEntry("data", "test value"); + assertThat(result.get().structuredContent()).containsEntry("result", "success"); + assertThat(result.get().structuredContent()).containsEntry("data", "test value"); } @Test @@ -322,7 +325,7 @@ public void testElicitationWithRequest() { when(exchange.createElicitation(elicitRequest)).thenReturn(expectedResult); - Optional result = context.elicitation(elicitRequest); + Optional result = context.elicit(elicitRequest); assertThat(result).isPresent(); assertThat(result.get()).isEqualTo(expectedResult); @@ -337,7 +340,7 @@ public void testElicitationWhenNotSupported() { .requestedSchema(Map.of("type", "string")) .build(); - Optional result = context.elicitation(elicitRequest); + Optional result = context.elicit(elicitRequest); assertThat(result).isEmpty(); } @@ -354,7 +357,7 @@ public void testSamplingWithMessages() { CreateMessageResult expectedResult = mock(CreateMessageResult.class); when(exchange.createMessage(any(CreateMessageRequest.class))).thenReturn(expectedResult); - Optional result = context.sampling("Message 1", "Message 2"); + Optional result = context.sample("Message 1", "Message 2"); assertThat(result).isPresent(); assertThat(result.get()).isEqualTo(expectedResult); @@ -370,7 +373,7 @@ public void testSamplingWithConsumer() { CreateMessageResult expectedResult = mock(CreateMessageResult.class); when(exchange.createMessage(any(CreateMessageRequest.class))).thenReturn(expectedResult); - Optional result = context.sampling(spec -> { + Optional result = context.sample(spec -> { spec.message(new TextContent("Test message")); spec.systemPrompt("System prompt"); spec.temperature(0.7); @@ -404,7 +407,7 @@ public void testSamplingWithRequest() { when(exchange.createMessage(createRequest)).thenReturn(expectedResult); - Optional result = context.sampling(createRequest); + Optional result = context.sample(createRequest); assertThat(result).isPresent(); assertThat(result.get()).isEqualTo(expectedResult); @@ -419,7 +422,7 @@ public void testSamplingWhenNotSupported() { .maxTokens(500) .build(); - Optional result = context.sampling(createRequest); + Optional result = context.sample(createRequest); assertThat(result).isEmpty(); }