From 200658eb7dfb775ed51deabdaa34d16af4b83ebc Mon Sep 17 00:00:00 2001 From: Daniel Garnier-Moiroux Date: Thu, 30 Oct 2025 21:23:49 +0100 Subject: [PATCH 01/11] Add ClientMcpSyncHandlersRegistry Signed-off-by: Daniel Garnier-Moiroux --- .../McpClientAutoConfiguration.java | 43 +- ...entAnnotationScannerAutoConfiguration.java | 10 + ...SpecificationFactoryAutoConfiguration.java | 66 --- .../McpSyncAnnotationCustomizer.java | 179 ------- .../McpSyncAnnotationCustomizerTests.java | 366 -------------- .../spring/ClientMcpSyncHandlersRegistry.java | 364 ++++++++++++++ .../ClientMcpSyncHandlersRegistryTests.java | 447 ++++++++++++++++++ 7 files changed, 840 insertions(+), 635 deletions(-) delete mode 100644 auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpSyncAnnotationCustomizer.java delete mode 100644 auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpSyncAnnotationCustomizerTests.java create mode 100644 mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistry.java create mode 100644 mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistryTests.java diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpClientAutoConfiguration.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpClientAutoConfiguration.java index 893c1910d70..bbcc95689f7 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpClientAutoConfiguration.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpClientAutoConfiguration.java @@ -24,22 +24,15 @@ import io.modelcontextprotocol.client.McpSyncClient; import io.modelcontextprotocol.spec.McpSchema; import org.springaicommunity.mcp.method.changed.prompt.AsyncPromptListChangedSpecification; -import org.springaicommunity.mcp.method.changed.prompt.SyncPromptListChangedSpecification; import org.springaicommunity.mcp.method.changed.resource.AsyncResourceListChangedSpecification; -import org.springaicommunity.mcp.method.changed.resource.SyncResourceListChangedSpecification; import org.springaicommunity.mcp.method.changed.tool.AsyncToolListChangedSpecification; -import org.springaicommunity.mcp.method.changed.tool.SyncToolListChangedSpecification; import org.springaicommunity.mcp.method.elicitation.AsyncElicitationSpecification; -import org.springaicommunity.mcp.method.elicitation.SyncElicitationSpecification; import org.springaicommunity.mcp.method.logging.AsyncLoggingSpecification; -import org.springaicommunity.mcp.method.logging.SyncLoggingSpecification; import org.springaicommunity.mcp.method.progress.AsyncProgressSpecification; -import org.springaicommunity.mcp.method.progress.SyncProgressSpecification; import org.springaicommunity.mcp.method.sampling.AsyncSamplingSpecification; -import org.springaicommunity.mcp.method.sampling.SyncSamplingSpecification; +import org.springframework.ai.mcp.annotation.spring.ClientMcpSyncHandlersRegistry; import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpAsyncAnnotationCustomizer; -import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpSyncAnnotationCustomizer; import org.springframework.ai.mcp.client.common.autoconfigure.configurer.McpAsyncClientConfigurer; import org.springframework.ai.mcp.client.common.autoconfigure.configurer.McpSyncClientConfigurer; import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpClientCommonProperties; @@ -161,7 +154,8 @@ private String connectedClientName(String clientName, String serverConnectionNam matchIfMissing = true) public List mcpSyncClients(McpSyncClientConfigurer mcpSyncClientConfigurer, McpClientCommonProperties commonProperties, - ObjectProvider> transportsProvider) { + ObjectProvider> transportsProvider, + ClientMcpSyncHandlersRegistry clientMcpSyncHandlersRegistry) { List mcpSyncClients = new ArrayList<>(); @@ -176,7 +170,22 @@ public List mcpSyncClients(McpSyncClientConfigurer mcpSyncClientC McpClient.SyncSpec spec = McpClient.sync(namedTransport.transport()) .clientInfo(clientInfo) - .requestTimeout(commonProperties.getRequestTimeout()); + .requestTimeout(commonProperties.getRequestTimeout()) + .sampling(samplingRequest -> clientMcpSyncHandlersRegistry.handleSampling(namedTransport.name(), + samplingRequest)) + .elicitation(elicitationRequest -> clientMcpSyncHandlersRegistry + .handleElicitation(namedTransport.name(), elicitationRequest)) + .loggingConsumer(loggingMessageNotification -> clientMcpSyncHandlersRegistry + .handleLogging(namedTransport.name(), loggingMessageNotification)) + .progressConsumer(progressNotification -> clientMcpSyncHandlersRegistry + .handleProgress(namedTransport.name(), progressNotification)) + .toolsChangeConsumer(newTools -> clientMcpSyncHandlersRegistry + .handleToolListChanged(namedTransport.name(), newTools)) + .promptsChangeConsumer(newPrompts -> clientMcpSyncHandlersRegistry + .handlePromptListChanged(namedTransport.name(), newPrompts)) + .resourcesChangeConsumer(newResources -> clientMcpSyncHandlersRegistry + .handleResourceListChanged(namedTransport.name(), newResources)) + .capabilities(clientMcpSyncHandlersRegistry.getCapabilities(namedTransport.name())); spec = mcpSyncClientConfigurer.configure(namedTransport.name(), spec); @@ -222,20 +231,6 @@ McpSyncClientConfigurer mcpSyncClientConfigurer(ObjectProvider loggingSpecs, - List samplingSpecs, List elicitationSpecs, - List progressSpecs, - List syncToolListChangedSpecifications, - List syncResourceListChangedSpecifications, - List syncPromptListChangedSpecifications) { - return new McpSyncAnnotationCustomizer(samplingSpecs, loggingSpecs, elicitationSpecs, progressSpecs, - syncToolListChangedSpecifications, syncResourceListChangedSpecifications, - syncPromptListChangedSpecifications); - } - // Async client configuration @Bean diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientAnnotationScannerAutoConfiguration.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientAnnotationScannerAutoConfiguration.java index 8ce05bcbe07..81b19e2ce2f 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientAnnotationScannerAutoConfiguration.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientAnnotationScannerAutoConfiguration.java @@ -27,9 +27,11 @@ import org.springaicommunity.mcp.annotation.McpSampling; import org.springaicommunity.mcp.annotation.McpToolListChanged; +import org.springframework.ai.mcp.annotation.spring.ClientMcpSyncHandlersRegistry; import org.springframework.ai.mcp.annotation.spring.scan.AbstractAnnotatedMethodBeanFactoryInitializationAotProcessor; import org.springframework.ai.mcp.annotation.spring.scan.AbstractAnnotatedMethodBeanPostProcessor; import org.springframework.ai.mcp.annotation.spring.scan.AbstractMcpAnnotatedBeans; +import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpClientCommonProperties; import org.springframework.aot.hint.MemberCategory; import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.RuntimeHintsRegistrar; @@ -58,6 +60,14 @@ public class McpClientAnnotationScannerAutoConfiguration { McpSampling.class, McpElicitation.class, McpProgress.class, McpToolListChanged.class, McpResourceListChanged.class, McpPromptListChanged.class); + @Bean + @ConditionalOnMissingBean + @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "SYNC", + matchIfMissing = true) + public ClientMcpSyncHandlersRegistry mcpHandlersRegistry() { + return new ClientMcpSyncHandlersRegistry(); + } + @Bean @ConditionalOnMissingBean public ClientMcpAnnotatedBeans clientAnnotatedBeans() { diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientSpecificationFactoryAutoConfiguration.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientSpecificationFactoryAutoConfiguration.java index 620028f0e63..52b88ab740a 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientSpecificationFactoryAutoConfiguration.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientSpecificationFactoryAutoConfiguration.java @@ -18,30 +18,16 @@ import java.util.List; -import org.springaicommunity.mcp.annotation.McpElicitation; import org.springaicommunity.mcp.annotation.McpLogging; -import org.springaicommunity.mcp.annotation.McpProgress; -import org.springaicommunity.mcp.annotation.McpPromptListChanged; -import org.springaicommunity.mcp.annotation.McpResourceListChanged; -import org.springaicommunity.mcp.annotation.McpSampling; -import org.springaicommunity.mcp.annotation.McpToolListChanged; import org.springaicommunity.mcp.method.changed.prompt.AsyncPromptListChangedSpecification; -import org.springaicommunity.mcp.method.changed.prompt.SyncPromptListChangedSpecification; import org.springaicommunity.mcp.method.changed.resource.AsyncResourceListChangedSpecification; -import org.springaicommunity.mcp.method.changed.resource.SyncResourceListChangedSpecification; import org.springaicommunity.mcp.method.changed.tool.AsyncToolListChangedSpecification; -import org.springaicommunity.mcp.method.changed.tool.SyncToolListChangedSpecification; import org.springaicommunity.mcp.method.elicitation.AsyncElicitationSpecification; -import org.springaicommunity.mcp.method.elicitation.SyncElicitationSpecification; import org.springaicommunity.mcp.method.logging.AsyncLoggingSpecification; -import org.springaicommunity.mcp.method.logging.SyncLoggingSpecification; import org.springaicommunity.mcp.method.progress.AsyncProgressSpecification; -import org.springaicommunity.mcp.method.progress.SyncProgressSpecification; import org.springaicommunity.mcp.method.sampling.AsyncSamplingSpecification; -import org.springaicommunity.mcp.method.sampling.SyncSamplingSpecification; import org.springframework.ai.mcp.annotation.spring.AsyncMcpAnnotationProviders; -import org.springframework.ai.mcp.annotation.spring.SyncMcpAnnotationProviders; import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientAnnotationScannerAutoConfiguration.ClientMcpAnnotatedBeans; import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpClientCommonProperties; import org.springframework.boot.autoconfigure.AutoConfiguration; @@ -60,58 +46,6 @@ havingValue = "true", matchIfMissing = true) public class McpClientSpecificationFactoryAutoConfiguration { - @Configuration(proxyBeanMethods = false) - @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "SYNC", - matchIfMissing = true) - static class SyncClientSpecificationConfiguration { - - @Bean - List loggingSpecs(ClientMcpAnnotatedBeans beansWithMcpMethodAnnotations) { - return SyncMcpAnnotationProviders - .loggingSpecifications(beansWithMcpMethodAnnotations.getBeansByAnnotation(McpLogging.class)); - } - - @Bean - List samplingSpecs(ClientMcpAnnotatedBeans beansWithMcpMethodAnnotations) { - return SyncMcpAnnotationProviders - .samplingSpecifications(beansWithMcpMethodAnnotations.getBeansByAnnotation(McpSampling.class)); - } - - @Bean - List elicitationSpecs(ClientMcpAnnotatedBeans beansWithMcpMethodAnnotations) { - return SyncMcpAnnotationProviders - .elicitationSpecifications(beansWithMcpMethodAnnotations.getBeansByAnnotation(McpElicitation.class)); - } - - @Bean - List progressSpecs(ClientMcpAnnotatedBeans beansWithMcpMethodAnnotations) { - return SyncMcpAnnotationProviders - .progressSpecifications(beansWithMcpMethodAnnotations.getBeansByAnnotation(McpProgress.class)); - } - - @Bean - List syncToolListChangedSpecs( - ClientMcpAnnotatedBeans beansWithMcpMethodAnnotations) { - return SyncMcpAnnotationProviders.toolListChangedSpecifications( - beansWithMcpMethodAnnotations.getBeansByAnnotation(McpToolListChanged.class)); - } - - @Bean - List syncResourceListChangedSpecs( - ClientMcpAnnotatedBeans beansWithMcpMethodAnnotations) { - return SyncMcpAnnotationProviders.resourceListChangedSpecifications( - beansWithMcpMethodAnnotations.getBeansByAnnotation(McpResourceListChanged.class)); - } - - @Bean - List syncPromptListChangedSpecs( - ClientMcpAnnotatedBeans beansWithMcpMethodAnnotations) { - return SyncMcpAnnotationProviders.promptListChangedSpecifications( - beansWithMcpMethodAnnotations.getBeansByAnnotation(McpPromptListChanged.class)); - } - - } - @Configuration(proxyBeanMethods = false) @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "ASYNC") static class AsyncClientSpecificationConfiguration { diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpSyncAnnotationCustomizer.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpSyncAnnotationCustomizer.java deleted file mode 100644 index 69d19bfe1c0..00000000000 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpSyncAnnotationCustomizer.java +++ /dev/null @@ -1,179 +0,0 @@ -/* - * Copyright 2025-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.mcp.client.common.autoconfigure.annotations; - -import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.stream.Stream; - -import io.modelcontextprotocol.client.McpClient.SyncSpec; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springaicommunity.mcp.method.changed.prompt.SyncPromptListChangedSpecification; -import org.springaicommunity.mcp.method.changed.resource.SyncResourceListChangedSpecification; -import org.springaicommunity.mcp.method.changed.tool.SyncToolListChangedSpecification; -import org.springaicommunity.mcp.method.elicitation.SyncElicitationSpecification; -import org.springaicommunity.mcp.method.logging.SyncLoggingSpecification; -import org.springaicommunity.mcp.method.progress.SyncProgressSpecification; -import org.springaicommunity.mcp.method.sampling.SyncSamplingSpecification; - -import org.springframework.ai.mcp.customizer.McpSyncClientCustomizer; -import org.springframework.util.CollectionUtils; - -/** - * @author Christian Tzolov - */ -public class McpSyncAnnotationCustomizer implements McpSyncClientCustomizer { - - private static final Logger logger = LoggerFactory.getLogger(McpSyncAnnotationCustomizer.class); - - private final List syncSamplingSpecifications; - - private final List syncLoggingSpecifications; - - private final List syncElicitationSpecifications; - - private final List syncProgressSpecifications; - - private final List syncToolListChangedSpecifications; - - private final List syncResourceListChangedSpecifications; - - private final List syncPromptListChangedSpecifications; - - // Tracking registered specifications per client - private final Map clientElicitationSpecs = new ConcurrentHashMap<>(); - - private final Map clientSamplingSpecs = new ConcurrentHashMap<>(); - - public McpSyncAnnotationCustomizer(List syncSamplingSpecifications, - List syncLoggingSpecifications, - List syncElicitationSpecifications, - List syncProgressSpecifications, - List syncToolListChangedSpecifications, - List syncResourceListChangedSpecifications, - List syncPromptListChangedSpecifications) { - - this.syncSamplingSpecifications = syncSamplingSpecifications; - this.syncLoggingSpecifications = syncLoggingSpecifications; - this.syncElicitationSpecifications = syncElicitationSpecifications; - this.syncProgressSpecifications = syncProgressSpecifications; - this.syncToolListChangedSpecifications = syncToolListChangedSpecifications; - this.syncResourceListChangedSpecifications = syncResourceListChangedSpecifications; - this.syncPromptListChangedSpecifications = syncPromptListChangedSpecifications; - } - - @Override - public void customize(String name, SyncSpec clientSpec) { - - if (!CollectionUtils.isEmpty(this.syncElicitationSpecifications)) { - this.syncElicitationSpecifications.forEach(elicitationSpec -> { - Stream.of(elicitationSpec.clients()).forEach(clientId -> { - if (clientId.equalsIgnoreCase(name)) { - // Check if client already has an elicitation spec - if (this.clientElicitationSpecs.containsKey(name)) { - throw new IllegalArgumentException("Client '" + name - + "' already has an elicitationSpec registered. Only one elicitationSpec is allowed per client."); - } - - this.clientElicitationSpecs.put(name, Boolean.TRUE); - clientSpec.elicitation(elicitationSpec.elicitationHandler()); - - logger.info("Registered elicitationSpec for client '{}'.", name); - } - }); - }); - } - - if (!CollectionUtils.isEmpty(this.syncSamplingSpecifications)) { - this.syncSamplingSpecifications.forEach(samplingSpec -> { - Stream.of(samplingSpec.clients()).forEach(clientId -> { - if (clientId.equalsIgnoreCase(name)) { - - // Check if client already has a sampling spec - if (this.clientSamplingSpecs.containsKey(name)) { - throw new IllegalArgumentException("Client '" + name - + "' already has a samplingSpec registered. Only one samplingSpec is allowed per client."); - } - this.clientSamplingSpecs.put(name, Boolean.TRUE); - - clientSpec.sampling(samplingSpec.samplingHandler()); - - logger.info("Registered samplingSpec for client '{}'.", name); - } - }); - }); - } - - if (!CollectionUtils.isEmpty(this.syncLoggingSpecifications)) { - this.syncLoggingSpecifications.forEach(loggingSpec -> { - Stream.of(loggingSpec.clients()).forEach(clientId -> { - if (clientId.equalsIgnoreCase(name)) { - clientSpec.loggingConsumer(loggingSpec.loggingHandler()); - logger.info("Registered loggingSpec for client '{}'.", name); - } - }); - }); - } - - if (!CollectionUtils.isEmpty(this.syncProgressSpecifications)) { - this.syncProgressSpecifications.forEach(progressSpec -> { - Stream.of(progressSpec.clients()).forEach(clientId -> { - if (clientId.equalsIgnoreCase(name)) { - clientSpec.progressConsumer(progressSpec.progressHandler()); - logger.info("Registered progressSpec for client '{}'.", name); - } - }); - }); - } - - if (!CollectionUtils.isEmpty(this.syncToolListChangedSpecifications)) { - this.syncToolListChangedSpecifications.forEach(toolListChangedSpec -> { - Stream.of(toolListChangedSpec.clients()).forEach(clientId -> { - if (clientId.equalsIgnoreCase(name)) { - clientSpec.toolsChangeConsumer(toolListChangedSpec.toolListChangeHandler()); - logger.info("Registered toolListChangedSpec for client '{}'.", name); - } - }); - }); - } - - if (!CollectionUtils.isEmpty(this.syncResourceListChangedSpecifications)) { - this.syncResourceListChangedSpecifications.forEach(resourceListChangedSpec -> { - Stream.of(resourceListChangedSpec.clients()).forEach(clientId -> { - if (clientId.equalsIgnoreCase(name)) { - clientSpec.resourcesChangeConsumer(resourceListChangedSpec.resourceListChangeHandler()); - logger.info("Registered resourceListChangedSpec for client '{}'.", name); - } - }); - }); - } - - if (!CollectionUtils.isEmpty(this.syncPromptListChangedSpecifications)) { - this.syncPromptListChangedSpecifications.forEach(promptListChangedSpec -> { - Stream.of(promptListChangedSpec.clients()).forEach(clientId -> { - if (clientId.equalsIgnoreCase(name)) { - clientSpec.promptsChangeConsumer(promptListChangedSpec.promptListChangeHandler()); - logger.info("Registered promptListChangedSpec for client '{}'.", name); - } - }); - }); - } - } - -} diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpSyncAnnotationCustomizerTests.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpSyncAnnotationCustomizerTests.java deleted file mode 100644 index 2e6f2f39b53..00000000000 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpSyncAnnotationCustomizerTests.java +++ /dev/null @@ -1,366 +0,0 @@ -/* - * Copyright 2025-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.mcp.client.common.autoconfigure.annotations; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -import io.modelcontextprotocol.client.McpClient.SyncSpec; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Mock; -import org.mockito.junit.jupiter.MockitoExtension; -import org.springaicommunity.mcp.method.changed.prompt.SyncPromptListChangedSpecification; -import org.springaicommunity.mcp.method.changed.resource.SyncResourceListChangedSpecification; -import org.springaicommunity.mcp.method.changed.tool.SyncToolListChangedSpecification; -import org.springaicommunity.mcp.method.elicitation.SyncElicitationSpecification; -import org.springaicommunity.mcp.method.logging.SyncLoggingSpecification; -import org.springaicommunity.mcp.method.progress.SyncProgressSpecification; -import org.springaicommunity.mcp.method.sampling.SyncSamplingSpecification; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verifyNoInteractions; -import static org.mockito.Mockito.when; - -@ExtendWith(MockitoExtension.class) -class McpSyncAnnotationCustomizerTests { - - @Mock - private SyncSpec syncSpec; - - private List samplingSpecs; - - private List loggingSpecs; - - private List elicitationSpecs; - - private List progressSpecs; - - private List toolListChangedSpecs; - - private List resourceListChangedSpecs; - - private List promptListChangedSpecs; - - @BeforeEach - void setUp() { - this.samplingSpecs = new ArrayList<>(); - this.loggingSpecs = new ArrayList<>(); - this.elicitationSpecs = new ArrayList<>(); - this.progressSpecs = new ArrayList<>(); - this.toolListChangedSpecs = new ArrayList<>(); - this.resourceListChangedSpecs = new ArrayList<>(); - this.promptListChangedSpecs = new ArrayList<>(); - } - - @Test - void constructorShouldInitializeAllFields() { - McpSyncAnnotationCustomizer customizer = new McpSyncAnnotationCustomizer(this.samplingSpecs, this.loggingSpecs, - this.elicitationSpecs, this.progressSpecs, this.toolListChangedSpecs, this.resourceListChangedSpecs, - this.promptListChangedSpecs); - - assertThat(customizer).isNotNull(); - } - - @Test - void constructorShouldAcceptNullLists() { - McpSyncAnnotationCustomizer customizer = new McpSyncAnnotationCustomizer(null, null, null, null, null, null, - null); - - assertThat(customizer).isNotNull(); - } - - @Test - void customizeShouldNotRegisterAnythingWhenAllListsAreEmpty() { - McpSyncAnnotationCustomizer customizer = new McpSyncAnnotationCustomizer(this.samplingSpecs, this.loggingSpecs, - this.elicitationSpecs, this.progressSpecs, this.toolListChangedSpecs, this.resourceListChangedSpecs, - this.promptListChangedSpecs); - - customizer.customize("test-client", this.syncSpec); - - verifyNoInteractions(this.syncSpec); - } - - @Test - void customizeShouldNotRegisterElicitationSpecForNonMatchingClient() { - SyncElicitationSpecification elicitationSpec = mock(SyncElicitationSpecification.class); - when(elicitationSpec.clients()).thenReturn(new String[] { "other-client" }); - this.elicitationSpecs.add(elicitationSpec); - - McpSyncAnnotationCustomizer customizer = new McpSyncAnnotationCustomizer(this.samplingSpecs, this.loggingSpecs, - this.elicitationSpecs, this.progressSpecs, this.toolListChangedSpecs, this.resourceListChangedSpecs, - this.promptListChangedSpecs); - - customizer.customize("test-client", this.syncSpec); - - verifyNoInteractions(this.syncSpec); - } - - @Test - void customizeShouldThrowExceptionWhenDuplicateElicitationSpecRegistered() { - SyncElicitationSpecification elicitationSpec1 = mock(SyncElicitationSpecification.class); - SyncElicitationSpecification elicitationSpec2 = mock(SyncElicitationSpecification.class); - - when(elicitationSpec1.clients()).thenReturn(new String[] { "test-client" }); - when(elicitationSpec1.elicitationHandler()).thenReturn(request -> null); - when(elicitationSpec2.clients()).thenReturn(new String[] { "test-client" }); - // No need to stub elicitationSpec2.elicitationHandler() as exception is thrown - // before it's accessed - - this.elicitationSpecs.addAll(Arrays.asList(elicitationSpec1, elicitationSpec2)); - - McpSyncAnnotationCustomizer customizer = new McpSyncAnnotationCustomizer(this.samplingSpecs, this.loggingSpecs, - this.elicitationSpecs, this.progressSpecs, this.toolListChangedSpecs, this.resourceListChangedSpecs, - this.promptListChangedSpecs); - - assertThatThrownBy(() -> customizer.customize("test-client", this.syncSpec)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Client 'test-client' already has an elicitationSpec registered"); - } - - @Test - void customizeShouldThrowExceptionWhenDuplicateSamplingSpecRegistered() { - SyncSamplingSpecification samplingSpec1 = mock(SyncSamplingSpecification.class); - SyncSamplingSpecification samplingSpec2 = mock(SyncSamplingSpecification.class); - - when(samplingSpec1.clients()).thenReturn(new String[] { "test-client" }); - when(samplingSpec1.samplingHandler()).thenReturn(request -> null); - when(samplingSpec2.clients()).thenReturn(new String[] { "test-client" }); - // No need to stub samplingSpec2.samplingHandler() as exception is thrown before - // it's accessed - - this.samplingSpecs.addAll(Arrays.asList(samplingSpec1, samplingSpec2)); - - McpSyncAnnotationCustomizer customizer = new McpSyncAnnotationCustomizer(this.samplingSpecs, this.loggingSpecs, - this.elicitationSpecs, this.progressSpecs, this.toolListChangedSpecs, this.resourceListChangedSpecs, - this.promptListChangedSpecs); - - assertThatThrownBy(() -> customizer.customize("test-client", this.syncSpec)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Client 'test-client' already has a samplingSpec registered"); - } - - @Test - void customizeShouldSkipSpecificationsWithNonMatchingClientIds() { - // Setup specs with different client IDs - SyncLoggingSpecification loggingSpec = mock(SyncLoggingSpecification.class); - SyncProgressSpecification progressSpec = mock(SyncProgressSpecification.class); - SyncElicitationSpecification elicitationSpec = mock(SyncElicitationSpecification.class); - - when(loggingSpec.clients()).thenReturn(new String[] { "other-client" }); - when(progressSpec.clients()).thenReturn(new String[] { "another-client" }); - when(elicitationSpec.clients()).thenReturn(new String[] { "different-client" }); - - this.loggingSpecs.add(loggingSpec); - this.progressSpecs.add(progressSpec); - this.elicitationSpecs.add(elicitationSpec); - - McpSyncAnnotationCustomizer customizer = new McpSyncAnnotationCustomizer(this.samplingSpecs, this.loggingSpecs, - this.elicitationSpecs, this.progressSpecs, this.toolListChangedSpecs, this.resourceListChangedSpecs, - this.promptListChangedSpecs); - - customizer.customize("target-client", this.syncSpec); - - // None of the specifications should be registered since client IDs don't match - verifyNoInteractions(this.syncSpec); - } - - @Test - void customizeShouldAllowElicitationSpecForDifferentClients() { - SyncElicitationSpecification elicitationSpec1 = mock(SyncElicitationSpecification.class); - SyncElicitationSpecification elicitationSpec2 = mock(SyncElicitationSpecification.class); - - when(elicitationSpec1.clients()).thenReturn(new String[] { "client1" }); - when(elicitationSpec1.elicitationHandler()).thenReturn(request -> null); - when(elicitationSpec2.clients()).thenReturn(new String[] { "client2" }); - when(elicitationSpec2.elicitationHandler()).thenReturn(request -> null); - - this.elicitationSpecs.addAll(Arrays.asList(elicitationSpec1, elicitationSpec2)); - - McpSyncAnnotationCustomizer customizer = new McpSyncAnnotationCustomizer(this.samplingSpecs, this.loggingSpecs, - this.elicitationSpecs, this.progressSpecs, this.toolListChangedSpecs, this.resourceListChangedSpecs, - this.promptListChangedSpecs); - - // Should not throw exception since they are for different clients - SyncSpec syncSpec1 = mock(SyncSpec.class); - customizer.customize("client1", syncSpec1); - - SyncSpec syncSpec2 = mock(SyncSpec.class); - customizer.customize("client2", syncSpec2); - - // No exception should be thrown, indicating successful registration for different - // clients - } - - @Test - void customizeShouldAllowSamplingSpecForDifferentClients() { - SyncSamplingSpecification samplingSpec1 = mock(SyncSamplingSpecification.class); - SyncSamplingSpecification samplingSpec2 = mock(SyncSamplingSpecification.class); - - when(samplingSpec1.clients()).thenReturn(new String[] { "client1" }); - when(samplingSpec1.samplingHandler()).thenReturn(request -> null); - when(samplingSpec2.clients()).thenReturn(new String[] { "client2" }); - when(samplingSpec2.samplingHandler()).thenReturn(request -> null); - - this.samplingSpecs.addAll(Arrays.asList(samplingSpec1, samplingSpec2)); - - McpSyncAnnotationCustomizer customizer = new McpSyncAnnotationCustomizer(this.samplingSpecs, this.loggingSpecs, - this.elicitationSpecs, this.progressSpecs, this.toolListChangedSpecs, this.resourceListChangedSpecs, - this.promptListChangedSpecs); - - // Should not throw exception since they are for different clients - SyncSpec syncSpec1 = mock(SyncSpec.class); - customizer.customize("client1", syncSpec1); - - SyncSpec syncSpec2 = mock(SyncSpec.class); - customizer.customize("client2", syncSpec2); - - // No exception should be thrown, indicating successful registration for different - // clients - } - - @Test - void customizeShouldPreventMultipleElicitationCallsForSameClient() { - SyncElicitationSpecification elicitationSpec = mock(SyncElicitationSpecification.class); - when(elicitationSpec.clients()).thenReturn(new String[] { "test-client" }); - when(elicitationSpec.elicitationHandler()).thenReturn(request -> null); - this.elicitationSpecs.add(elicitationSpec); - - McpSyncAnnotationCustomizer customizer = new McpSyncAnnotationCustomizer(this.samplingSpecs, this.loggingSpecs, - this.elicitationSpecs, this.progressSpecs, this.toolListChangedSpecs, this.resourceListChangedSpecs, - this.promptListChangedSpecs); - - // First call should succeed - customizer.customize("test-client", this.syncSpec); - - // Second call should throw exception - SyncSpec syncSpec2 = mock(SyncSpec.class); - assertThatThrownBy(() -> customizer.customize("test-client", syncSpec2)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Client 'test-client' already has an elicitationSpec registered"); - } - - @Test - void customizeShouldPreventMultipleSamplingCallsForSameClient() { - SyncSamplingSpecification samplingSpec = mock(SyncSamplingSpecification.class); - when(samplingSpec.clients()).thenReturn(new String[] { "test-client" }); - when(samplingSpec.samplingHandler()).thenReturn(request -> null); - this.samplingSpecs.add(samplingSpec); - - McpSyncAnnotationCustomizer customizer = new McpSyncAnnotationCustomizer(this.samplingSpecs, this.loggingSpecs, - this.elicitationSpecs, this.progressSpecs, this.toolListChangedSpecs, this.resourceListChangedSpecs, - this.promptListChangedSpecs); - - // First call should succeed - customizer.customize("test-client", this.syncSpec); - - // Second call should throw exception - SyncSpec syncSpec2 = mock(SyncSpec.class); - assertThatThrownBy(() -> customizer.customize("test-client", syncSpec2)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Client 'test-client' already has a samplingSpec registered"); - } - - @Test - void customizeShouldPerformCaseInsensitiveClientIdMatching() { - SyncElicitationSpecification elicitationSpec = mock(SyncElicitationSpecification.class); - when(elicitationSpec.clients()).thenReturn(new String[] { "TEST-CLIENT" }); - when(elicitationSpec.elicitationHandler()).thenReturn(request -> null); - this.elicitationSpecs.add(elicitationSpec); - - McpSyncAnnotationCustomizer customizer = new McpSyncAnnotationCustomizer(this.samplingSpecs, this.loggingSpecs, - this.elicitationSpecs, this.progressSpecs, this.toolListChangedSpecs, this.resourceListChangedSpecs, - this.promptListChangedSpecs); - - // Should register elicitation spec when client ID matches case-insensitively - customizer.customize("test-client", this.syncSpec); - - // Verify that a subsequent call for the same client (case-insensitive) throws - // exception - SyncSpec syncSpec2 = mock(SyncSpec.class); - assertThatThrownBy(() -> customizer.customize("test-client", syncSpec2)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Client 'test-client' already has an elicitationSpec registered"); - } - - @Test - void customizeShouldHandleEmptyClientName() { - SyncLoggingSpecification loggingSpec = mock(SyncLoggingSpecification.class); - when(loggingSpec.clients()).thenReturn(new String[] { "" }); - when(loggingSpec.loggingHandler()).thenReturn(message -> { - }); - this.loggingSpecs.add(loggingSpec); - - McpSyncAnnotationCustomizer customizer = new McpSyncAnnotationCustomizer(this.samplingSpecs, this.loggingSpecs, - this.elicitationSpecs, this.progressSpecs, this.toolListChangedSpecs, this.resourceListChangedSpecs, - this.promptListChangedSpecs); - - // Should not throw exception when customizing for empty client name - customizer.customize("", this.syncSpec); - - } - - @Test - void customizeShouldAllowMultipleLoggingSpecsForSameClient() { - SyncLoggingSpecification loggingSpec1 = mock(SyncLoggingSpecification.class); - SyncLoggingSpecification loggingSpec2 = mock(SyncLoggingSpecification.class); - - when(loggingSpec1.clients()).thenReturn(new String[] { "test-client" }); - when(loggingSpec1.loggingHandler()).thenReturn(message -> { - }); - when(loggingSpec2.clients()).thenReturn(new String[] { "test-client" }); - when(loggingSpec2.loggingHandler()).thenReturn(message -> { - }); - - this.loggingSpecs.addAll(Arrays.asList(loggingSpec1, loggingSpec2)); - - McpSyncAnnotationCustomizer customizer = new McpSyncAnnotationCustomizer(this.samplingSpecs, this.loggingSpecs, - this.elicitationSpecs, this.progressSpecs, this.toolListChangedSpecs, this.resourceListChangedSpecs, - this.promptListChangedSpecs); - - // Should not throw exception for multiple logging specs for same client - customizer.customize("test-client", this.syncSpec); - - } - - @Test - void customizeShouldAllowMultipleProgressSpecsForSameClient() { - SyncProgressSpecification progressSpec1 = mock(SyncProgressSpecification.class); - SyncProgressSpecification progressSpec2 = mock(SyncProgressSpecification.class); - - when(progressSpec1.clients()).thenReturn(new String[] { "test-client" }); - when(progressSpec1.progressHandler()).thenReturn(notification -> { - }); - when(progressSpec2.clients()).thenReturn(new String[] { "test-client" }); - when(progressSpec2.progressHandler()).thenReturn(notification -> { - }); - - this.progressSpecs.addAll(Arrays.asList(progressSpec1, progressSpec2)); - - McpSyncAnnotationCustomizer customizer = new McpSyncAnnotationCustomizer(this.samplingSpecs, this.loggingSpecs, - this.elicitationSpecs, this.progressSpecs, this.toolListChangedSpecs, this.resourceListChangedSpecs, - this.promptListChangedSpecs); - - // Should not throw exception for multiple progress specs for same client - customizer.customize("test-client", this.syncSpec); - } - -} diff --git a/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistry.java b/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistry.java new file mode 100644 index 00000000000..b686d8aa180 --- /dev/null +++ b/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistry.java @@ -0,0 +1,364 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mcp.annotation.spring; + +import java.lang.annotation.Annotation; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.stream.Collectors; + +import io.modelcontextprotocol.spec.McpSchema; +import org.springaicommunity.mcp.annotation.McpElicitation; +import org.springaicommunity.mcp.annotation.McpLogging; +import org.springaicommunity.mcp.annotation.McpProgress; +import org.springaicommunity.mcp.annotation.McpPromptListChanged; +import org.springaicommunity.mcp.annotation.McpResourceListChanged; +import org.springaicommunity.mcp.annotation.McpSampling; +import org.springaicommunity.mcp.annotation.McpToolListChanged; + +import org.springframework.beans.BeansException; +import org.springframework.beans.factory.SmartInitializingSingleton; +import org.springframework.beans.factory.config.BeanFactoryPostProcessor; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.core.annotation.AnnotationUtils; +import org.springframework.util.ReflectionUtils; + +/** + * Registry of methods annotated with MCP Client annotations (sampling, logging, etc.). + * All beans in the application context are scanned to find these methods automatically. + * They are then exposed by the registry by client name. + *

+ * The scanning happens in two phases: + *

+ * First, once bean definitions are available, all bean types are scanned for the presence + * of MCP annotations. In particular, this is used to prepare the result + * {@link #getCapabilities(String)}, which is then used by MCP client auto-configurations + * to configure the client capabilities without needing to instantiate the beans. + *

+ * Second, after all singleton beans have been instantiated, all annotated beans are + * scanned again, MCP handlers are created to match the annotations, and stored by client. + * + * @see McpSampling + * @see McpElicitation + * @see McpLogging + * @see McpProgress + * @see McpToolListChanged + * @see McpPromptListChanged + * @see McpResourceListChanged + * @author Daniel Garnier-Moiroux + * @since 1.1.0 + */ +public class ClientMcpSyncHandlersRegistry implements BeanFactoryPostProcessor, SmartInitializingSingleton { + + private static final Class[] CLIENT_MCP_ANNOTATIONS = new Class[] { McpSampling.class, + McpElicitation.class, McpLogging.class, McpProgress.class, McpToolListChanged.class, + McpPromptListChanged.class, McpResourceListChanged.class }; + + private final McpSchema.ClientCapabilities EMPTY_CAPABILITIES = new McpSchema.ClientCapabilities(null, null, null, + null); + + private Map capabilitiesPerClient = new HashMap<>(); + + private ConfigurableListableBeanFactory beanFactory; + + private final Set allAnnotatedBeans = new HashSet<>(); + + private final Map> samplingHandlers = new HashMap<>(); + + private final Map> elicitationHandlers = new HashMap<>(); + + private final Map>> loggingHandlers = new HashMap<>(); + + private final Map>> progressHandlers = new HashMap<>(); + + private final Map>>> toolListChangedHandlers = new HashMap<>(); + + private final Map>>> promptListChangedHandlers = new HashMap<>(); + + private final Map>>> resourceListChangedHandlers = new HashMap<>(); + + /** + * Obtain the MCP capabilities declared for a given MCP client. Capabilities are + * registered with the {@link McpSampling} and {@link McpElicitation} annotations. + */ + public McpSchema.ClientCapabilities getCapabilities(String clientName) { + return this.capabilitiesPerClient.getOrDefault(clientName, this.EMPTY_CAPABILITIES); + } + + /** + * Invoke the sampling handler for a given MCP client. + * + * @see McpSampling + */ + public McpSchema.CreateMessageResult handleSampling(String name, McpSchema.CreateMessageRequest samplingRequest) { + var handler = this.samplingHandlers.get(name); + if (handler != null) { + return handler.apply(samplingRequest); + } + // TODO: handle null + return null; + } + + /** + * Invoke the elicitation handler for a given MCP client. + * + * @see McpElicitation + */ + public McpSchema.ElicitResult handleElicitation(String name, McpSchema.ElicitRequest elicitationRequest) { + var handler = this.elicitationHandlers.get(name); + if (handler != null) { + return handler.apply(elicitationRequest); + } + // TODO: handle null + return null; + } + + /** + * Invoke all elicitation handlers for a given MCP client, sequentially. + * + * @see McpLogging + */ + public void handleLogging(String name, McpSchema.LoggingMessageNotification loggingMessageNotification) { + var consumers = this.loggingHandlers.get(name); + if (consumers == null) { + // TODO handle + return; + } + for (var consumer : consumers) { + consumer.accept(loggingMessageNotification); + } + } + + /** + * Invoke all progress handlers for a given MCP client, sequentially. + * + * @see McpProgress + */ + public void handleProgress(String name, McpSchema.ProgressNotification progressNotification) { + var consumers = this.progressHandlers.get(name); + if (consumers == null) { + // TODO handle + return; + } + for (var consumer : consumers) { + consumer.accept(progressNotification); + } + } + + /** + * Invoke all tool list changed handlers for a given MCP client, sequentially. + * + * @see McpToolListChanged + */ + public void handleToolListChanged(String name, List updatedTools) { + var consumers = this.toolListChangedHandlers.get(name); + if (consumers == null) { + // TODO handle + return; + } + for (var consumer : consumers) { + consumer.accept(updatedTools); + } + } + + /** + * Invoke all prompt list changed handlers for a given MCP client, sequentially. + * + * @see McpPromptListChanged + */ + public void handlePromptListChanged(String name, List updatedPrompts) { + var consumers = this.promptListChangedHandlers.get(name); + if (consumers == null) { + // TODO handle + return; + } + for (var consumer : consumers) { + consumer.accept(updatedPrompts); + } + } + + /** + * Invoke all resource list changed handlers for a given MCP client, sequentially. + * + * @see McpResourceListChanged + */ + public void handleResourceListChanged(String name, List updatedResources) { + var consumers = this.resourceListChangedHandlers.get(name); + if (consumers == null) { + // TODO handle + return; + } + for (var consumer : consumers) { + consumer.accept(updatedResources); + } + } + + @Override + public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException { + this.beanFactory = beanFactory; + Map> elicitationClientToAnnotatedBeans = new HashMap<>(); + Map> samplingClientToAnnotatedBeans = new HashMap<>(); + for (var beanName : beanFactory.getBeanDefinitionNames()) { + var definition = beanFactory.getBeanDefinition(beanName); + var foundAnnotations = scan(definition.getResolvableType().toClass()); + if (!foundAnnotations.isEmpty()) { + this.allAnnotatedBeans.add(beanName); + } + for (var foundAnnotation : foundAnnotations) { + if (foundAnnotation instanceof McpSampling sampling) { + for (var client : sampling.clients()) { + samplingClientToAnnotatedBeans.computeIfAbsent(client, c -> new ArrayList<>()).add(beanName); + } + } + else if (foundAnnotation instanceof McpElicitation elicitation) { + for (var client : elicitation.clients()) { + elicitationClientToAnnotatedBeans.computeIfAbsent(client, c -> new ArrayList<>()).add(beanName); + } + } + } + } + + for (var elicitationEntry : elicitationClientToAnnotatedBeans.entrySet()) { + if (elicitationEntry.getValue().size() > 1) { + throw new IllegalArgumentException( + "Found 2 elicitation handlers for client [%s], found in bean with names %s. Only one @McpElicitation handler is allowed per client" + .formatted(elicitationEntry.getKey(), new LinkedHashSet<>(elicitationEntry.getValue()))); + } + } + for (var samplingEntry : samplingClientToAnnotatedBeans.entrySet()) { + if (samplingEntry.getValue().size() > 1) { + throw new IllegalArgumentException( + "Found 2 sampling handlers for client [%s], found in bean with names %s. Only one @McpSampling handler is allowed per client" + .formatted(samplingEntry.getKey(), new LinkedHashSet<>(samplingEntry.getValue()))); + } + } + + Map capsPerClient = new HashMap<>(); + for (var samplingClient : samplingClientToAnnotatedBeans.keySet()) { + capsPerClient.computeIfAbsent(samplingClient, ignored -> McpSchema.ClientCapabilities.builder()).sampling(); + } + for (var elicitationClient : elicitationClientToAnnotatedBeans.keySet()) { + capsPerClient.computeIfAbsent(elicitationClient, ignored -> McpSchema.ClientCapabilities.builder()) + .elicitation(); + } + + this.capabilitiesPerClient = capsPerClient.entrySet() + .stream() + .collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().build())); + } + + private List scan(Class beanClass) { + List foundAnnotations = new ArrayList<>(); + + // Scan all methods in the bean class + ReflectionUtils.doWithMethods(beanClass, method -> { + for (var annotationType : CLIENT_MCP_ANNOTATIONS) { + Annotation annotation = AnnotationUtils.findAnnotation(method, annotationType); + if (annotation != null) { + foundAnnotations.add(annotation); + } + } + }); + return foundAnnotations; + } + + @Override + public void afterSingletonsInstantiated() { + // Use a set in case multiple handlers are registered in the same bean + Map, Set> beansByAnnotation = new HashMap<>(); + for (var annotation : CLIENT_MCP_ANNOTATIONS) { + beansByAnnotation.put(annotation, new HashSet<>()); + } + + for (var beanName : this.allAnnotatedBeans) { + var bean = this.beanFactory.getBean(beanName); + var annotations = scan(bean.getClass()); + for (var annotation : annotations) { + beansByAnnotation.computeIfAbsent(annotation.annotationType(), k -> new HashSet<>()).add(bean); + } + } + + var samplingSpecs = SyncMcpAnnotationProviders + .samplingSpecifications(new ArrayList<>(beansByAnnotation.get(McpSampling.class))); + for (var samplingSpec : samplingSpecs) { + for (var client : samplingSpec.clients()) { + this.samplingHandlers.put(client, samplingSpec.samplingHandler()); + } + } + + var elicitationSpecs = SyncMcpAnnotationProviders + .elicitationSpecifications(new ArrayList<>(beansByAnnotation.get(McpElicitation.class))); + for (var elicitationSpec : elicitationSpecs) { + for (var client : elicitationSpec.clients()) { + this.elicitationHandlers.put(client, elicitationSpec.elicitationHandler()); + } + } + + var loggingSpecs = SyncMcpAnnotationProviders + .loggingSpecifications(new ArrayList<>(beansByAnnotation.get(McpLogging.class))); + for (var loggingSpec : loggingSpecs) { + for (var client : loggingSpec.clients()) { + this.loggingHandlers.computeIfAbsent(client, k -> new ArrayList<>()).add(loggingSpec.loggingHandler()); + } + } + + var progressSpecs = SyncMcpAnnotationProviders + .progressSpecifications(new ArrayList<>(beansByAnnotation.get(McpProgress.class))); + for (var progressSpec : progressSpecs) { + for (var client : progressSpec.clients()) { + this.progressHandlers.computeIfAbsent(client, k -> new ArrayList<>()) + .add(progressSpec.progressHandler()); + } + } + + var toolsListChangedSpecs = SyncMcpAnnotationProviders + .toolListChangedSpecifications(new ArrayList<>(beansByAnnotation.get(McpToolListChanged.class))); + for (var toolsListChangedSpec : toolsListChangedSpecs) { + for (var client : toolsListChangedSpec.clients()) { + this.toolListChangedHandlers.computeIfAbsent(client, k -> new ArrayList<>()) + .add(toolsListChangedSpec.toolListChangeHandler()); + } + } + + var promptListChangedSpecs = SyncMcpAnnotationProviders + .promptListChangedSpecifications(new ArrayList<>(beansByAnnotation.get(McpPromptListChanged.class))); + for (var promptListChangedSpec : promptListChangedSpecs) { + for (var client : promptListChangedSpec.clients()) { + this.promptListChangedHandlers.computeIfAbsent(client, k -> new ArrayList<>()) + .add(promptListChangedSpec.promptListChangeHandler()); + } + } + + var resourceListChangedSpecs = SyncMcpAnnotationProviders + .resourceListChangedSpecifications(new ArrayList<>(beansByAnnotation.get(McpResourceListChanged.class))); + for (var resourceListChangedSpec : resourceListChangedSpecs) { + for (var client : resourceListChangedSpec.clients()) { + this.resourceListChangedHandlers.computeIfAbsent(client, k -> new ArrayList<>()) + .add(resourceListChangedSpec.resourceListChangeHandler()); + } + } + + } + +} diff --git a/mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistryTests.java b/mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistryTests.java new file mode 100644 index 00000000000..8986434b3fb --- /dev/null +++ b/mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistryTests.java @@ -0,0 +1,447 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mcp.annotation.spring; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import io.modelcontextprotocol.spec.McpSchema; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.springaicommunity.mcp.annotation.McpElicitation; +import org.springaicommunity.mcp.annotation.McpLogging; +import org.springaicommunity.mcp.annotation.McpProgress; +import org.springaicommunity.mcp.annotation.McpPromptListChanged; +import org.springaicommunity.mcp.annotation.McpResourceListChanged; +import org.springaicommunity.mcp.annotation.McpSampling; +import org.springaicommunity.mcp.annotation.McpToolListChanged; + +import org.springframework.beans.factory.support.BeanDefinitionBuilder; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.fail; + +class ClientMcpSyncHandlersRegistryTests { + + @Test + void getCapabilitiesPerClient() { + var registry = new ClientMcpSyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder.genericBeanDefinition(ClientCapabilitiesConfiguration.class).getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + + assertThat(registry.getCapabilities("client-1").elicitation()).isNotNull(); + assertThat(registry.getCapabilities("client-2").elicitation()).isNotNull(); + assertThat(registry.getCapabilities("client-3").elicitation()).isNotNull(); + + assertThat(registry.getCapabilities("client-1").sampling()).isNotNull(); + assertThat(registry.getCapabilities("client-2").sampling()).isNull(); + assertThat(registry.getCapabilities("client-3").sampling()).isNull(); + + assertThat(registry.getCapabilities("client-1").roots()).isNull(); + assertThat(registry.getCapabilities("client-2").roots()).isNull(); + assertThat(registry.getCapabilities("client-3").roots()).isNull(); + + assertThat(registry.getCapabilities("client-1").experimental()).isNull(); + assertThat(registry.getCapabilities("client-2").experimental()).isNull(); + assertThat(registry.getCapabilities("client-3").experimental()).isNull(); + + assertThat(registry.getCapabilities("client-unknown").sampling()).isNull(); + assertThat(registry.getCapabilities("client-unknown").elicitation()).isNull(); + assertThat(registry.getCapabilities("client-unknown").roots()).isNull(); + } + + @Test + void twoHandlersElicitation() { + var registry = new ClientMcpSyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("firstConfig", + BeanDefinitionBuilder.genericBeanDefinition(DoubleElicitationHandlerConfiguration.First.class) + .getBeanDefinition()); + beanFactory.registerBeanDefinition("secondConfig", + BeanDefinitionBuilder.genericBeanDefinition(DoubleElicitationHandlerConfiguration.Second.class) + .getBeanDefinition()); + assertThatThrownBy(() -> registry.postProcessBeanFactory(beanFactory)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Found 2 elicitation handlers for client [client-1], found in bean with names [firstConfig, secondConfig]. Only one @McpElicitation handler is allowed per client"); + } + + @Test + void twoHandlersSameBeanElicitation() { + var registry = new ClientMcpSyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("elicitationConfig", + BeanDefinitionBuilder.genericBeanDefinition(DoubleElicitationHandlerConfiguration.TwoHandlers.class) + .getBeanDefinition()); + assertThatThrownBy(() -> registry.postProcessBeanFactory(beanFactory)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Found 2 elicitation handlers for client [client-1], found in bean with names [elicitationConfig]. Only one @McpElicitation handler is allowed per client"); + } + + @Test + void twoHandlersSampling() { + var registry = new ClientMcpSyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("firstConfig", + BeanDefinitionBuilder.genericBeanDefinition(DoubleSamplingHandlerConfiguration.First.class) + .getBeanDefinition()); + beanFactory.registerBeanDefinition("secondConfig", + BeanDefinitionBuilder.genericBeanDefinition(DoubleSamplingHandlerConfiguration.Second.class) + .getBeanDefinition()); + assertThatThrownBy(() -> registry.postProcessBeanFactory(beanFactory)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Found 2 sampling handlers for client [client-1], found in bean with names [firstConfig, secondConfig]. Only one @McpSampling handler is allowed per client"); + } + + @Test + void twoHandlersSameBeanSampling() { + var registry = new ClientMcpSyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("samplingConfig", + BeanDefinitionBuilder.genericBeanDefinition(DoubleSamplingHandlerConfiguration.TwoHandlers.class) + .getBeanDefinition()); + assertThatThrownBy(() -> registry.postProcessBeanFactory(beanFactory)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Found 2 sampling handlers for client [client-1], found in bean with names [samplingConfig]. Only one @McpSampling handler is allowed per client"); + } + + @Test + void elicitation() { + var registry = new ClientMcpSyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder.genericBeanDefinition(HandlersConfiguration.class).getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + registry.afterSingletonsInstantiated(); + + var request = McpSchema.ElicitRequest.builder().message("Elicit request").progressToken("token-12345").build(); + var response = registry.handleElicitation("client-1", request); + + assertThat(response.content()).hasSize(1).containsEntry("message", "Elicit request"); + assertThat(response.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + } + + @Test + void sampling() { + var registry = new ClientMcpSyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder.genericBeanDefinition(HandlersConfiguration.class).getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + registry.afterSingletonsInstantiated(); + + var request = McpSchema.CreateMessageRequest.builder() + .messages(List + .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Tell a joke")))) + .build(); + var response = registry.handleSampling("client-1", request); + + assertThat(response.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(response.model()).isEqualTo("testgpt-42.5"); + McpSchema.TextContent content = (McpSchema.TextContent) response.content(); + assertThat(content.text()).isEqualTo("Tell a joke"); + } + + @Test + void logging() { + var registry = new ClientMcpSyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder.genericBeanDefinition(HandlersConfiguration.class).getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + registry.afterSingletonsInstantiated(); + var handlers = beanFactory.getBean(HandlersConfiguration.class); + + var logRequest = McpSchema.LoggingMessageNotification.builder() + .data("Hello world") + .logger("log-me") + .level(McpSchema.LoggingLevel.INFO) + .build(); + + registry.handleLogging("client-1", logRequest); + assertThat(handlers.getCalls()).hasSize(2) + .containsExactlyInAnyOrder(new HandlersConfiguration.Call("handleLoggingMessage", logRequest), + new HandlersConfiguration.Call("handleLoggingMessageAgain", logRequest)); + } + + @Test + void progress() { + var registry = new ClientMcpSyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder.genericBeanDefinition(HandlersConfiguration.class).getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + registry.afterSingletonsInstantiated(); + var handlers = beanFactory.getBean(HandlersConfiguration.class); + + var progressRequest = new McpSchema.ProgressNotification("progress-12345", 13.37, 100., "progressing ..."); + + registry.handleProgress("client-1", progressRequest); + assertThat(handlers.getCalls()).hasSize(2) + .containsExactlyInAnyOrder(new HandlersConfiguration.Call("handleProgress", progressRequest), + new HandlersConfiguration.Call("handleProgressAgain", progressRequest)); + } + + @Test + void toolListChanged() { + var registry = new ClientMcpSyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder.genericBeanDefinition(HandlersConfiguration.class).getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + registry.afterSingletonsInstantiated(); + var handlers = beanFactory.getBean(HandlersConfiguration.class); + + List updatedTools = List.of(McpSchema.Tool.builder().name("tool-1").build(), + McpSchema.Tool.builder().name("tool-2").build()); + + registry.handleToolListChanged("client-1", updatedTools); + assertThat(handlers.getCalls()).hasSize(2) + .containsExactlyInAnyOrder(new HandlersConfiguration.Call("handleToolListChanged", updatedTools), + new HandlersConfiguration.Call("handleToolListChangedAgain", updatedTools)); + } + + @Test + void promptListChanged() { + var registry = new ClientMcpSyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder.genericBeanDefinition(HandlersConfiguration.class).getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + registry.afterSingletonsInstantiated(); + var handlers = beanFactory.getBean(HandlersConfiguration.class); + + List updatedTools = List.of( + new McpSchema.Prompt("prompt-1", "a test prompt", Collections.emptyList()), + new McpSchema.Prompt("prompt-2", "another test prompt", Collections.emptyList())); + + registry.handlePromptListChanged("client-1", updatedTools); + assertThat(handlers.getCalls()).hasSize(2) + .containsExactlyInAnyOrder(new HandlersConfiguration.Call("handlePromptListChanged", updatedTools), + new HandlersConfiguration.Call("handlePromptListChangedAgain", updatedTools)); + } + + @Test + void resourceListChanged() { + var registry = new ClientMcpSyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder.genericBeanDefinition(HandlersConfiguration.class).getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + registry.afterSingletonsInstantiated(); + var handlers = beanFactory.getBean(HandlersConfiguration.class); + + List updatedResources = List.of( + McpSchema.Resource.builder().name("resource-1").uri("file:///resource/1").build(), + McpSchema.Resource.builder().name("resource-2").uri("file:///resource/2").build()); + + registry.handleResourceListChanged("client-1", updatedResources); + assertThat(handlers.getCalls()).hasSize(2) + .containsExactlyInAnyOrder(new HandlersConfiguration.Call("handleResourceListChanged", updatedResources), + new HandlersConfiguration.Call("handleResourceListChangedAgain", updatedResources)); + } + + @Test + @Disabled + void missingHandler() { + fail("TODO"); + } + + static class ClientCapabilitiesConfiguration { + + @McpElicitation(clients = { "client-1", "client-2" }) + public McpSchema.ElicitResult elicitationHandler1(McpSchema.ElicitRequest request) { + return null; + } + + @McpElicitation(clients = { "client-3" }) + public McpSchema.ElicitResult elicitationHandler2(McpSchema.ElicitRequest request) { + return null; + } + + @McpSampling(clients = { "client-1" }) + public McpSchema.CreateMessageResult samplingHandler(McpSchema.CreateMessageRequest request) { + return null; + } + + } + + static class DoubleElicitationHandlerConfiguration { + + static class First { + + @McpElicitation(clients = { "client-1" }) + public McpSchema.ElicitResult elicitationHandler1(McpSchema.ElicitRequest request) { + return null; + } + + } + + static class Second { + + @McpElicitation(clients = { "client-1" }) + public McpSchema.ElicitResult elicitationHandler2(McpSchema.ElicitRequest request) { + return null; + } + + } + + static class TwoHandlers { + + @McpElicitation(clients = { "client-1" }) + public McpSchema.ElicitResult elicitationHandler1(McpSchema.ElicitRequest request) { + return null; + } + + @McpElicitation(clients = { "client-1" }) + public McpSchema.ElicitResult elicitationHandler2(McpSchema.ElicitRequest request) { + return null; + } + + } + + } + + static class DoubleSamplingHandlerConfiguration { + + static class First { + + @McpSampling(clients = { "client-1" }) + public McpSchema.CreateMessageResult samplingHandler1(McpSchema.CreateMessageRequest request) { + return null; + } + + } + + static class Second { + + @McpSampling(clients = { "client-1" }) + public McpSchema.CreateMessageResult samplingHandler2(McpSchema.CreateMessageRequest request) { + return null; + } + + } + + static class TwoHandlers { + + @McpSampling(clients = { "client-1" }) + public McpSchema.CreateMessageResult samplingHandler1(McpSchema.CreateMessageRequest request) { + return null; + } + + @McpSampling(clients = { "client-1" }) + public McpSchema.CreateMessageResult samplingHandler2(McpSchema.CreateMessageRequest request) { + return null; + } + + } + + } + + static class HandlersConfiguration { + + private final List calls = new ArrayList<>(); + + HandlersConfiguration() { + } + + List getCalls() { + return Collections.unmodifiableList(this.calls); + } + + @McpElicitation(clients = { "client-1" }) + McpSchema.ElicitResult elicitationHandler(McpSchema.ElicitRequest request) { + return McpSchema.ElicitResult.builder() + .message(McpSchema.ElicitResult.Action.ACCEPT) + .content(Map.of("message", request.message())) + .build(); + } + + @McpSampling(clients = { "client-1" }) + McpSchema.CreateMessageResult samplingHandler(McpSchema.CreateMessageRequest request) { + return McpSchema.CreateMessageResult.builder() + .message(((McpSchema.TextContent) request.messages().get(0).content()).text()) + .model("testgpt-42.5") + .build(); + } + + @McpLogging(clients = { "client-1" }) + void handleLoggingMessage(McpSchema.LoggingMessageNotification notification) { + this.calls.add(new Call("handleLoggingMessage", notification)); + } + + @McpLogging(clients = { "client-1" }) + void handleLoggingMessageAgain(McpSchema.LoggingMessageNotification notification) { + this.calls.add(new Call("handleLoggingMessageAgain", notification)); + } + + @McpProgress(clients = { "client-1" }) + void handleProgress(McpSchema.ProgressNotification notification) { + this.calls.add(new Call("handleProgress", notification)); + } + + @McpProgress(clients = { "client-1" }) + void handleProgressAgain(McpSchema.ProgressNotification notification) { + this.calls.add(new Call("handleProgressAgain", notification)); + } + + @McpToolListChanged(clients = { "client-1" }) + void handleToolListChanged(List updatedTools) { + this.calls.add(new Call("handleToolListChanged", updatedTools)); + } + + @McpToolListChanged(clients = { "client-1" }) + void handleToolListChangedAgain(List updatedTools) { + this.calls.add(new Call("handleToolListChangedAgain", updatedTools)); + } + + @McpPromptListChanged(clients = { "client-1" }) + void handlePromptListChanged(List updatedPrompts) { + this.calls.add(new Call("handlePromptListChanged", updatedPrompts)); + } + + @McpPromptListChanged(clients = { "client-1" }) + void handlePromptListChangedAgain(List updatedPrompts) { + this.calls.add(new Call("handlePromptListChangedAgain", updatedPrompts)); + } + + @McpResourceListChanged(clients = { "client-1" }) + void handleResourceListChanged(List updatedResources) { + this.calls.add(new Call("handleResourceListChanged", updatedResources)); + } + + @McpResourceListChanged(clients = { "client-1" }) + void handleResourceListChangedAgain(List updatedResources) { + this.calls.add(new Call("handleResourceListChangedAgain", updatedResources)); + } + + // Record calls made to this object + record Call(String name, Object callRequest) { + } + + } + +} From c573cb0f137dcd2d4c400748e8fa2ffff08564a4 Mon Sep 17 00:00:00 2001 From: Daniel Garnier-Moiroux Date: Tue, 4 Nov 2025 19:26:07 +0100 Subject: [PATCH 02/11] Add ClientMcpAsyncHandlersRegistry Signed-off-by: Daniel Garnier-Moiroux --- .../McpClientAutoConfiguration.java | 41 +- .../McpAsyncAnnotationCustomizer.java | 181 ------- ...entAnnotationScannerAutoConfiguration.java | 10 +- ...SpecificationFactoryAutoConfiguration.java | 91 ---- ...ot.autoconfigure.AutoConfiguration.imports | 1 - ...lientListChangedAnnotationsScanningIT.java | 3 +- .../McpToolsConfigurationTests.java | 2 - .../StreamableMcpAnnotations2IT.java | 4 +- .../StreamableMcpAnnotationsIT.java | 4 +- .../StreamableMcpAnnotationsManualIT.java | 3 +- .../StreamableMcpAnnotationsWithLLMIT.java | 5 +- .../ClientMcpAsyncHandlersRegistry.java | 356 ++++++++++++++ .../ClientMcpAsyncHandlersRegistryTests.java | 459 ++++++++++++++++++ 13 files changed, 849 insertions(+), 311 deletions(-) delete mode 100644 auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpAsyncAnnotationCustomizer.java delete mode 100644 auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientSpecificationFactoryAutoConfiguration.java create mode 100644 mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistry.java create mode 100644 mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistryTests.java diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpClientAutoConfiguration.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpClientAutoConfiguration.java index bbcc95689f7..7d2f33fbf1c 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpClientAutoConfiguration.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpClientAutoConfiguration.java @@ -23,16 +23,9 @@ import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.McpSyncClient; import io.modelcontextprotocol.spec.McpSchema; -import org.springaicommunity.mcp.method.changed.prompt.AsyncPromptListChangedSpecification; -import org.springaicommunity.mcp.method.changed.resource.AsyncResourceListChangedSpecification; -import org.springaicommunity.mcp.method.changed.tool.AsyncToolListChangedSpecification; -import org.springaicommunity.mcp.method.elicitation.AsyncElicitationSpecification; -import org.springaicommunity.mcp.method.logging.AsyncLoggingSpecification; -import org.springaicommunity.mcp.method.progress.AsyncProgressSpecification; -import org.springaicommunity.mcp.method.sampling.AsyncSamplingSpecification; +import org.springframework.ai.mcp.annotation.spring.ClientMcpAsyncHandlersRegistry; import org.springframework.ai.mcp.annotation.spring.ClientMcpSyncHandlersRegistry; -import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpAsyncAnnotationCustomizer; import org.springframework.ai.mcp.client.common.autoconfigure.configurer.McpAsyncClientConfigurer; import org.springframework.ai.mcp.client.common.autoconfigure.configurer.McpSyncClientConfigurer; import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpClientCommonProperties; @@ -237,7 +230,8 @@ McpSyncClientConfigurer mcpSyncClientConfigurer(ObjectProvider mcpAsyncClients(McpAsyncClientConfigurer mcpAsyncClientConfigurer, McpClientCommonProperties commonProperties, - ObjectProvider> transportsProvider) { + ObjectProvider> transportsProvider, + ClientMcpAsyncHandlersRegistry clientMcpAsyncHandlersRegistry) { List mcpAsyncClients = new ArrayList<>(); @@ -252,7 +246,22 @@ public List mcpAsyncClients(McpAsyncClientConfigurer mcpAsyncCli McpClient.AsyncSpec spec = McpClient.async(namedTransport.transport()) .clientInfo(clientInfo) - .requestTimeout(commonProperties.getRequestTimeout()); + .requestTimeout(commonProperties.getRequestTimeout()) + .sampling(samplingRequest -> clientMcpAsyncHandlersRegistry.handleSampling(namedTransport.name(), + samplingRequest)) + .elicitation(elicitationRequest -> clientMcpAsyncHandlersRegistry + .handleElicitation(namedTransport.name(), elicitationRequest)) + .loggingConsumer(loggingMessageNotification -> clientMcpAsyncHandlersRegistry + .handleLogging(namedTransport.name(), loggingMessageNotification)) + .progressConsumer(progressNotification -> clientMcpAsyncHandlersRegistry + .handleProgress(namedTransport.name(), progressNotification)) + .toolsChangeConsumer(newTools -> clientMcpAsyncHandlersRegistry + .handleToolListChanged(namedTransport.name(), newTools)) + .promptsChangeConsumer(newPrompts -> clientMcpAsyncHandlersRegistry + .handlePromptListChanged(namedTransport.name(), newPrompts)) + .resourcesChangeConsumer(newResources -> clientMcpAsyncHandlersRegistry + .handleResourceListChanged(namedTransport.name(), newResources)) + .capabilities(clientMcpAsyncHandlersRegistry.getCapabilities(namedTransport.name())); spec = mcpAsyncClientConfigurer.configure(namedTransport.name(), spec); @@ -282,18 +291,6 @@ McpAsyncClientConfigurer mcpAsyncClientConfigurer(ObjectProvider loggingSpecs, - List samplingSpecs, List elicitationSpecs, - List progressSpecs, - List toolListChangedSpecs, - List resourceListChangedSpecs, - List promptListChangedSpecs) { - return new McpAsyncAnnotationCustomizer(samplingSpecs, loggingSpecs, elicitationSpecs, progressSpecs, - toolListChangedSpecs, resourceListChangedSpecs, promptListChangedSpecs); - } - /** * Record class that implements {@link AutoCloseable} to ensure proper cleanup of MCP * clients. diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpAsyncAnnotationCustomizer.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpAsyncAnnotationCustomizer.java deleted file mode 100644 index 292942a2d63..00000000000 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpAsyncAnnotationCustomizer.java +++ /dev/null @@ -1,181 +0,0 @@ -/* - * Copyright 2025-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.mcp.client.common.autoconfigure.annotations; - -import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.stream.Stream; - -import io.modelcontextprotocol.client.McpClient.AsyncSpec; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springaicommunity.mcp.method.changed.prompt.AsyncPromptListChangedSpecification; -import org.springaicommunity.mcp.method.changed.resource.AsyncResourceListChangedSpecification; -import org.springaicommunity.mcp.method.changed.tool.AsyncToolListChangedSpecification; -import org.springaicommunity.mcp.method.elicitation.AsyncElicitationSpecification; -import org.springaicommunity.mcp.method.logging.AsyncLoggingSpecification; -import org.springaicommunity.mcp.method.progress.AsyncProgressSpecification; -import org.springaicommunity.mcp.method.sampling.AsyncSamplingSpecification; - -import org.springframework.ai.mcp.customizer.McpAsyncClientCustomizer; -import org.springframework.util.CollectionUtils; - -/** - * @author Christian Tzolov - */ -public class McpAsyncAnnotationCustomizer implements McpAsyncClientCustomizer { - - private static final Logger logger = LoggerFactory.getLogger(McpAsyncAnnotationCustomizer.class); - - private final List asyncSamplingSpecifications; - - private final List asyncLoggingSpecifications; - - private final List asyncElicitationSpecifications; - - private final List asyncProgressSpecifications; - - private final List asyncToolListChangedSpecifications; - - private final List asyncResourceListChangedSpecifications; - - private final List asyncPromptListChangedSpecifications; - - // Tracking registered specifications per client - private final Map clientElicitationSpecs = new ConcurrentHashMap<>(); - - private final Map clientSamplingSpecs = new ConcurrentHashMap<>(); - - public McpAsyncAnnotationCustomizer(List asyncSamplingSpecifications, - List asyncLoggingSpecifications, - List asyncElicitationSpecifications, - List asyncProgressSpecifications, - List asyncToolListChangedSpecifications, - List asyncResourceListChangedSpecifications, - List asyncPromptListChangedSpecifications) { - - this.asyncSamplingSpecifications = asyncSamplingSpecifications; - this.asyncLoggingSpecifications = asyncLoggingSpecifications; - this.asyncElicitationSpecifications = asyncElicitationSpecifications; - this.asyncProgressSpecifications = asyncProgressSpecifications; - this.asyncToolListChangedSpecifications = asyncToolListChangedSpecifications; - this.asyncResourceListChangedSpecifications = asyncResourceListChangedSpecifications; - this.asyncPromptListChangedSpecifications = asyncPromptListChangedSpecifications; - } - - @Override - public void customize(String name, AsyncSpec clientSpec) { - - if (!CollectionUtils.isEmpty(this.asyncElicitationSpecifications)) { - this.asyncElicitationSpecifications.forEach(elicitationSpec -> { - Stream.of(elicitationSpec.clients()).forEach(clientId -> { - if (clientId.equalsIgnoreCase(name)) { - - // Check if client already has an elicitation spec - if (this.clientElicitationSpecs.containsKey(name)) { - throw new IllegalArgumentException("Client '" + name - + "' already has an elicitationSpec registered. Only one elicitationSpec is allowed per client."); - } - - this.clientElicitationSpecs.put(name, Boolean.TRUE); - clientSpec.elicitation(elicitationSpec.elicitationHandler()); - - logger.info("Registered elicitationSpec for client '{}'.", name); - - } - }); - }); - } - - if (!CollectionUtils.isEmpty(this.asyncSamplingSpecifications)) { - this.asyncSamplingSpecifications.forEach(samplingSpec -> { - Stream.of(samplingSpec.clients()).forEach(clientId -> { - if (clientId.equalsIgnoreCase(name)) { - - // Check if client already has a sampling spec - if (this.clientSamplingSpecs.containsKey(name)) { - throw new IllegalArgumentException("Client '" + name - + "' already has a samplingSpec registered. Only one samplingSpec is allowed per client."); - } - this.clientSamplingSpecs.put(name, Boolean.TRUE); - - clientSpec.sampling(samplingSpec.samplingHandler()); - - logger.info("Registered samplingSpec for client '{}'.", name); - } - }); - }); - } - - if (!CollectionUtils.isEmpty(this.asyncLoggingSpecifications)) { - this.asyncLoggingSpecifications.forEach(loggingSpec -> { - Stream.of(loggingSpec.clients()).forEach(clientId -> { - if (clientId.equalsIgnoreCase(name)) { - clientSpec.loggingConsumer(loggingSpec.loggingHandler()); - logger.info("Registered loggingSpec for client '{}'.", name); - } - }); - }); - } - - if (!CollectionUtils.isEmpty(this.asyncProgressSpecifications)) { - this.asyncProgressSpecifications.forEach(progressSpec -> { - Stream.of(progressSpec.clients()).forEach(clientId -> { - if (clientId.equalsIgnoreCase(name)) { - clientSpec.progressConsumer(progressSpec.progressHandler()); - logger.info("Registered progressSpec for client '{}'.", name); - } - }); - }); - } - - if (!CollectionUtils.isEmpty(this.asyncToolListChangedSpecifications)) { - this.asyncToolListChangedSpecifications.forEach(toolListChangedSpec -> { - Stream.of(toolListChangedSpec.clients()).forEach(clientId -> { - if (clientId.equalsIgnoreCase(name)) { - clientSpec.toolsChangeConsumer(toolListChangedSpec.toolListChangeHandler()); - logger.info("Registered toolListChangedSpec for client '{}'.", name); - } - }); - }); - } - - if (!CollectionUtils.isEmpty(this.asyncResourceListChangedSpecifications)) { - this.asyncResourceListChangedSpecifications.forEach(resourceListChangedSpec -> { - Stream.of(resourceListChangedSpec.clients()).forEach(clientId -> { - if (clientId.equalsIgnoreCase(name)) { - clientSpec.resourcesChangeConsumer(resourceListChangedSpec.resourceListChangeHandler()); - logger.info("Registered resourceListChangedSpec for client '{}'.", name); - } - }); - }); - } - - if (!CollectionUtils.isEmpty(this.asyncPromptListChangedSpecifications)) { - this.asyncPromptListChangedSpecifications.forEach(promptListChangedSpec -> { - Stream.of(promptListChangedSpec.clients()).forEach(clientId -> { - if (clientId.equalsIgnoreCase(name)) { - clientSpec.promptsChangeConsumer(promptListChangedSpec.promptListChangeHandler()); - logger.info("Registered promptListChangedSpec for client '{}'.", name); - } - }); - }); - } - } - -} diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientAnnotationScannerAutoConfiguration.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientAnnotationScannerAutoConfiguration.java index 81b19e2ce2f..47ad28101b0 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientAnnotationScannerAutoConfiguration.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientAnnotationScannerAutoConfiguration.java @@ -27,6 +27,7 @@ import org.springaicommunity.mcp.annotation.McpSampling; import org.springaicommunity.mcp.annotation.McpToolListChanged; +import org.springframework.ai.mcp.annotation.spring.ClientMcpAsyncHandlersRegistry; import org.springframework.ai.mcp.annotation.spring.ClientMcpSyncHandlersRegistry; import org.springframework.ai.mcp.annotation.spring.scan.AbstractAnnotatedMethodBeanFactoryInitializationAotProcessor; import org.springframework.ai.mcp.annotation.spring.scan.AbstractAnnotatedMethodBeanPostProcessor; @@ -64,10 +65,17 @@ public class McpClientAnnotationScannerAutoConfiguration { @ConditionalOnMissingBean @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "SYNC", matchIfMissing = true) - public ClientMcpSyncHandlersRegistry mcpHandlersRegistry() { + public ClientMcpSyncHandlersRegistry clientMcpSyncHandlersRegistry() { return new ClientMcpSyncHandlersRegistry(); } + @Bean + @ConditionalOnMissingBean + @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "ASYNC") + public ClientMcpAsyncHandlersRegistry clientMcpAsyncHandlersRegistry() { + return new ClientMcpAsyncHandlersRegistry(); + } + @Bean @ConditionalOnMissingBean public ClientMcpAnnotatedBeans clientAnnotatedBeans() { diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientSpecificationFactoryAutoConfiguration.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientSpecificationFactoryAutoConfiguration.java deleted file mode 100644 index 52b88ab740a..00000000000 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientSpecificationFactoryAutoConfiguration.java +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Copyright 2025-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.mcp.client.common.autoconfigure.annotations; - -import java.util.List; - -import org.springaicommunity.mcp.annotation.McpLogging; -import org.springaicommunity.mcp.method.changed.prompt.AsyncPromptListChangedSpecification; -import org.springaicommunity.mcp.method.changed.resource.AsyncResourceListChangedSpecification; -import org.springaicommunity.mcp.method.changed.tool.AsyncToolListChangedSpecification; -import org.springaicommunity.mcp.method.elicitation.AsyncElicitationSpecification; -import org.springaicommunity.mcp.method.logging.AsyncLoggingSpecification; -import org.springaicommunity.mcp.method.progress.AsyncProgressSpecification; -import org.springaicommunity.mcp.method.sampling.AsyncSamplingSpecification; - -import org.springframework.ai.mcp.annotation.spring.AsyncMcpAnnotationProviders; -import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientAnnotationScannerAutoConfiguration.ClientMcpAnnotatedBeans; -import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpClientCommonProperties; -import org.springframework.boot.autoconfigure.AutoConfiguration; -import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; -import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; -import org.springframework.context.annotation.Bean; -import org.springframework.context.annotation.Configuration; - -/** - * @author Christian Tzolov - * @author Fu Jian - */ -@AutoConfiguration(after = McpClientAnnotationScannerAutoConfiguration.class) -@ConditionalOnClass(McpLogging.class) -@ConditionalOnProperty(prefix = McpClientAnnotationScannerProperties.CONFIG_PREFIX, name = "enabled", - havingValue = "true", matchIfMissing = true) -public class McpClientSpecificationFactoryAutoConfiguration { - - @Configuration(proxyBeanMethods = false) - @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "ASYNC") - static class AsyncClientSpecificationConfiguration { - - @Bean - List loggingSpecs(ClientMcpAnnotatedBeans beanRegistry) { - return AsyncMcpAnnotationProviders.loggingSpecifications(beanRegistry.getAllAnnotatedBeans()); - } - - @Bean - List samplingSpecs(ClientMcpAnnotatedBeans beanRegistry) { - return AsyncMcpAnnotationProviders.samplingSpecifications(beanRegistry.getAllAnnotatedBeans()); - } - - @Bean - List elicitationSpecs(ClientMcpAnnotatedBeans beanRegistry) { - return AsyncMcpAnnotationProviders.elicitationSpecifications(beanRegistry.getAllAnnotatedBeans()); - } - - @Bean - List progressSpecs(ClientMcpAnnotatedBeans beanRegistry) { - return AsyncMcpAnnotationProviders.progressSpecifications(beanRegistry.getAllAnnotatedBeans()); - } - - @Bean - List asyncToolListChangedSpecs(ClientMcpAnnotatedBeans beanRegistry) { - return AsyncMcpAnnotationProviders.toolListChangedSpecifications(beanRegistry.getAllAnnotatedBeans()); - } - - @Bean - List asyncResourceListChangedSpecs( - ClientMcpAnnotatedBeans beanRegistry) { - return AsyncMcpAnnotationProviders.resourceListChangedSpecifications(beanRegistry.getAllAnnotatedBeans()); - } - - @Bean - List asyncPromptListChangedSpecs(ClientMcpAnnotatedBeans beanRegistry) { - return AsyncMcpAnnotationProviders.promptListChangedSpecifications(beanRegistry.getAllAnnotatedBeans()); - } - - } - -} diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports index 120dd1beab9..38cd4021d5c 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -17,5 +17,4 @@ org.springframework.ai.mcp.client.common.autoconfigure.StdioTransportAutoConfiguration org.springframework.ai.mcp.client.common.autoconfigure.McpClientAutoConfiguration org.springframework.ai.mcp.client.common.autoconfigure.McpToolCallbackAutoConfiguration -org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientSpecificationFactoryAutoConfiguration org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientAnnotationScannerAutoConfiguration diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientListChangedAnnotationsScanningIT.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientListChangedAnnotationsScanningIT.java index d00e3cc6b35..04bd0cf7e73 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientListChangedAnnotationsScanningIT.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientListChangedAnnotationsScanningIT.java @@ -45,8 +45,7 @@ public class McpClientListChangedAnnotationsScanningIT { private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(McpClientAnnotationScannerAutoConfiguration.class, - McpClientSpecificationFactoryAutoConfiguration.class)); + .withConfiguration(AutoConfigurations.of(McpClientAnnotationScannerAutoConfiguration.class)); @ParameterizedTest @ValueSource(strings = { "SYNC", "ASYNC" }) diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/McpToolsConfigurationTests.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/McpToolsConfigurationTests.java index 674f2663a5b..2dc90ca4d88 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/McpToolsConfigurationTests.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/McpToolsConfigurationTests.java @@ -31,7 +31,6 @@ import org.springframework.ai.mcp.client.common.autoconfigure.McpClientAutoConfiguration; import org.springframework.ai.mcp.client.common.autoconfigure.McpToolCallbackAutoConfiguration; import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientAnnotationScannerAutoConfiguration; -import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientSpecificationFactoryAutoConfiguration; import org.springframework.ai.model.chat.client.autoconfigure.ChatClientAutoConfiguration; import org.springframework.ai.model.tool.autoconfigure.ToolCallingAutoConfiguration; import org.springframework.ai.tool.ToolCallback; @@ -74,7 +73,6 @@ void mcpClientSupportsSampling() { McpToolCallbackAutoConfiguration.class, McpClientAutoConfiguration.class, McpClientAnnotationScannerAutoConfiguration.class, - McpClientSpecificationFactoryAutoConfiguration.class, // Tool callbacks ToolCallingAutoConfiguration.class, // Chat client for sampling diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotations2IT.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotations2IT.java index 35f1d67937e..a7a56ba9e5a 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotations2IT.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotations2IT.java @@ -72,7 +72,6 @@ import org.springframework.ai.mcp.client.common.autoconfigure.McpClientAutoConfiguration; import org.springframework.ai.mcp.client.common.autoconfigure.McpToolCallbackAutoConfiguration; import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientAnnotationScannerAutoConfiguration; -import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientSpecificationFactoryAutoConfiguration; import org.springframework.ai.mcp.client.webflux.autoconfigure.StreamableHttpWebFluxTransportAutoConfiguration; import org.springframework.ai.mcp.server.common.autoconfigure.McpServerAutoConfiguration; import org.springframework.ai.mcp.server.common.autoconfigure.McpServerObjectMapperAutoConfiguration; @@ -109,8 +108,7 @@ public class StreamableMcpAnnotations2IT { private final ApplicationContextRunner clientApplicationContext = new ApplicationContextRunner() .withConfiguration(AutoConfigurations.of(McpToolCallbackAutoConfiguration.class, McpClientAutoConfiguration.class, StreamableHttpWebFluxTransportAutoConfiguration.class, - McpClientAnnotationScannerAutoConfiguration.class, - McpClientSpecificationFactoryAutoConfiguration.class)); + McpClientAnnotationScannerAutoConfiguration.class)); @Test void clientServerCapabilities() { diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotationsIT.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotationsIT.java index 4a0da8b3ac7..cd34bac36c6 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotationsIT.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotationsIT.java @@ -73,7 +73,6 @@ import org.springframework.ai.mcp.client.common.autoconfigure.McpClientAutoConfiguration; import org.springframework.ai.mcp.client.common.autoconfigure.McpToolCallbackAutoConfiguration; import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientAnnotationScannerAutoConfiguration; -import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientSpecificationFactoryAutoConfiguration; import org.springframework.ai.mcp.client.webflux.autoconfigure.StreamableHttpWebFluxTransportAutoConfiguration; import org.springframework.ai.mcp.server.common.autoconfigure.McpServerAutoConfiguration; import org.springframework.ai.mcp.server.common.autoconfigure.McpServerObjectMapperAutoConfiguration; @@ -110,8 +109,7 @@ public class StreamableMcpAnnotationsIT { private final ApplicationContextRunner clientApplicationContext = new ApplicationContextRunner() .withConfiguration(AutoConfigurations.of(McpToolCallbackAutoConfiguration.class, McpClientAutoConfiguration.class, StreamableHttpWebFluxTransportAutoConfiguration.class, - McpClientAnnotationScannerAutoConfiguration.class, - McpClientSpecificationFactoryAutoConfiguration.class)); + McpClientAnnotationScannerAutoConfiguration.class)); @Test void clientServerCapabilities() { diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotationsManualIT.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotationsManualIT.java index 1f6c2490267..9d83b94e6df 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotationsManualIT.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotationsManualIT.java @@ -75,7 +75,6 @@ import org.springframework.ai.mcp.client.common.autoconfigure.McpClientAutoConfiguration; import org.springframework.ai.mcp.client.common.autoconfigure.McpToolCallbackAutoConfiguration; import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientAnnotationScannerAutoConfiguration; -import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientSpecificationFactoryAutoConfiguration; import org.springframework.ai.mcp.client.webflux.autoconfigure.StreamableHttpWebFluxTransportAutoConfiguration; import org.springframework.ai.mcp.server.common.autoconfigure.McpServerAutoConfiguration; import org.springframework.ai.mcp.server.common.autoconfigure.McpServerObjectMapperAutoConfiguration; @@ -117,7 +116,7 @@ public class StreamableMcpAnnotationsManualIT { .withConfiguration(AutoConfigurations.of(McpToolCallbackAutoConfiguration.class, McpClientAutoConfiguration.class, StreamableHttpWebFluxTransportAutoConfiguration.class, // MCP Annotations - McpClientAnnotationScannerAutoConfiguration.class, McpClientSpecificationFactoryAutoConfiguration.class, + McpClientAnnotationScannerAutoConfiguration.class, // Anthropic ChatClient Builder AnthropicChatAutoConfiguration.class, ChatClientAutoConfiguration.class, SpringAiRetryAutoConfiguration.class, ToolCallingAutoConfiguration.class)); diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotationsWithLLMIT.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotationsWithLLMIT.java index 9403b2e0bf4..3799a52ddb9 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotationsWithLLMIT.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotationsWithLLMIT.java @@ -51,7 +51,6 @@ import org.springframework.ai.mcp.client.common.autoconfigure.McpClientAutoConfiguration; import org.springframework.ai.mcp.client.common.autoconfigure.McpToolCallbackAutoConfiguration; import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientAnnotationScannerAutoConfiguration; -import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientSpecificationFactoryAutoConfiguration; import org.springframework.ai.mcp.client.webflux.autoconfigure.StreamableHttpWebFluxTransportAutoConfiguration; import org.springframework.ai.mcp.server.common.autoconfigure.McpServerAutoConfiguration; import org.springframework.ai.mcp.server.common.autoconfigure.McpServerObjectMapperAutoConfiguration; @@ -98,8 +97,8 @@ public class StreamableMcpAnnotationsWithLLMIT { .withPropertyValues("spring.ai.anthropic.apiKey=" + System.getenv("ANTHROPIC_API_KEY")) .withConfiguration(anthropicAutoConfig(McpToolCallbackAutoConfiguration.class, McpClientAutoConfiguration.class, StreamableHttpWebFluxTransportAutoConfiguration.class, - McpClientAnnotationScannerAutoConfiguration.class, McpClientSpecificationFactoryAutoConfiguration.class, - AnthropicChatAutoConfiguration.class, ChatClientAutoConfiguration.class)); + McpClientAnnotationScannerAutoConfiguration.class, AnthropicChatAutoConfiguration.class, + ChatClientAutoConfiguration.class)); private static AutoConfigurations anthropicAutoConfig(Class... additional) { Class[] dependencies = { SpringAiRetryAutoConfiguration.class, ToolCallingAutoConfiguration.class, diff --git a/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistry.java b/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistry.java new file mode 100644 index 00000000000..540ad367edb --- /dev/null +++ b/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistry.java @@ -0,0 +1,356 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mcp.annotation.spring; + +import java.lang.annotation.Annotation; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; + +import io.modelcontextprotocol.spec.McpSchema; +import org.springaicommunity.mcp.annotation.McpElicitation; +import org.springaicommunity.mcp.annotation.McpLogging; +import org.springaicommunity.mcp.annotation.McpProgress; +import org.springaicommunity.mcp.annotation.McpPromptListChanged; +import org.springaicommunity.mcp.annotation.McpResourceListChanged; +import org.springaicommunity.mcp.annotation.McpSampling; +import org.springaicommunity.mcp.annotation.McpToolListChanged; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.beans.BeansException; +import org.springframework.beans.factory.SmartInitializingSingleton; +import org.springframework.beans.factory.config.BeanFactoryPostProcessor; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.core.annotation.AnnotationUtils; +import org.springframework.util.ReflectionUtils; + +/** + * Registry of methods annotated with MCP Client annotations (sampling, logging, etc.). + * All beans in the application context are scanned to find these methods automatically. + * They are then exposed by the registry by client name. + *

+ * The scanning happens in two phases: + *

+ * First, once bean definitions are available, all bean types are scanned for the presence + * of MCP annotations. In particular, this is used to prepare the result + * {@link #getCapabilities(String)}, which is then used by MCP client auto-configurations + * to configure the client capabilities without needing to instantiate the beans. + *

+ * Second, after all singleton beans have been instantiated, all annotated beans are + * scanned again, MCP handlers are created to match the annotations, and stored by client. + * + * @see McpSampling + * @see McpElicitation + * @see McpLogging + * @see McpProgress + * @see McpToolListChanged + * @see McpPromptListChanged + * @see McpResourceListChanged + * @author Daniel Garnier-Moiroux + * @since 1.1.0 + */ +public class ClientMcpAsyncHandlersRegistry implements BeanFactoryPostProcessor, SmartInitializingSingleton { + + private static final Class[] CLIENT_MCP_ANNOTATIONS = new Class[] { McpSampling.class, + McpElicitation.class, McpLogging.class, McpProgress.class, McpToolListChanged.class, + McpPromptListChanged.class, McpResourceListChanged.class }; + + private final McpSchema.ClientCapabilities EMPTY_CAPABILITIES = new McpSchema.ClientCapabilities(null, null, null, + null); + + private Map capabilitiesPerClient = new HashMap<>(); + + private ConfigurableListableBeanFactory beanFactory; + + private final Set allAnnotatedBeans = new HashSet<>(); + + private final Map>> samplingHandlers = new HashMap<>(); + + private final Map>> elicitationHandlers = new HashMap<>(); + + private final Map>>> loggingHandlers = new HashMap<>(); + + private final Map>>> progressHandlers = new HashMap<>(); + + private final Map, Mono>>> toolListChangedHandlers = new HashMap<>(); + + private final Map, Mono>>> promptListChangedHandlers = new HashMap<>(); + + private final Map, Mono>>> resourceListChangedHandlers = new HashMap<>(); + + /** + * Obtain the MCP capabilities declared for a given MCP client. Capabilities are + * registered with the {@link McpSampling} and {@link McpElicitation} annotations. + */ + public McpSchema.ClientCapabilities getCapabilities(String clientName) { + return this.capabilitiesPerClient.getOrDefault(clientName, this.EMPTY_CAPABILITIES); + } + + /** + * Invoke the sampling handler for a given MCP client. + * + * @see McpSampling + */ + public Mono handleSampling(String name, + McpSchema.CreateMessageRequest samplingRequest) { + var handler = this.samplingHandlers.get(name); + if (handler != null) { + return handler.apply(samplingRequest); + } + // TODO: handle null + return Mono.empty(); + } + + /** + * Invoke the elicitation handler for a given MCP client. + * + * @see McpElicitation + */ + public Mono handleElicitation(String name, McpSchema.ElicitRequest elicitationRequest) { + var handler = this.elicitationHandlers.get(name); + if (handler != null) { + return handler.apply(elicitationRequest); + } + // TODO: handle null + return Mono.empty(); + } + + /** + * Invoke all elicitation handlers for a given MCP client, sequentially. + * + * @see McpLogging + */ + public Mono handleLogging(String name, McpSchema.LoggingMessageNotification loggingMessageNotification) { + var consumers = this.loggingHandlers.get(name); + if (consumers == null) { + // TODO handle + return Mono.empty(); + } + return Flux.fromIterable(consumers).flatMap(c -> c.apply(loggingMessageNotification)).then(); + } + + /** + * Invoke all progress handlers for a given MCP client, sequentially. + * + * @see McpProgress + */ + public Mono handleProgress(String name, McpSchema.ProgressNotification progressNotification) { + var consumers = this.progressHandlers.get(name); + if (consumers == null) { + // TODO handle + return Mono.empty(); + } + return Flux.fromIterable(consumers).flatMap(c -> c.apply(progressNotification)).then(); + } + + /** + * Invoke all tool list changed handlers for a given MCP client, sequentially. + * + * @see McpToolListChanged + */ + public Mono handleToolListChanged(String name, List updatedTools) { + var consumers = this.toolListChangedHandlers.get(name); + if (consumers == null) { + // TODO handle + return Mono.empty(); + } + return Flux.fromIterable(consumers).flatMap(c -> c.apply(updatedTools)).then(); + } + + /** + * Invoke all prompt list changed handlers for a given MCP client, sequentially. + * + * @see McpPromptListChanged + */ + public Mono handlePromptListChanged(String name, List updatedPrompts) { + var consumers = this.promptListChangedHandlers.get(name); + if (consumers == null) { + // TODO handle + return Mono.empty(); + } + return Flux.fromIterable(consumers).flatMap(c -> c.apply(updatedPrompts)).then(); + } + + /** + * Invoke all resource list changed handlers for a given MCP client, sequentially. + * + * @see McpResourceListChanged + */ + public Mono handleResourceListChanged(String name, List updatedResources) { + var consumers = this.resourceListChangedHandlers.get(name); + if (consumers == null) { + // TODO handle + return Mono.empty(); + } + return Flux.fromIterable(consumers).flatMap(c -> c.apply(updatedResources)).then(); + } + + @Override + public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException { + this.beanFactory = beanFactory; + Map> elicitationClientToAnnotatedBeans = new HashMap<>(); + Map> samplingClientToAnnotatedBeans = new HashMap<>(); + for (var beanName : beanFactory.getBeanDefinitionNames()) { + var definition = beanFactory.getBeanDefinition(beanName); + var foundAnnotations = scan(definition.getResolvableType().toClass()); + if (!foundAnnotations.isEmpty()) { + this.allAnnotatedBeans.add(beanName); + } + for (var foundAnnotation : foundAnnotations) { + if (foundAnnotation instanceof McpSampling sampling) { + for (var client : sampling.clients()) { + samplingClientToAnnotatedBeans.computeIfAbsent(client, c -> new ArrayList<>()).add(beanName); + } + } + else if (foundAnnotation instanceof McpElicitation elicitation) { + for (var client : elicitation.clients()) { + elicitationClientToAnnotatedBeans.computeIfAbsent(client, c -> new ArrayList<>()).add(beanName); + } + } + } + } + + for (var elicitationEntry : elicitationClientToAnnotatedBeans.entrySet()) { + if (elicitationEntry.getValue().size() > 1) { + throw new IllegalArgumentException( + "Found 2 elicitation handlers for client [%s], found in bean with names %s. Only one @McpElicitation handler is allowed per client" + .formatted(elicitationEntry.getKey(), new LinkedHashSet<>(elicitationEntry.getValue()))); + } + } + for (var samplingEntry : samplingClientToAnnotatedBeans.entrySet()) { + if (samplingEntry.getValue().size() > 1) { + throw new IllegalArgumentException( + "Found 2 sampling handlers for client [%s], found in bean with names %s. Only one @McpSampling handler is allowed per client" + .formatted(samplingEntry.getKey(), new LinkedHashSet<>(samplingEntry.getValue()))); + } + } + + Map capsPerClient = new HashMap<>(); + for (var samplingClient : samplingClientToAnnotatedBeans.keySet()) { + capsPerClient.computeIfAbsent(samplingClient, ignored -> McpSchema.ClientCapabilities.builder()).sampling(); + } + for (var elicitationClient : elicitationClientToAnnotatedBeans.keySet()) { + capsPerClient.computeIfAbsent(elicitationClient, ignored -> McpSchema.ClientCapabilities.builder()) + .elicitation(); + } + + this.capabilitiesPerClient = capsPerClient.entrySet() + .stream() + .collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().build())); + } + + private List scan(Class beanClass) { + List foundAnnotations = new ArrayList<>(); + + // Scan all methods in the bean class + ReflectionUtils.doWithMethods(beanClass, method -> { + for (var annotationType : CLIENT_MCP_ANNOTATIONS) { + Annotation annotation = AnnotationUtils.findAnnotation(method, annotationType); + if (annotation != null) { + foundAnnotations.add(annotation); + } + } + }); + return foundAnnotations; + } + + @Override + public void afterSingletonsInstantiated() { + // Use a set in case multiple handlers are registered in the same bean + Map, Set> beansByAnnotation = new HashMap<>(); + for (var annotation : CLIENT_MCP_ANNOTATIONS) { + beansByAnnotation.put(annotation, new HashSet<>()); + } + + for (var beanName : this.allAnnotatedBeans) { + var bean = this.beanFactory.getBean(beanName); + var annotations = scan(bean.getClass()); + for (var annotation : annotations) { + beansByAnnotation.computeIfAbsent(annotation.annotationType(), k -> new HashSet<>()).add(bean); + } + } + + var samplingSpecs = AsyncMcpAnnotationProviders + .samplingSpecifications(new ArrayList<>(beansByAnnotation.get(McpSampling.class))); + for (var samplingSpec : samplingSpecs) { + for (var client : samplingSpec.clients()) { + this.samplingHandlers.put(client, samplingSpec.samplingHandler()); + } + } + + var elicitationSpecs = AsyncMcpAnnotationProviders + .elicitationSpecifications(new ArrayList<>(beansByAnnotation.get(McpElicitation.class))); + for (var elicitationSpec : elicitationSpecs) { + for (var client : elicitationSpec.clients()) { + this.elicitationHandlers.put(client, elicitationSpec.elicitationHandler()); + } + } + + var loggingSpecs = AsyncMcpAnnotationProviders + .loggingSpecifications(new ArrayList<>(beansByAnnotation.get(McpLogging.class))); + for (var loggingSpec : loggingSpecs) { + for (var client : loggingSpec.clients()) { + this.loggingHandlers.computeIfAbsent(client, k -> new ArrayList<>()).add(loggingSpec.loggingHandler()); + } + } + + var progressSpecs = AsyncMcpAnnotationProviders + .progressSpecifications(new ArrayList<>(beansByAnnotation.get(McpProgress.class))); + for (var progressSpec : progressSpecs) { + for (var client : progressSpec.clients()) { + this.progressHandlers.computeIfAbsent(client, k -> new ArrayList<>()) + .add(progressSpec.progressHandler()); + } + } + + var toolsListChangedSpecs = AsyncMcpAnnotationProviders + .toolListChangedSpecifications(new ArrayList<>(beansByAnnotation.get(McpToolListChanged.class))); + for (var toolsListChangedSpec : toolsListChangedSpecs) { + for (var client : toolsListChangedSpec.clients()) { + this.toolListChangedHandlers.computeIfAbsent(client, k -> new ArrayList<>()) + .add(toolsListChangedSpec.toolListChangeHandler()); + } + } + + var promptListChangedSpecs = AsyncMcpAnnotationProviders + .promptListChangedSpecifications(new ArrayList<>(beansByAnnotation.get(McpPromptListChanged.class))); + for (var promptListChangedSpec : promptListChangedSpecs) { + for (var client : promptListChangedSpec.clients()) { + this.promptListChangedHandlers.computeIfAbsent(client, k -> new ArrayList<>()) + .add(promptListChangedSpec.promptListChangeHandler()); + } + } + + var resourceListChangedSpecs = AsyncMcpAnnotationProviders + .resourceListChangedSpecifications(new ArrayList<>(beansByAnnotation.get(McpResourceListChanged.class))); + for (var resourceListChangedSpec : resourceListChangedSpecs) { + for (var client : resourceListChangedSpec.clients()) { + this.resourceListChangedHandlers.computeIfAbsent(client, k -> new ArrayList<>()) + .add(resourceListChangedSpec.resourceListChangeHandler()); + } + } + + } + +} diff --git a/mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistryTests.java b/mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistryTests.java new file mode 100644 index 00000000000..635a7d06605 --- /dev/null +++ b/mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistryTests.java @@ -0,0 +1,459 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mcp.annotation.spring; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import io.modelcontextprotocol.spec.McpSchema; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.springaicommunity.mcp.annotation.McpElicitation; +import org.springaicommunity.mcp.annotation.McpLogging; +import org.springaicommunity.mcp.annotation.McpProgress; +import org.springaicommunity.mcp.annotation.McpPromptListChanged; +import org.springaicommunity.mcp.annotation.McpResourceListChanged; +import org.springaicommunity.mcp.annotation.McpSampling; +import org.springaicommunity.mcp.annotation.McpToolListChanged; +import reactor.core.publisher.Mono; + +import org.springframework.beans.factory.support.BeanDefinitionBuilder; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.fail; + +class ClientMcpAsyncHandlersRegistryTests { + + @Test + void getCapabilitiesPerClient() { + var registry = new ClientMcpAsyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder.genericBeanDefinition(ClientCapabilitiesConfiguration.class).getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + + assertThat(registry.getCapabilities("client-1").elicitation()).isNotNull(); + assertThat(registry.getCapabilities("client-2").elicitation()).isNotNull(); + assertThat(registry.getCapabilities("client-3").elicitation()).isNotNull(); + + assertThat(registry.getCapabilities("client-1").sampling()).isNotNull(); + assertThat(registry.getCapabilities("client-2").sampling()).isNull(); + assertThat(registry.getCapabilities("client-3").sampling()).isNull(); + + assertThat(registry.getCapabilities("client-1").roots()).isNull(); + assertThat(registry.getCapabilities("client-2").roots()).isNull(); + assertThat(registry.getCapabilities("client-3").roots()).isNull(); + + assertThat(registry.getCapabilities("client-1").experimental()).isNull(); + assertThat(registry.getCapabilities("client-2").experimental()).isNull(); + assertThat(registry.getCapabilities("client-3").experimental()).isNull(); + + assertThat(registry.getCapabilities("client-unknown").sampling()).isNull(); + assertThat(registry.getCapabilities("client-unknown").elicitation()).isNull(); + assertThat(registry.getCapabilities("client-unknown").roots()).isNull(); + } + + @Test + void twoHandlersElicitation() { + var registry = new ClientMcpAsyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("firstConfig", + BeanDefinitionBuilder.genericBeanDefinition(DoubleElicitationHandlerConfiguration.First.class) + .getBeanDefinition()); + beanFactory.registerBeanDefinition("secondConfig", + BeanDefinitionBuilder.genericBeanDefinition(DoubleElicitationHandlerConfiguration.Second.class) + .getBeanDefinition()); + assertThatThrownBy(() -> registry.postProcessBeanFactory(beanFactory)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Found 2 elicitation handlers for client [client-1], found in bean with names [firstConfig, secondConfig]. Only one @McpElicitation handler is allowed per client"); + } + + @Test + void twoHandlersSameBeanElicitation() { + var registry = new ClientMcpAsyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("elicitationConfig", + BeanDefinitionBuilder.genericBeanDefinition(DoubleElicitationHandlerConfiguration.TwoHandlers.class) + .getBeanDefinition()); + assertThatThrownBy(() -> registry.postProcessBeanFactory(beanFactory)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Found 2 elicitation handlers for client [client-1], found in bean with names [elicitationConfig]. Only one @McpElicitation handler is allowed per client"); + } + + @Test + void twoHandlersSampling() { + var registry = new ClientMcpAsyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("firstConfig", + BeanDefinitionBuilder.genericBeanDefinition(DoubleSamplingHandlerConfiguration.First.class) + .getBeanDefinition()); + beanFactory.registerBeanDefinition("secondConfig", + BeanDefinitionBuilder.genericBeanDefinition(DoubleSamplingHandlerConfiguration.Second.class) + .getBeanDefinition()); + assertThatThrownBy(() -> registry.postProcessBeanFactory(beanFactory)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Found 2 sampling handlers for client [client-1], found in bean with names [firstConfig, secondConfig]. Only one @McpSampling handler is allowed per client"); + } + + @Test + void twoHandlersSameBeanSampling() { + var registry = new ClientMcpAsyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("samplingConfig", + BeanDefinitionBuilder.genericBeanDefinition(DoubleSamplingHandlerConfiguration.TwoHandlers.class) + .getBeanDefinition()); + assertThatThrownBy(() -> registry.postProcessBeanFactory(beanFactory)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Found 2 sampling handlers for client [client-1], found in bean with names [samplingConfig]. Only one @McpSampling handler is allowed per client"); + } + + @Test + void elicitation() { + var registry = new ClientMcpAsyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder.genericBeanDefinition(HandlersConfiguration.class).getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + registry.afterSingletonsInstantiated(); + + var request = McpSchema.ElicitRequest.builder().message("Elicit request").progressToken("token-12345").build(); + var response = registry.handleElicitation("client-1", request).block(); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1).containsEntry("message", "Elicit request"); + assertThat(response.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + } + + @Test + void sampling() { + var registry = new ClientMcpAsyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder.genericBeanDefinition(HandlersConfiguration.class).getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + registry.afterSingletonsInstantiated(); + + var request = McpSchema.CreateMessageRequest.builder() + .messages(List + .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Tell a joke")))) + .build(); + var response = registry.handleSampling("client-1", request).block(); + + assertThat(response.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(response.model()).isEqualTo("testgpt-42.5"); + McpSchema.TextContent content = (McpSchema.TextContent) response.content(); + assertThat(content.text()).isEqualTo("Tell a joke"); + } + + @Test + void logging() { + var registry = new ClientMcpAsyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder.genericBeanDefinition(HandlersConfiguration.class).getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + registry.afterSingletonsInstantiated(); + var handlers = beanFactory.getBean(HandlersConfiguration.class); + + var logRequest = McpSchema.LoggingMessageNotification.builder() + .data("Hello world") + .logger("log-me") + .level(McpSchema.LoggingLevel.INFO) + .build(); + + registry.handleLogging("client-1", logRequest).block(); + assertThat(handlers.getCalls()).hasSize(2) + .containsExactlyInAnyOrder(new HandlersConfiguration.Call("handleLoggingMessage", logRequest), + new HandlersConfiguration.Call("handleLoggingMessageAgain", logRequest)); + } + + @Test + void progress() { + var registry = new ClientMcpAsyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder.genericBeanDefinition(HandlersConfiguration.class).getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + registry.afterSingletonsInstantiated(); + var handlers = beanFactory.getBean(HandlersConfiguration.class); + + var progressRequest = new McpSchema.ProgressNotification("progress-12345", 13.37, 100., "progressing ..."); + + registry.handleProgress("client-1", progressRequest).block(); + assertThat(handlers.getCalls()).hasSize(2) + .containsExactlyInAnyOrder(new HandlersConfiguration.Call("handleProgress", progressRequest), + new HandlersConfiguration.Call("handleProgressAgain", progressRequest)); + } + + @Test + void toolListChanged() { + var registry = new ClientMcpAsyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder.genericBeanDefinition(HandlersConfiguration.class).getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + registry.afterSingletonsInstantiated(); + var handlers = beanFactory.getBean(HandlersConfiguration.class); + + List updatedTools = List.of(McpSchema.Tool.builder().name("tool-1").build(), + McpSchema.Tool.builder().name("tool-2").build()); + + registry.handleToolListChanged("client-1", updatedTools).block(); + assertThat(handlers.getCalls()).hasSize(2) + .containsExactlyInAnyOrder(new HandlersConfiguration.Call("handleToolListChanged", updatedTools), + new HandlersConfiguration.Call("handleToolListChangedAgain", updatedTools)); + } + + @Test + void promptListChanged() { + var registry = new ClientMcpAsyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder.genericBeanDefinition(HandlersConfiguration.class).getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + registry.afterSingletonsInstantiated(); + var handlers = beanFactory.getBean(HandlersConfiguration.class); + + List updatedTools = List.of( + new McpSchema.Prompt("prompt-1", "a test prompt", Collections.emptyList()), + new McpSchema.Prompt("prompt-2", "another test prompt", Collections.emptyList())); + + registry.handlePromptListChanged("client-1", updatedTools).block(); + assertThat(handlers.getCalls()).hasSize(2) + .containsExactlyInAnyOrder(new HandlersConfiguration.Call("handlePromptListChanged", updatedTools), + new HandlersConfiguration.Call("handlePromptListChangedAgain", updatedTools)); + } + + @Test + void resourceListChanged() { + var registry = new ClientMcpAsyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder.genericBeanDefinition(HandlersConfiguration.class).getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + registry.afterSingletonsInstantiated(); + var handlers = beanFactory.getBean(HandlersConfiguration.class); + + List updatedResources = List.of( + McpSchema.Resource.builder().name("resource-1").uri("file:///resource/1").build(), + McpSchema.Resource.builder().name("resource-2").uri("file:///resource/2").build()); + + registry.handleResourceListChanged("client-1", updatedResources).block(); + assertThat(handlers.getCalls()).hasSize(2) + .containsExactlyInAnyOrder(new HandlersConfiguration.Call("handleResourceListChanged", updatedResources), + new HandlersConfiguration.Call("handleResourceListChangedAgain", updatedResources)); + } + + @Test + @Disabled + void missingHandler() { + fail("TODO"); + } + + static class ClientCapabilitiesConfiguration { + + @McpElicitation(clients = { "client-1", "client-2" }) + public Mono elicitationHandler1(McpSchema.ElicitRequest request) { + return Mono.empty(); + } + + @McpElicitation(clients = { "client-3" }) + public Mono elicitationHandler2(McpSchema.ElicitRequest request) { + return Mono.empty(); + } + + @McpSampling(clients = { "client-1" }) + public Mono samplingHandler(McpSchema.CreateMessageRequest request) { + return Mono.empty(); + } + + } + + static class DoubleElicitationHandlerConfiguration { + + static class First { + + @McpElicitation(clients = { "client-1" }) + public Mono elicitationHandler1(McpSchema.ElicitRequest request) { + return Mono.empty(); + } + + } + + static class Second { + + @McpElicitation(clients = { "client-1" }) + public Mono elicitationHandler2(McpSchema.ElicitRequest request) { + return Mono.empty(); + } + + } + + static class TwoHandlers { + + @McpElicitation(clients = { "client-1" }) + public Mono elicitationHandler1(McpSchema.ElicitRequest request) { + return Mono.empty(); + } + + @McpElicitation(clients = { "client-1" }) + public Mono elicitationHandler2(McpSchema.ElicitRequest request) { + return Mono.empty(); + } + + } + + } + + static class DoubleSamplingHandlerConfiguration { + + static class First { + + @McpSampling(clients = { "client-1" }) + public Mono samplingHandler1(McpSchema.CreateMessageRequest request) { + return Mono.empty(); + } + + } + + static class Second { + + @McpSampling(clients = { "client-1" }) + public Mono samplingHandler2(McpSchema.CreateMessageRequest request) { + return Mono.empty(); + } + + } + + static class TwoHandlers { + + @McpSampling(clients = { "client-1" }) + public Mono samplingHandler1(McpSchema.CreateMessageRequest request) { + return Mono.empty(); + } + + @McpSampling(clients = { "client-1" }) + public Mono samplingHandler2(McpSchema.CreateMessageRequest request) { + return Mono.empty(); + } + + } + + } + + static class HandlersConfiguration { + + private final List calls = new ArrayList<>(); + + HandlersConfiguration() { + } + + List getCalls() { + return Collections.unmodifiableList(this.calls); + } + + @McpElicitation(clients = { "client-1" }) + Mono elicitationHandler(McpSchema.ElicitRequest request) { + return Mono.just(McpSchema.ElicitResult.builder() + .message(McpSchema.ElicitResult.Action.ACCEPT) + .content(Map.of("message", request.message())) + .build()); + } + + @McpSampling(clients = { "client-1" }) + Mono samplingHandler(McpSchema.CreateMessageRequest request) { + return Mono.just(McpSchema.CreateMessageResult.builder() + .message(((McpSchema.TextContent) request.messages().get(0).content()).text()) + .model("testgpt-42.5") + .build()); + } + + @McpLogging(clients = { "client-1" }) + Mono handleLoggingMessage(McpSchema.LoggingMessageNotification notification) { + this.calls.add(new Call("handleLoggingMessage", notification)); + return Mono.empty(); + } + + @McpLogging(clients = { "client-1" }) + Mono handleLoggingMessageAgain(McpSchema.LoggingMessageNotification notification) { + this.calls.add(new Call("handleLoggingMessageAgain", notification)); + return Mono.empty(); + } + + @McpProgress(clients = { "client-1" }) + Mono handleProgress(McpSchema.ProgressNotification notification) { + this.calls.add(new Call("handleProgress", notification)); + return Mono.empty(); + } + + @McpProgress(clients = { "client-1" }) + Mono handleProgressAgain(McpSchema.ProgressNotification notification) { + this.calls.add(new Call("handleProgressAgain", notification)); + return Mono.empty(); + } + + @McpToolListChanged(clients = { "client-1" }) + Mono handleToolListChanged(List updatedTools) { + this.calls.add(new Call("handleToolListChanged", updatedTools)); + return Mono.empty(); + } + + @McpToolListChanged(clients = { "client-1" }) + Mono handleToolListChangedAgain(List updatedTools) { + this.calls.add(new Call("handleToolListChangedAgain", updatedTools)); + return Mono.empty(); + } + + @McpPromptListChanged(clients = { "client-1" }) + Mono handlePromptListChanged(List updatedPrompts) { + this.calls.add(new Call("handlePromptListChanged", updatedPrompts)); + return Mono.empty(); + } + + @McpPromptListChanged(clients = { "client-1" }) + Mono handlePromptListChangedAgain(List updatedPrompts) { + this.calls.add(new Call("handlePromptListChangedAgain", updatedPrompts)); + return Mono.empty(); + } + + @McpResourceListChanged(clients = { "client-1" }) + Mono handleResourceListChanged(List updatedResources) { + this.calls.add(new Call("handleResourceListChanged", updatedResources)); + return Mono.empty(); + } + + @McpResourceListChanged(clients = { "client-1" }) + Mono handleResourceListChangedAgain(List updatedResources) { + this.calls.add(new Call("handleResourceListChangedAgain", updatedResources)); + return Mono.empty(); + } + + // Record calls made to this object + record Call(String name, Object callRequest) { + } + + } + +} From 23a05b08036e5dc7691567f5da7e5aa2f02ea564 Mon Sep 17 00:00:00 2001 From: Daniel Garnier-Moiroux Date: Wed, 5 Nov 2025 07:57:03 +0100 Subject: [PATCH 03/11] Simplify ToolCallbackResolver auto-configuration. - In febf86c, we broke a dependency cycle ChatClient -> McpClient - With the introduction of ClientMcpSyncHandlersRegistry and the async variant, there is no dependency McpClient -> MCP handlers anymore, breaking the cycle in a simpler way. - Here, we revert most of the changes of febf86c, but keep the tests. Signed-off-by: Daniel Garnier-Moiroux --- .../McpToolsConfigurationTests.java | 43 ++++----------- .../ToolCallingAutoConfiguration.java | 55 +------------------ 2 files changed, 14 insertions(+), 84 deletions(-) diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/McpToolsConfigurationTests.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/McpToolsConfigurationTests.java index 2dc90ca4d88..b9db603a0bc 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/McpToolsConfigurationTests.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/McpToolsConfigurationTests.java @@ -120,26 +120,7 @@ void toolCallbacksRegistered() { assertThat(resolver.resolve("customToolCallbackProvider")).isNotNull(); // MCP toolcallback providers are never added to the resolver - - // Bean graph setup - var injectedProviders = (List) ctx.getBean( - "org.springframework.ai.model.tool.autoconfigure.ToolCallingAutoConfiguration.toolcallbackprovider.mcp-excluded"); - // Beans exposed as non-MCP - var toolCallbackProvider = (ToolCallbackProvider) ctx.getBean("toolCallbackProvider"); - var customToolCallbackProvider = (ToolCallbackProvider) ctx.getBean("customToolCallbackProvider"); - // This is injected in the resolver bean, because it's exposed as a - // ToolCallbackProvider, but it's not added to the resolver - var genericMcpToolCallbackProvider = (ToolCallbackProvider) ctx.getBean("genericMcpToolCallbackProvider"); - - // beans exposed as MCP - var mcpToolCallbackProvider = (ToolCallbackProvider) ctx.getBean("mcpToolCallbackProvider"); - var customMcpToolCallbackProvider = (ToolCallbackProvider) ctx.getBean("customMcpToolCallbackProvider"); - - assertThat(injectedProviders) - .containsExactlyInAnyOrder(toolCallbackProvider, customToolCallbackProvider, - genericMcpToolCallbackProvider) - .doesNotContain(mcpToolCallbackProvider, customMcpToolCallbackProvider); - + // Otherwise, they would throw. }); } @@ -192,29 +173,27 @@ ToolCallbackProvider toolCallbackProvider() { return tcp; } - // This bean depends on the resolver, to ensure there are no cyclic dependencies @Bean - SyncMcpToolCallbackProvider mcpToolCallbackProvider(ToolCallbackResolver resolver) { + CustomToolCallbackProvider customToolCallbackProvider() { + return new CustomToolCallbackProvider("customToolCallbackProvider"); + } + + // Ignored by the resolver + @Bean + SyncMcpToolCallbackProvider mcpToolCallbackProvider() { var tcp = mock(SyncMcpToolCallbackProvider.class); when(tcp.getToolCallbacks()) .thenThrow(new RuntimeException("mcpToolCallbackProvider#getToolCallbacks should not be called")); return tcp; } + // Ignored by the resolver @Bean - CustomToolCallbackProvider customToolCallbackProvider() { - return new CustomToolCallbackProvider("customToolCallbackProvider"); - } - - // This bean depends on the resolver, to ensure there are no cyclic dependencies - @Bean - CustomMcpToolCallbackProvider customMcpToolCallbackProvider(ToolCallbackResolver resolver) { + CustomMcpToolCallbackProvider customMcpToolCallbackProvider() { return new CustomMcpToolCallbackProvider(); } - // This will be added to the resolver, because the visible type of the bean - // is ToolCallbackProvider ; we would need to actually instantiate the bean - // to find out that it is MCP-related + // Ignored by the resolver @Bean ToolCallbackProvider genericMcpToolCallbackProvider() { return new CustomMcpToolCallbackProvider(); diff --git a/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/main/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfiguration.java b/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/main/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfiguration.java index e5d6699cd69..bcdad6b2bf5 100644 --- a/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/main/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfiguration.java +++ b/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/main/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfiguration.java @@ -17,7 +17,6 @@ package org.springframework.ai.model.tool.autoconfigure; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import io.micrometer.observation.ObservationRegistry; @@ -36,14 +35,7 @@ import org.springframework.ai.tool.resolution.SpringBeanToolCallbackResolver; import org.springframework.ai.tool.resolution.StaticToolCallbackResolver; import org.springframework.ai.tool.resolution.ToolCallbackResolver; -import org.springframework.beans.BeansException; import org.springframework.beans.factory.ObjectProvider; -import org.springframework.beans.factory.annotation.Qualifier; -import org.springframework.beans.factory.config.BeanDefinition; -import org.springframework.beans.factory.support.BeanDefinitionBuilder; -import org.springframework.beans.factory.support.BeanDefinitionRegistry; -import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor; -import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; @@ -65,26 +57,20 @@ @AutoConfiguration @ConditionalOnClass(ChatModel.class) @EnableConfigurationProperties(ToolCallingProperties.class) -public class ToolCallingAutoConfiguration implements BeanDefinitionRegistryPostProcessor { +public class ToolCallingAutoConfiguration { private static final Logger logger = LoggerFactory.getLogger(ToolCallingAutoConfiguration.class); - // Marker qualifier to exclude MCP-related ToolCallbackProviders - private static final String EXCLUDE_MCP_TOOL_CALLBACK_PROVIDER = "org.springframework.ai.model.tool.autoconfigure.ToolCallingAutoConfiguration.toolcallbackprovider.mcp-excluded"; - /** * The default {@link ToolCallbackResolver} resolves tools by name for methods, * functions, and {@link ToolCallbackProvider} beans. *

- * MCP providers should not be injected to avoid cyclic dependencies. If some MCP - * providers are injected, we filter them out to avoid eagerly calling - * #getToolCallbacks. + * MCP providers are excluded, to avoid initializing them early with #listTools(). */ @Bean @ConditionalOnMissingBean ToolCallbackResolver toolCallbackResolver(GenericApplicationContext applicationContext, - List toolCallbacks, - @Qualifier(EXCLUDE_MCP_TOOL_CALLBACK_PROVIDER) List tcbProviders) { + List toolCallbacks, List tcbProviders) { List allFunctionAndToolCallbacks = new ArrayList<>(toolCallbacks); tcbProviders.stream() .filter(pr -> !isMcpToolCallbackProvider(ResolvableType.forInstance(pr))) @@ -100,41 +86,6 @@ ToolCallbackResolver toolCallbackResolver(GenericApplicationContext applicationC return new DelegatingToolCallbackResolver(List.of(staticToolCallbackResolver, springBeanToolCallbackResolver)); } - /** - * Wrap {@link ToolCallbackProvider} beans that are not MCP-related into a named bean, - * which will be picked up by the - * {@link ToolCallingAutoConfiguration#toolCallbackResolver}. - *

- * MCP providers must be excluded, because they may depend on a {@code ChatClient} to - * do sampling. The chat client, in turn, depends on a {@link ToolCallbackResolver}. - * To do the detection, we depend on the exposed bean type. If a bean uses a factory - * method which returns a {@link ToolCallbackProvider}, which is an MCP provider under - * the hood, it will be included in the list. - */ - @Override - public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException { - if (!(registry instanceof DefaultListableBeanFactory beanFactory)) { - return; - } - - var excludeMcpToolCallbackProviderBeanDefinition = BeanDefinitionBuilder - .genericBeanDefinition(List.class, () -> { - var providerNames = beanFactory.getBeanNamesForType(ToolCallbackProvider.class); - return Arrays.stream(providerNames) - .filter(name -> !isMcpToolCallbackProvider(beanFactory.getBeanDefinition(name).getResolvableType())) - .map(beanFactory::getBean) - .filter(ToolCallbackProvider.class::isInstance) - .map(ToolCallbackProvider.class::cast) - .toList(); - }) - .setScope(BeanDefinition.SCOPE_SINGLETON) - .setLazyInit(true) - .getBeanDefinition(); - - registry.registerBeanDefinition(EXCLUDE_MCP_TOOL_CALLBACK_PROVIDER, - excludeMcpToolCallbackProviderBeanDefinition); - } - private static boolean isMcpToolCallbackProvider(ResolvableType type) { if (type.getType().getTypeName().equals("org.springframework.ai.mcp.SyncMcpToolCallbackProvider") || type.getType().getTypeName().equals("org.springframework.ai.mcp.AsyncMcpToolCallbackProvider")) { From a3aedbbd80659fbcd93c5cfc1d4fcfacd250ddf0 Mon Sep 17 00:00:00 2001 From: Daniel Garnier-Moiroux Date: Wed, 5 Nov 2025 08:17:08 +0100 Subject: [PATCH 04/11] Remove unused MCP annotated beans auto-configuration Signed-off-by: Daniel Garnier-Moiroux --- ...entAnnotationScannerAutoConfiguration.java | 23 ------------------- 1 file changed, 23 deletions(-) diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientAnnotationScannerAutoConfiguration.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientAnnotationScannerAutoConfiguration.java index 47ad28101b0..449c2a9da4b 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientAnnotationScannerAutoConfiguration.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientAnnotationScannerAutoConfiguration.java @@ -30,7 +30,6 @@ import org.springframework.ai.mcp.annotation.spring.ClientMcpAsyncHandlersRegistry; import org.springframework.ai.mcp.annotation.spring.ClientMcpSyncHandlersRegistry; import org.springframework.ai.mcp.annotation.spring.scan.AbstractAnnotatedMethodBeanFactoryInitializationAotProcessor; -import org.springframework.ai.mcp.annotation.spring.scan.AbstractAnnotatedMethodBeanPostProcessor; import org.springframework.ai.mcp.annotation.spring.scan.AbstractMcpAnnotatedBeans; import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpClientCommonProperties; import org.springframework.aot.hint.MemberCategory; @@ -76,19 +75,6 @@ public ClientMcpAsyncHandlersRegistry clientMcpAsyncHandlersRegistry() { return new ClientMcpAsyncHandlersRegistry(); } - @Bean - @ConditionalOnMissingBean - public ClientMcpAnnotatedBeans clientAnnotatedBeans() { - return new ClientMcpAnnotatedBeans(); - } - - @Bean - @ConditionalOnMissingBean - public static ClientAnnotatedMethodBeanPostProcessor clientAnnotatedMethodBeanPostProcessor( - ClientMcpAnnotatedBeans clientMcpAnnotatedBeans, McpClientAnnotationScannerProperties properties) { - return new ClientAnnotatedMethodBeanPostProcessor(clientMcpAnnotatedBeans, CLIENT_MCP_ANNOTATIONS); - } - @Bean static ClientAnnotatedBeanFactoryInitializationAotProcessor clientAnnotatedBeanFactoryInitializationAotProcessor() { return new ClientAnnotatedBeanFactoryInitializationAotProcessor(CLIENT_MCP_ANNOTATIONS); @@ -108,15 +94,6 @@ public ClientAnnotatedBeanFactoryInitializationAotProcessor( } - public static class ClientAnnotatedMethodBeanPostProcessor extends AbstractAnnotatedMethodBeanPostProcessor { - - public ClientAnnotatedMethodBeanPostProcessor(ClientMcpAnnotatedBeans clientMcpAnnotatedBeans, - Set> targetAnnotations) { - super(clientMcpAnnotatedBeans, targetAnnotations); - } - - } - static class AnnotationHints implements RuntimeHintsRegistrar { @Override From 598d8fb39d2a8c692508411578756a2f94c005ef Mon Sep 17 00:00:00 2001 From: Daniel Garnier-Moiroux Date: Wed, 5 Nov 2025 17:49:40 +0100 Subject: [PATCH 05/11] Introduce AbstractClientMcpHandlerRegistry Signed-off-by: Daniel Garnier-Moiroux --- .../AbstractClientMcpHandlerRegistry.java | 135 ++++++++++++++++++ .../ClientMcpAsyncHandlersRegistry.java | 111 +------------- .../spring/ClientMcpSyncHandlersRegistry.java | 111 +------------- 3 files changed, 143 insertions(+), 214 deletions(-) create mode 100644 mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/AbstractClientMcpHandlerRegistry.java diff --git a/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/AbstractClientMcpHandlerRegistry.java b/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/AbstractClientMcpHandlerRegistry.java new file mode 100644 index 00000000000..c662536173c --- /dev/null +++ b/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/AbstractClientMcpHandlerRegistry.java @@ -0,0 +1,135 @@ +package org.springframework.ai.mcp.annotation.spring; + +import java.lang.annotation.Annotation; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import io.modelcontextprotocol.spec.McpSchema; +import org.springaicommunity.mcp.annotation.McpElicitation; +import org.springaicommunity.mcp.annotation.McpLogging; +import org.springaicommunity.mcp.annotation.McpProgress; +import org.springaicommunity.mcp.annotation.McpPromptListChanged; +import org.springaicommunity.mcp.annotation.McpResourceListChanged; +import org.springaicommunity.mcp.annotation.McpSampling; +import org.springaicommunity.mcp.annotation.McpToolListChanged; + +import org.springframework.beans.BeansException; +import org.springframework.beans.factory.config.BeanFactoryPostProcessor; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.core.annotation.AnnotationUtils; +import org.springframework.util.ReflectionUtils; + +/** + * Base class for sync and async ClientMcpHandlerRegistries. Not intended for public use. + * + * @see ClientMcpAsyncHandlersRegistry + * @see ClientMcpSyncHandlersRegistry + */ +abstract class AbstractClientMcpHandlerRegistry implements BeanFactoryPostProcessor { + + protected Map capabilitiesPerClient = new HashMap<>(); + + protected ConfigurableListableBeanFactory beanFactory; + + protected final Set allAnnotatedBeans = new HashSet<>(); + + static final Class[] CLIENT_MCP_ANNOTATIONS = new Class[] { McpSampling.class, + McpElicitation.class, McpLogging.class, McpProgress.class, McpToolListChanged.class, + McpPromptListChanged.class, McpResourceListChanged.class }; + + static final McpSchema.ClientCapabilities EMPTY_CAPABILITIES = new McpSchema.ClientCapabilities(null, null, null, + null); + + @Override + public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException { + this.beanFactory = beanFactory; + Map> elicitationClientToAnnotatedBeans = new HashMap<>(); + Map> samplingClientToAnnotatedBeans = new HashMap<>(); + for (var beanName : beanFactory.getBeanDefinitionNames()) { + var definition = beanFactory.getBeanDefinition(beanName); + var foundAnnotations = scan(definition.getResolvableType().toClass()); + if (!foundAnnotations.isEmpty()) { + this.allAnnotatedBeans.add(beanName); + } + for (var foundAnnotation : foundAnnotations) { + if (foundAnnotation instanceof McpSampling sampling) { + for (var client : sampling.clients()) { + samplingClientToAnnotatedBeans.computeIfAbsent(client, c -> new ArrayList<>()).add(beanName); + } + } + else if (foundAnnotation instanceof McpElicitation elicitation) { + for (var client : elicitation.clients()) { + elicitationClientToAnnotatedBeans.computeIfAbsent(client, c -> new ArrayList<>()).add(beanName); + } + } + } + } + + for (var elicitationEntry : elicitationClientToAnnotatedBeans.entrySet()) { + if (elicitationEntry.getValue().size() > 1) { + throw new IllegalArgumentException( + "Found 2 elicitation handlers for client [%s], found in bean with names %s. Only one @McpElicitation handler is allowed per client" + .formatted(elicitationEntry.getKey(), new LinkedHashSet<>(elicitationEntry.getValue()))); + } + } + for (var samplingEntry : samplingClientToAnnotatedBeans.entrySet()) { + if (samplingEntry.getValue().size() > 1) { + throw new IllegalArgumentException( + "Found 2 sampling handlers for client [%s], found in bean with names %s. Only one @McpSampling handler is allowed per client" + .formatted(samplingEntry.getKey(), new LinkedHashSet<>(samplingEntry.getValue()))); + } + } + + Map capsPerClient = new HashMap<>(); + for (var samplingClient : samplingClientToAnnotatedBeans.keySet()) { + capsPerClient.computeIfAbsent(samplingClient, ignored -> McpSchema.ClientCapabilities.builder()).sampling(); + } + for (var elicitationClient : elicitationClientToAnnotatedBeans.keySet()) { + capsPerClient.computeIfAbsent(elicitationClient, ignored -> McpSchema.ClientCapabilities.builder()) + .elicitation(); + } + + this.capabilitiesPerClient = capsPerClient.entrySet() + .stream() + .collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().build())); + } + + protected List scan(Class beanClass) { + List foundAnnotations = new ArrayList<>(); + + // Scan all methods in the bean class + ReflectionUtils.doWithMethods(beanClass, method -> { + for (var annotationType : CLIENT_MCP_ANNOTATIONS) { + Annotation annotation = AnnotationUtils.findAnnotation(method, annotationType); + if (annotation != null) { + foundAnnotations.add(annotation); + } + } + }); + return foundAnnotations; + } + + protected Map, Set> getBeansByAnnotationType() { + // Use a set in case multiple handlers are registered in the same bean + Map, Set> beansByAnnotation = new HashMap<>(); + for (var annotation : CLIENT_MCP_ANNOTATIONS) { + beansByAnnotation.put(annotation, new HashSet<>()); + } + + for (var beanName : this.allAnnotatedBeans) { + var bean = this.beanFactory.getBean(beanName); + var annotations = scan(bean.getClass()); + for (var annotation : annotations) { + beansByAnnotation.computeIfAbsent(annotation.annotationType(), k -> new HashSet<>()).add(bean); + } + } + return beansByAnnotation; + } + +} diff --git a/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistry.java b/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistry.java index 540ad367edb..768acce647f 100644 --- a/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistry.java +++ b/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistry.java @@ -16,16 +16,11 @@ package org.springframework.ai.mcp.annotation.spring; -import java.lang.annotation.Annotation; import java.util.ArrayList; import java.util.HashMap; -import java.util.HashSet; -import java.util.LinkedHashSet; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.function.Function; -import java.util.stream.Collectors; import io.modelcontextprotocol.spec.McpSchema; import org.springaicommunity.mcp.annotation.McpElicitation; @@ -38,12 +33,7 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import org.springframework.beans.BeansException; import org.springframework.beans.factory.SmartInitializingSingleton; -import org.springframework.beans.factory.config.BeanFactoryPostProcessor; -import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; -import org.springframework.core.annotation.AnnotationUtils; -import org.springframework.util.ReflectionUtils; /** * Registry of methods annotated with MCP Client annotations (sampling, logging, etc.). @@ -70,20 +60,8 @@ * @author Daniel Garnier-Moiroux * @since 1.1.0 */ -public class ClientMcpAsyncHandlersRegistry implements BeanFactoryPostProcessor, SmartInitializingSingleton { - - private static final Class[] CLIENT_MCP_ANNOTATIONS = new Class[] { McpSampling.class, - McpElicitation.class, McpLogging.class, McpProgress.class, McpToolListChanged.class, - McpPromptListChanged.class, McpResourceListChanged.class }; - - private final McpSchema.ClientCapabilities EMPTY_CAPABILITIES = new McpSchema.ClientCapabilities(null, null, null, - null); - - private Map capabilitiesPerClient = new HashMap<>(); - - private ConfigurableListableBeanFactory beanFactory; - - private final Set allAnnotatedBeans = new HashSet<>(); +public class ClientMcpAsyncHandlersRegistry extends AbstractClientMcpHandlerRegistry + implements SmartInitializingSingleton { private final Map>> samplingHandlers = new HashMap<>(); @@ -104,7 +82,7 @@ public class ClientMcpAsyncHandlersRegistry implements BeanFactoryPostProcessor, * registered with the {@link McpSampling} and {@link McpElicitation} annotations. */ public McpSchema.ClientCapabilities getCapabilities(String clientName) { - return this.capabilitiesPerClient.getOrDefault(clientName, this.EMPTY_CAPABILITIES); + return this.capabilitiesPerClient.getOrDefault(clientName, EMPTY_CAPABILITIES); } /** @@ -206,90 +184,9 @@ public Mono handleResourceListChanged(String name, List c.apply(updatedResources)).then(); } - @Override - public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException { - this.beanFactory = beanFactory; - Map> elicitationClientToAnnotatedBeans = new HashMap<>(); - Map> samplingClientToAnnotatedBeans = new HashMap<>(); - for (var beanName : beanFactory.getBeanDefinitionNames()) { - var definition = beanFactory.getBeanDefinition(beanName); - var foundAnnotations = scan(definition.getResolvableType().toClass()); - if (!foundAnnotations.isEmpty()) { - this.allAnnotatedBeans.add(beanName); - } - for (var foundAnnotation : foundAnnotations) { - if (foundAnnotation instanceof McpSampling sampling) { - for (var client : sampling.clients()) { - samplingClientToAnnotatedBeans.computeIfAbsent(client, c -> new ArrayList<>()).add(beanName); - } - } - else if (foundAnnotation instanceof McpElicitation elicitation) { - for (var client : elicitation.clients()) { - elicitationClientToAnnotatedBeans.computeIfAbsent(client, c -> new ArrayList<>()).add(beanName); - } - } - } - } - - for (var elicitationEntry : elicitationClientToAnnotatedBeans.entrySet()) { - if (elicitationEntry.getValue().size() > 1) { - throw new IllegalArgumentException( - "Found 2 elicitation handlers for client [%s], found in bean with names %s. Only one @McpElicitation handler is allowed per client" - .formatted(elicitationEntry.getKey(), new LinkedHashSet<>(elicitationEntry.getValue()))); - } - } - for (var samplingEntry : samplingClientToAnnotatedBeans.entrySet()) { - if (samplingEntry.getValue().size() > 1) { - throw new IllegalArgumentException( - "Found 2 sampling handlers for client [%s], found in bean with names %s. Only one @McpSampling handler is allowed per client" - .formatted(samplingEntry.getKey(), new LinkedHashSet<>(samplingEntry.getValue()))); - } - } - - Map capsPerClient = new HashMap<>(); - for (var samplingClient : samplingClientToAnnotatedBeans.keySet()) { - capsPerClient.computeIfAbsent(samplingClient, ignored -> McpSchema.ClientCapabilities.builder()).sampling(); - } - for (var elicitationClient : elicitationClientToAnnotatedBeans.keySet()) { - capsPerClient.computeIfAbsent(elicitationClient, ignored -> McpSchema.ClientCapabilities.builder()) - .elicitation(); - } - - this.capabilitiesPerClient = capsPerClient.entrySet() - .stream() - .collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().build())); - } - - private List scan(Class beanClass) { - List foundAnnotations = new ArrayList<>(); - - // Scan all methods in the bean class - ReflectionUtils.doWithMethods(beanClass, method -> { - for (var annotationType : CLIENT_MCP_ANNOTATIONS) { - Annotation annotation = AnnotationUtils.findAnnotation(method, annotationType); - if (annotation != null) { - foundAnnotations.add(annotation); - } - } - }); - return foundAnnotations; - } - @Override public void afterSingletonsInstantiated() { - // Use a set in case multiple handlers are registered in the same bean - Map, Set> beansByAnnotation = new HashMap<>(); - for (var annotation : CLIENT_MCP_ANNOTATIONS) { - beansByAnnotation.put(annotation, new HashSet<>()); - } - - for (var beanName : this.allAnnotatedBeans) { - var bean = this.beanFactory.getBean(beanName); - var annotations = scan(bean.getClass()); - for (var annotation : annotations) { - beansByAnnotation.computeIfAbsent(annotation.annotationType(), k -> new HashSet<>()).add(bean); - } - } + var beansByAnnotation = getBeansByAnnotationType(); var samplingSpecs = AsyncMcpAnnotationProviders .samplingSpecifications(new ArrayList<>(beansByAnnotation.get(McpSampling.class))); diff --git a/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistry.java b/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistry.java index b686d8aa180..1f2e0cfa559 100644 --- a/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistry.java +++ b/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistry.java @@ -16,17 +16,12 @@ package org.springframework.ai.mcp.annotation.spring; -import java.lang.annotation.Annotation; import java.util.ArrayList; import java.util.HashMap; -import java.util.HashSet; -import java.util.LinkedHashSet; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.function.Consumer; import java.util.function.Function; -import java.util.stream.Collectors; import io.modelcontextprotocol.spec.McpSchema; import org.springaicommunity.mcp.annotation.McpElicitation; @@ -37,12 +32,7 @@ import org.springaicommunity.mcp.annotation.McpSampling; import org.springaicommunity.mcp.annotation.McpToolListChanged; -import org.springframework.beans.BeansException; import org.springframework.beans.factory.SmartInitializingSingleton; -import org.springframework.beans.factory.config.BeanFactoryPostProcessor; -import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; -import org.springframework.core.annotation.AnnotationUtils; -import org.springframework.util.ReflectionUtils; /** * Registry of methods annotated with MCP Client annotations (sampling, logging, etc.). @@ -69,20 +59,8 @@ * @author Daniel Garnier-Moiroux * @since 1.1.0 */ -public class ClientMcpSyncHandlersRegistry implements BeanFactoryPostProcessor, SmartInitializingSingleton { - - private static final Class[] CLIENT_MCP_ANNOTATIONS = new Class[] { McpSampling.class, - McpElicitation.class, McpLogging.class, McpProgress.class, McpToolListChanged.class, - McpPromptListChanged.class, McpResourceListChanged.class }; - - private final McpSchema.ClientCapabilities EMPTY_CAPABILITIES = new McpSchema.ClientCapabilities(null, null, null, - null); - - private Map capabilitiesPerClient = new HashMap<>(); - - private ConfigurableListableBeanFactory beanFactory; - - private final Set allAnnotatedBeans = new HashSet<>(); +public class ClientMcpSyncHandlersRegistry extends AbstractClientMcpHandlerRegistry + implements SmartInitializingSingleton { private final Map> samplingHandlers = new HashMap<>(); @@ -103,7 +81,7 @@ public class ClientMcpSyncHandlersRegistry implements BeanFactoryPostProcessor, * registered with the {@link McpSampling} and {@link McpElicitation} annotations. */ public McpSchema.ClientCapabilities getCapabilities(String clientName) { - return this.capabilitiesPerClient.getOrDefault(clientName, this.EMPTY_CAPABILITIES); + return this.capabilitiesPerClient.getOrDefault(clientName, EMPTY_CAPABILITIES); } /** @@ -214,90 +192,9 @@ public void handleResourceListChanged(String name, List upda } } - @Override - public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException { - this.beanFactory = beanFactory; - Map> elicitationClientToAnnotatedBeans = new HashMap<>(); - Map> samplingClientToAnnotatedBeans = new HashMap<>(); - for (var beanName : beanFactory.getBeanDefinitionNames()) { - var definition = beanFactory.getBeanDefinition(beanName); - var foundAnnotations = scan(definition.getResolvableType().toClass()); - if (!foundAnnotations.isEmpty()) { - this.allAnnotatedBeans.add(beanName); - } - for (var foundAnnotation : foundAnnotations) { - if (foundAnnotation instanceof McpSampling sampling) { - for (var client : sampling.clients()) { - samplingClientToAnnotatedBeans.computeIfAbsent(client, c -> new ArrayList<>()).add(beanName); - } - } - else if (foundAnnotation instanceof McpElicitation elicitation) { - for (var client : elicitation.clients()) { - elicitationClientToAnnotatedBeans.computeIfAbsent(client, c -> new ArrayList<>()).add(beanName); - } - } - } - } - - for (var elicitationEntry : elicitationClientToAnnotatedBeans.entrySet()) { - if (elicitationEntry.getValue().size() > 1) { - throw new IllegalArgumentException( - "Found 2 elicitation handlers for client [%s], found in bean with names %s. Only one @McpElicitation handler is allowed per client" - .formatted(elicitationEntry.getKey(), new LinkedHashSet<>(elicitationEntry.getValue()))); - } - } - for (var samplingEntry : samplingClientToAnnotatedBeans.entrySet()) { - if (samplingEntry.getValue().size() > 1) { - throw new IllegalArgumentException( - "Found 2 sampling handlers for client [%s], found in bean with names %s. Only one @McpSampling handler is allowed per client" - .formatted(samplingEntry.getKey(), new LinkedHashSet<>(samplingEntry.getValue()))); - } - } - - Map capsPerClient = new HashMap<>(); - for (var samplingClient : samplingClientToAnnotatedBeans.keySet()) { - capsPerClient.computeIfAbsent(samplingClient, ignored -> McpSchema.ClientCapabilities.builder()).sampling(); - } - for (var elicitationClient : elicitationClientToAnnotatedBeans.keySet()) { - capsPerClient.computeIfAbsent(elicitationClient, ignored -> McpSchema.ClientCapabilities.builder()) - .elicitation(); - } - - this.capabilitiesPerClient = capsPerClient.entrySet() - .stream() - .collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().build())); - } - - private List scan(Class beanClass) { - List foundAnnotations = new ArrayList<>(); - - // Scan all methods in the bean class - ReflectionUtils.doWithMethods(beanClass, method -> { - for (var annotationType : CLIENT_MCP_ANNOTATIONS) { - Annotation annotation = AnnotationUtils.findAnnotation(method, annotationType); - if (annotation != null) { - foundAnnotations.add(annotation); - } - } - }); - return foundAnnotations; - } - @Override public void afterSingletonsInstantiated() { - // Use a set in case multiple handlers are registered in the same bean - Map, Set> beansByAnnotation = new HashMap<>(); - for (var annotation : CLIENT_MCP_ANNOTATIONS) { - beansByAnnotation.put(annotation, new HashSet<>()); - } - - for (var beanName : this.allAnnotatedBeans) { - var bean = this.beanFactory.getBean(beanName); - var annotations = scan(bean.getClass()); - for (var annotation : annotations) { - beansByAnnotation.computeIfAbsent(annotation.annotationType(), k -> new HashSet<>()).add(bean); - } - } + var beansByAnnotation = getBeansByAnnotationType(); var samplingSpecs = SyncMcpAnnotationProviders .samplingSpecifications(new ArrayList<>(beansByAnnotation.get(McpSampling.class))); From fb9f4a004ab598f4b8b9678cd6a5800f6f86a164 Mon Sep 17 00:00:00 2001 From: Daniel Garnier-Moiroux Date: Wed, 5 Nov 2025 17:56:25 +0100 Subject: [PATCH 06/11] Find MCP Client annotations on @Component beans Signed-off-by: Daniel Garnier-Moiroux --- .../StreamableMcpAnnotationsWithLLMIT.java | 40 ++---------- .../capabilities/McpHandlerService.java | 64 +++++++++++++++++++ .../AbstractClientMcpHandlerRegistry.java | 37 ++++++++++- .../ClientMcpAsyncHandlersRegistryTests.java | 14 ++++ .../ClientMcpSyncHandlersRegistryTests.java | 12 ++++ 5 files changed, 132 insertions(+), 35 deletions(-) create mode 100644 auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/capabilities/McpHandlerService.java diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotationsWithLLMIT.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotationsWithLLMIT.java index 3799a52ddb9..7358d772ded 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotationsWithLLMIT.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotationsWithLLMIT.java @@ -36,14 +36,11 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springaicommunity.mcp.annotation.McpElicitation; import org.springaicommunity.mcp.annotation.McpLogging; import org.springaicommunity.mcp.annotation.McpProgress; -import org.springaicommunity.mcp.annotation.McpSampling; import org.springaicommunity.mcp.annotation.McpTool; import org.springaicommunity.mcp.annotation.McpToolParam; import org.springaicommunity.mcp.context.McpSyncRequestContext; -import org.springaicommunity.mcp.context.StructuredElicitResult; import reactor.netty.DisposableServer; import reactor.netty.http.server.HttpServer; @@ -52,6 +49,7 @@ import org.springframework.ai.mcp.client.common.autoconfigure.McpToolCallbackAutoConfiguration; import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientAnnotationScannerAutoConfiguration; import org.springframework.ai.mcp.client.webflux.autoconfigure.StreamableHttpWebFluxTransportAutoConfiguration; +import org.springframework.ai.mcp.server.autoconfigure.capabilities.McpHandlerService; import org.springframework.ai.mcp.server.common.autoconfigure.McpServerAutoConfiguration; import org.springframework.ai.mcp.server.common.autoconfigure.McpServerObjectMapperAutoConfiguration; import org.springframework.ai.mcp.server.common.autoconfigure.ToolCallbackConverterAutoConfiguration; @@ -70,6 +68,7 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.ComponentScan; import org.springframework.http.server.reactive.HttpHandler; import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; import org.springframework.test.util.TestSocketUtils; @@ -220,9 +219,6 @@ private static void stopHttpServer(DisposableServer server) { } } - record ElicitInput(String message) { - } - public static class TestMcpServerConfiguration { @Bean @@ -244,7 +240,7 @@ public String weather(McpSyncRequestContext ctx, @McpToolParam String cityName) ctx.ping(); // call client ping // call elicitation - var elicitationResult = ctx.elicit(e -> e.message("Test message"), ElicitInput.class); + var elicitationResult = ctx.elicit(e -> e.message("Test message"), McpHandlerService.ElicitInput.class); ctx.progress(p -> p.progress(0.50).total(1.0).message("elicitation completed")); @@ -285,18 +281,16 @@ public static class TestContext { } + // We also include scanned beans, because those are registered differently. + @ComponentScan(basePackageClasses = McpHandlerService.class) public static class TestMcpClientHandlers { private static final Logger logger = LoggerFactory.getLogger(TestMcpClientHandlers.class); - private final ChatClient client; - private TestMcpClientConfiguration.TestContext testContext; - public TestMcpClientHandlers(TestMcpClientConfiguration.TestContext testContext, - ChatClient.Builder clientBuilder) { + public TestMcpClientHandlers(TestMcpClientConfiguration.TestContext testContext) { this.testContext = testContext; - this.client = clientBuilder.build(); } @McpProgress(clients = "server1") @@ -313,28 +307,6 @@ public void loggingHandler(McpSchema.LoggingMessageNotification loggingMessage) logger.info("MCP LOGGING: [{}] {}", loggingMessage.level(), loggingMessage.data()); } - @McpSampling(clients = "server1") - public McpSchema.CreateMessageResult samplingHandler(McpSchema.CreateMessageRequest llmRequest) { - logger.info("MCP SAMPLING: {}", llmRequest); - - String userPrompt = ((McpSchema.TextContent) llmRequest.messages().get(0).content()).text(); - String modelHint = llmRequest.modelPreferences().hints().get(0).name(); - // In a real use-case, we would use the chat client to call the LLM again - logger.info("MCP SAMPLING: simulating using chat client {}", this.client); - - return McpSchema.CreateMessageResult.builder() - .content(new McpSchema.TextContent("Response " + userPrompt + " with model hint " + modelHint)) - .build(); - } - - @McpElicitation(clients = "server1") - public StructuredElicitResult elicitationHandler(McpSchema.ElicitRequest request) { - logger.info("MCP ELICITATION: {}", request); - StreamableMcpAnnotationsWithLLMIT.ElicitInput elicitData = new StreamableMcpAnnotationsWithLLMIT.ElicitInput( - request.message()); - return StructuredElicitResult.builder().structuredContent(elicitData).build(); - } - } } diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/capabilities/McpHandlerService.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/capabilities/McpHandlerService.java new file mode 100644 index 00000000000..987d29ec317 --- /dev/null +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/capabilities/McpHandlerService.java @@ -0,0 +1,64 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mcp.server.autoconfigure.capabilities; + +import io.modelcontextprotocol.spec.McpSchema; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springaicommunity.mcp.annotation.McpElicitation; +import org.springaicommunity.mcp.annotation.McpSampling; +import org.springaicommunity.mcp.context.StructuredElicitResult; + +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.stereotype.Service; + +@Service +public class McpHandlerService { + + private static final Logger logger = LoggerFactory.getLogger(McpHandlerService.class); + + private final ChatClient client; + + public McpHandlerService(ChatClient.Builder chatClientBuilder) { + this.client = chatClientBuilder.build(); + } + + @McpSampling(clients = "server1") + public McpSchema.CreateMessageResult samplingHandler(McpSchema.CreateMessageRequest llmRequest) { + logger.info("MCP SAMPLING: {}", llmRequest); + + String userPrompt = ((McpSchema.TextContent) llmRequest.messages().get(0).content()).text(); + String modelHint = llmRequest.modelPreferences().hints().get(0).name(); + // In a real use-case, we would use the chat client to call the LLM again + logger.info("MCP SAMPLING: simulating using chat client {}", this.client); + + return McpSchema.CreateMessageResult.builder() + .content(new McpSchema.TextContent("Response " + userPrompt + " with model hint " + modelHint)) + .build(); + } + + @McpElicitation(clients = "server1") + public StructuredElicitResult elicitationHandler(McpSchema.ElicitRequest request) { + logger.info("MCP ELICITATION: {}", request); + ElicitInput elicitData = new ElicitInput(request.message()); + return StructuredElicitResult.builder().structuredContent(elicitData).build(); + } + + public record ElicitInput(String message) { + } + +} diff --git a/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/AbstractClientMcpHandlerRegistry.java b/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/AbstractClientMcpHandlerRegistry.java index c662536173c..97126531f50 100644 --- a/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/AbstractClientMcpHandlerRegistry.java +++ b/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/AbstractClientMcpHandlerRegistry.java @@ -1,3 +1,19 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.mcp.annotation.spring; import java.lang.annotation.Annotation; @@ -20,6 +36,7 @@ import org.springaicommunity.mcp.annotation.McpToolListChanged; import org.springframework.beans.BeansException; +import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanFactoryPostProcessor; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.core.annotation.AnnotationUtils; @@ -28,6 +45,7 @@ /** * Base class for sync and async ClientMcpHandlerRegistries. Not intended for public use. * + * @author Daniel Garnier-Moiroux * @see ClientMcpAsyncHandlersRegistry * @see ClientMcpSyncHandlersRegistry */ @@ -53,7 +71,7 @@ public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) Map> samplingClientToAnnotatedBeans = new HashMap<>(); for (var beanName : beanFactory.getBeanDefinitionNames()) { var definition = beanFactory.getBeanDefinition(beanName); - var foundAnnotations = scan(definition.getResolvableType().toClass()); + var foundAnnotations = scan(getBeanClass(definition, beanFactory.getBeanClassLoader())); if (!foundAnnotations.isEmpty()) { this.allAnnotatedBeans.add(beanName); } @@ -100,6 +118,23 @@ else if (foundAnnotation instanceof McpElicitation elicitation) { .collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().build())); } + private static Class getBeanClass(BeanDefinition definition, ClassLoader beanClassLoader) { + if (definition.getResolvableType().resolve() != null) { + return definition.getResolvableType().resolve(); + } + // @Component beans registered by component scanning do not have a resolvable type + // We try to resolve them using the beanClassName (which might be null) + if (beanClassLoader != null && definition.getBeanClassName() != null) { + try { + return Class.forName(definition.getBeanClassName(), false, beanClassLoader); + } + catch (ClassNotFoundException ignored) { + + } + } + return null; + } + protected List scan(Class beanClass) { List foundAnnotations = new ArrayList<>(); diff --git a/mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistryTests.java b/mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistryTests.java index 635a7d06605..64c9e5b8489 100644 --- a/mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistryTests.java +++ b/mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistryTests.java @@ -266,6 +266,20 @@ void resourceListChanged() { new HandlersConfiguration.Call("handleResourceListChangedAgain", updatedResources)); } + @Test + void supportsNonResolvableTypes() { + var registry = new ClientMcpSyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder + .genericBeanDefinition( + ClientMcpSyncHandlersRegistryTests.ClientCapabilitiesConfiguration.class.getName()) + .getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + + assertThat(registry.getCapabilities("client-1").elicitation()).isNotNull(); + } + @Test @Disabled void missingHandler() { diff --git a/mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistryTests.java b/mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistryTests.java index 8986434b3fb..e7e9527adfd 100644 --- a/mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistryTests.java +++ b/mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistryTests.java @@ -264,6 +264,18 @@ void resourceListChanged() { new HandlersConfiguration.Call("handleResourceListChangedAgain", updatedResources)); } + @Test + void supportsNonResolvableTypes() { + var registry = new ClientMcpSyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder.genericBeanDefinition(ClientCapabilitiesConfiguration.class.getName()) + .getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + + assertThat(registry.getCapabilities("client-1").elicitation()).isNotNull(); + } + @Test @Disabled void missingHandler() { From 3bbe4571f163a968edf2d005bcae07c63bdd2907 Mon Sep 17 00:00:00 2001 From: Daniel Garnier-Moiroux Date: Thu, 6 Nov 2025 10:06:24 +0100 Subject: [PATCH 07/11] AbstractClientMcpHandlerRegistry also discovers proxied beans - Remove custom class resolution method and use AutoProxyUtils instead Signed-off-by: Daniel Garnier-Moiroux --- .../StreamableMcpAnnotationsWithLLMIT.java | 4 +- .../capabilities/McpHandlerConfiguration.java | 77 +++++++++++++++++++ .../capabilities/McpHandlerService.java | 12 --- .../AbstractClientMcpHandlerRegistry.java | 28 ++----- .../ClientMcpAsyncHandlersRegistry.java | 2 +- .../spring/ClientMcpSyncHandlersRegistry.java | 2 +- .../ClientMcpAsyncHandlersRegistryTests.java | 15 ++++ .../ClientMcpSyncHandlersRegistryTests.java | 15 ++++ 8 files changed, 119 insertions(+), 36 deletions(-) create mode 100644 auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/capabilities/McpHandlerConfiguration.java diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotationsWithLLMIT.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotationsWithLLMIT.java index 7358d772ded..5b224a154b7 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotationsWithLLMIT.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotationsWithLLMIT.java @@ -49,6 +49,7 @@ import org.springframework.ai.mcp.client.common.autoconfigure.McpToolCallbackAutoConfiguration; import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientAnnotationScannerAutoConfiguration; import org.springframework.ai.mcp.client.webflux.autoconfigure.StreamableHttpWebFluxTransportAutoConfiguration; +import org.springframework.ai.mcp.server.autoconfigure.capabilities.McpHandlerConfiguration; import org.springframework.ai.mcp.server.autoconfigure.capabilities.McpHandlerService; import org.springframework.ai.mcp.server.common.autoconfigure.McpServerAutoConfiguration; import org.springframework.ai.mcp.server.common.autoconfigure.McpServerObjectMapperAutoConfiguration; @@ -240,7 +241,8 @@ public String weather(McpSyncRequestContext ctx, @McpToolParam String cityName) ctx.ping(); // call client ping // call elicitation - var elicitationResult = ctx.elicit(e -> e.message("Test message"), McpHandlerService.ElicitInput.class); + var elicitationResult = ctx.elicit(e -> e.message("Test message"), + McpHandlerConfiguration.ElicitInput.class); ctx.progress(p -> p.progress(0.50).total(1.0).message("elicitation completed")); diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/capabilities/McpHandlerConfiguration.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/capabilities/McpHandlerConfiguration.java new file mode 100644 index 00000000000..1b85779c9a3 --- /dev/null +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/capabilities/McpHandlerConfiguration.java @@ -0,0 +1,77 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mcp.server.autoconfigure.capabilities; + +import io.modelcontextprotocol.spec.McpSchema; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springaicommunity.mcp.annotation.McpElicitation; +import org.springaicommunity.mcp.context.StructuredElicitResult; + +import org.springframework.beans.factory.config.ConfigurableBeanFactory; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Scope; +import org.springframework.web.context.annotation.RequestScope; + +@Configuration +public class McpHandlerConfiguration { + + private static final Logger logger = LoggerFactory.getLogger(McpHandlerConfiguration.class); + + @Bean + ElicitationHandler elicitationHandler() { + return new ElicitationHandler(); + } + + // Ensure that we don't blow up on non-singleton beans + @Bean + @Scope(scopeName = ConfigurableBeanFactory.SCOPE_PROTOTYPE) + Foo foo() { + return new Foo(); + } + + // Ensure that we don't blow up on non-singleton beans + @Bean + @RequestScope + Bar bar(Foo foo) { + return new Bar(); + } + + record ElicitationHandler() { + + @McpElicitation(clients = "server1") + public StructuredElicitResult elicitationHandler(McpSchema.ElicitRequest request) { + logger.info("MCP ELICITATION: {}", request); + ElicitInput elicitData = new ElicitInput(request.message()); + return StructuredElicitResult.builder().structuredContent(elicitData).build(); + } + + } + + public record ElicitInput(String message) { + } + + public static class Foo { + + } + + public static class Bar { + + } + +} diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/capabilities/McpHandlerService.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/capabilities/McpHandlerService.java index 987d29ec317..d9e872a60b2 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/capabilities/McpHandlerService.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/capabilities/McpHandlerService.java @@ -19,9 +19,7 @@ import io.modelcontextprotocol.spec.McpSchema; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springaicommunity.mcp.annotation.McpElicitation; import org.springaicommunity.mcp.annotation.McpSampling; -import org.springaicommunity.mcp.context.StructuredElicitResult; import org.springframework.ai.chat.client.ChatClient; import org.springframework.stereotype.Service; @@ -51,14 +49,4 @@ public McpSchema.CreateMessageResult samplingHandler(McpSchema.CreateMessageRequ .build(); } - @McpElicitation(clients = "server1") - public StructuredElicitResult elicitationHandler(McpSchema.ElicitRequest request) { - logger.info("MCP ELICITATION: {}", request); - ElicitInput elicitData = new ElicitInput(request.message()); - return StructuredElicitResult.builder().structuredContent(elicitData).build(); - } - - public record ElicitInput(String message) { - } - } diff --git a/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/AbstractClientMcpHandlerRegistry.java b/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/AbstractClientMcpHandlerRegistry.java index 97126531f50..c1d5c84017e 100644 --- a/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/AbstractClientMcpHandlerRegistry.java +++ b/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/AbstractClientMcpHandlerRegistry.java @@ -35,8 +35,8 @@ import org.springaicommunity.mcp.annotation.McpSampling; import org.springaicommunity.mcp.annotation.McpToolListChanged; +import org.springframework.aop.framework.autoproxy.AutoProxyUtils; import org.springframework.beans.BeansException; -import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanFactoryPostProcessor; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.core.annotation.AnnotationUtils; @@ -70,8 +70,11 @@ public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) Map> elicitationClientToAnnotatedBeans = new HashMap<>(); Map> samplingClientToAnnotatedBeans = new HashMap<>(); for (var beanName : beanFactory.getBeanDefinitionNames()) { - var definition = beanFactory.getBeanDefinition(beanName); - var foundAnnotations = scan(getBeanClass(definition, beanFactory.getBeanClassLoader())); + if (!beanFactory.getBeanDefinition(beanName).isSingleton()) { + // Only process singleton beans, not scoped beans + continue; + } + var foundAnnotations = this.scan(AutoProxyUtils.determineTargetClass(beanFactory, beanName)); if (!foundAnnotations.isEmpty()) { this.allAnnotatedBeans.add(beanName); } @@ -118,23 +121,6 @@ else if (foundAnnotation instanceof McpElicitation elicitation) { .collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().build())); } - private static Class getBeanClass(BeanDefinition definition, ClassLoader beanClassLoader) { - if (definition.getResolvableType().resolve() != null) { - return definition.getResolvableType().resolve(); - } - // @Component beans registered by component scanning do not have a resolvable type - // We try to resolve them using the beanClassName (which might be null) - if (beanClassLoader != null && definition.getBeanClassName() != null) { - try { - return Class.forName(definition.getBeanClassName(), false, beanClassLoader); - } - catch (ClassNotFoundException ignored) { - - } - } - return null; - } - protected List scan(Class beanClass) { List foundAnnotations = new ArrayList<>(); @@ -159,7 +145,7 @@ protected Map, Set> getBeansByAnnotationType for (var beanName : this.allAnnotatedBeans) { var bean = this.beanFactory.getBean(beanName); - var annotations = scan(bean.getClass()); + var annotations = this.scan(bean.getClass()); for (var annotation : annotations) { beansByAnnotation.computeIfAbsent(annotation.annotationType(), k -> new HashSet<>()).add(bean); } diff --git a/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistry.java b/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistry.java index 768acce647f..d4c62c26979 100644 --- a/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistry.java +++ b/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistry.java @@ -186,7 +186,7 @@ public Mono handleResourceListChanged(String name, List(beansByAnnotation.get(McpSampling.class))); diff --git a/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistry.java b/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistry.java index 1f2e0cfa559..623b2f9cb16 100644 --- a/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistry.java +++ b/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistry.java @@ -194,7 +194,7 @@ public void handleResourceListChanged(String name, List upda @Override public void afterSingletonsInstantiated() { - var beansByAnnotation = getBeansByAnnotationType(); + var beansByAnnotation = this.getBeansByAnnotationType(); var samplingSpecs = SyncMcpAnnotationProviders .samplingSpecifications(new ArrayList<>(beansByAnnotation.get(McpSampling.class))); diff --git a/mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistryTests.java b/mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistryTests.java index 64c9e5b8489..e6a93caa64e 100644 --- a/mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistryTests.java +++ b/mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistryTests.java @@ -33,6 +33,7 @@ import org.springaicommunity.mcp.annotation.McpToolListChanged; import reactor.core.publisher.Mono; +import org.springframework.aop.framework.autoproxy.AutoProxyUtils; import org.springframework.beans.factory.support.BeanDefinitionBuilder; import org.springframework.beans.factory.support.DefaultListableBeanFactory; @@ -280,6 +281,20 @@ void supportsNonResolvableTypes() { assertThat(registry.getCapabilities("client-1").elicitation()).isNotNull(); } + @Test + void supportsProxiedClass() { + var registry = new ClientMcpSyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + var beanDefinition = BeanDefinitionBuilder.genericBeanDefinition(Object.class).getBeanDefinition(); + beanDefinition.setAttribute(AutoProxyUtils.ORIGINAL_TARGET_CLASS_ATTRIBUTE, + ClientMcpSyncHandlersRegistryTests.ClientCapabilitiesConfiguration.class); + beanFactory.registerBeanDefinition("myConfig", beanDefinition); + + registry.postProcessBeanFactory(beanFactory); + + assertThat(registry.getCapabilities("client-1").elicitation()).isNotNull(); + } + @Test @Disabled void missingHandler() { diff --git a/mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistryTests.java b/mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistryTests.java index e7e9527adfd..967249bc4fc 100644 --- a/mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistryTests.java +++ b/mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistryTests.java @@ -32,6 +32,7 @@ import org.springaicommunity.mcp.annotation.McpSampling; import org.springaicommunity.mcp.annotation.McpToolListChanged; +import org.springframework.aop.framework.autoproxy.AutoProxyUtils; import org.springframework.beans.factory.support.BeanDefinitionBuilder; import org.springframework.beans.factory.support.DefaultListableBeanFactory; @@ -276,6 +277,20 @@ void supportsNonResolvableTypes() { assertThat(registry.getCapabilities("client-1").elicitation()).isNotNull(); } + @Test + void supportsProxiedClass() { + var registry = new ClientMcpSyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + var beanDefinition = BeanDefinitionBuilder.genericBeanDefinition(Object.class).getBeanDefinition(); + beanDefinition.setAttribute(AutoProxyUtils.ORIGINAL_TARGET_CLASS_ATTRIBUTE, + ClientCapabilitiesConfiguration.class); + beanFactory.registerBeanDefinition("myConfig", beanDefinition); + + registry.postProcessBeanFactory(beanFactory); + + assertThat(registry.getCapabilities("client-1").elicitation()).isNotNull(); + } + @Test @Disabled void missingHandler() { From 95865c83f2f7e82cbc520483235f1e6ea986b6bd Mon Sep 17 00:00:00 2001 From: Daniel Garnier-Moiroux Date: Thu, 6 Nov 2025 11:54:08 +0100 Subject: [PATCH 08/11] Add logging to MCP handlers registry Signed-off-by: Daniel Garnier-Moiroux --- .../ClientMcpAsyncHandlersRegistry.java | 18 +++++++++++ .../spring/ClientMcpSyncHandlersRegistry.java | 30 +++++++++++++++---- 2 files changed, 43 insertions(+), 5 deletions(-) diff --git a/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistry.java b/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistry.java index d4c62c26979..bb63a5d6732 100644 --- a/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistry.java +++ b/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistry.java @@ -23,6 +23,8 @@ import java.util.function.Function; import io.modelcontextprotocol.spec.McpSchema; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.springaicommunity.mcp.annotation.McpElicitation; import org.springaicommunity.mcp.annotation.McpLogging; import org.springaicommunity.mcp.annotation.McpProgress; @@ -63,6 +65,8 @@ public class ClientMcpAsyncHandlersRegistry extends AbstractClientMcpHandlerRegistry implements SmartInitializingSingleton { + private static final Logger logger = LoggerFactory.getLogger(ClientMcpAsyncHandlersRegistry.class); + private final Map>> samplingHandlers = new HashMap<>(); private final Map>> elicitationHandlers = new HashMap<>(); @@ -92,6 +96,7 @@ public McpSchema.ClientCapabilities getCapabilities(String clientName) { */ public Mono handleSampling(String name, McpSchema.CreateMessageRequest samplingRequest) { + logger.debug("Handling sampling request for client {}", name); var handler = this.samplingHandlers.get(name); if (handler != null) { return handler.apply(samplingRequest); @@ -106,6 +111,7 @@ public Mono handleSampling(String name, * @see McpElicitation */ public Mono handleElicitation(String name, McpSchema.ElicitRequest elicitationRequest) { + logger.debug("Handling elicitation request for client {}", name); var handler = this.elicitationHandlers.get(name); if (handler != null) { return handler.apply(elicitationRequest); @@ -120,6 +126,7 @@ public Mono handleElicitation(String name, McpSchema.Eli * @see McpLogging */ public Mono handleLogging(String name, McpSchema.LoggingMessageNotification loggingMessageNotification) { + logger.debug("Handling logging notification for client {}", name); var consumers = this.loggingHandlers.get(name); if (consumers == null) { // TODO handle @@ -134,6 +141,7 @@ public Mono handleLogging(String name, McpSchema.LoggingMessageNotificatio * @see McpProgress */ public Mono handleProgress(String name, McpSchema.ProgressNotification progressNotification) { + logger.debug("Handling progress notification for client {}", name); var consumers = this.progressHandlers.get(name); if (consumers == null) { // TODO handle @@ -148,6 +156,7 @@ public Mono handleProgress(String name, McpSchema.ProgressNotification pro * @see McpToolListChanged */ public Mono handleToolListChanged(String name, List updatedTools) { + logger.debug("Handling tool list changed notification for client {}", name); var consumers = this.toolListChangedHandlers.get(name); if (consumers == null) { // TODO handle @@ -162,6 +171,7 @@ public Mono handleToolListChanged(String name, List update * @see McpPromptListChanged */ public Mono handlePromptListChanged(String name, List updatedPrompts) { + logger.debug("Handling prompt list changed notification for client {}", name); var consumers = this.promptListChangedHandlers.get(name); if (consumers == null) { // TODO handle @@ -176,6 +186,7 @@ public Mono handlePromptListChanged(String name, List up * @see McpResourceListChanged */ public Mono handleResourceListChanged(String name, List updatedResources) { + logger.debug("Handling resource list changed notification for client {}", name); var consumers = this.resourceListChangedHandlers.get(name); if (consumers == null) { // TODO handle @@ -192,6 +203,7 @@ public void afterSingletonsInstantiated() { .samplingSpecifications(new ArrayList<>(beansByAnnotation.get(McpSampling.class))); for (var samplingSpec : samplingSpecs) { for (var client : samplingSpec.clients()) { + logger.debug("Registering sampling handler for {}", client); this.samplingHandlers.put(client, samplingSpec.samplingHandler()); } } @@ -200,6 +212,7 @@ public void afterSingletonsInstantiated() { .elicitationSpecifications(new ArrayList<>(beansByAnnotation.get(McpElicitation.class))); for (var elicitationSpec : elicitationSpecs) { for (var client : elicitationSpec.clients()) { + logger.debug("Registering elicitation handler for {}", client); this.elicitationHandlers.put(client, elicitationSpec.elicitationHandler()); } } @@ -208,6 +221,7 @@ public void afterSingletonsInstantiated() { .loggingSpecifications(new ArrayList<>(beansByAnnotation.get(McpLogging.class))); for (var loggingSpec : loggingSpecs) { for (var client : loggingSpec.clients()) { + logger.debug("Registering logging handler for {}", client); this.loggingHandlers.computeIfAbsent(client, k -> new ArrayList<>()).add(loggingSpec.loggingHandler()); } } @@ -216,6 +230,7 @@ public void afterSingletonsInstantiated() { .progressSpecifications(new ArrayList<>(beansByAnnotation.get(McpProgress.class))); for (var progressSpec : progressSpecs) { for (var client : progressSpec.clients()) { + logger.debug("Registering progress handler for {}", client); this.progressHandlers.computeIfAbsent(client, k -> new ArrayList<>()) .add(progressSpec.progressHandler()); } @@ -225,6 +240,7 @@ public void afterSingletonsInstantiated() { .toolListChangedSpecifications(new ArrayList<>(beansByAnnotation.get(McpToolListChanged.class))); for (var toolsListChangedSpec : toolsListChangedSpecs) { for (var client : toolsListChangedSpec.clients()) { + logger.debug("Registering tool list changed handler for {}", client); this.toolListChangedHandlers.computeIfAbsent(client, k -> new ArrayList<>()) .add(toolsListChangedSpec.toolListChangeHandler()); } @@ -234,6 +250,7 @@ public void afterSingletonsInstantiated() { .promptListChangedSpecifications(new ArrayList<>(beansByAnnotation.get(McpPromptListChanged.class))); for (var promptListChangedSpec : promptListChangedSpecs) { for (var client : promptListChangedSpec.clients()) { + logger.debug("Registering prompt list changed handler for {}", client); this.promptListChangedHandlers.computeIfAbsent(client, k -> new ArrayList<>()) .add(promptListChangedSpec.promptListChangeHandler()); } @@ -243,6 +260,7 @@ public void afterSingletonsInstantiated() { .resourceListChangedSpecifications(new ArrayList<>(beansByAnnotation.get(McpResourceListChanged.class))); for (var resourceListChangedSpec : resourceListChangedSpecs) { for (var client : resourceListChangedSpec.clients()) { + logger.debug("Registering resource list changed handler for {}", client); this.resourceListChangedHandlers.computeIfAbsent(client, k -> new ArrayList<>()) .add(resourceListChangedSpec.resourceListChangeHandler()); } diff --git a/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistry.java b/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistry.java index 623b2f9cb16..aec73a88a86 100644 --- a/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistry.java +++ b/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistry.java @@ -24,6 +24,8 @@ import java.util.function.Function; import io.modelcontextprotocol.spec.McpSchema; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.springaicommunity.mcp.annotation.McpElicitation; import org.springaicommunity.mcp.annotation.McpLogging; import org.springaicommunity.mcp.annotation.McpProgress; @@ -62,6 +64,8 @@ public class ClientMcpSyncHandlersRegistry extends AbstractClientMcpHandlerRegistry implements SmartInitializingSingleton { + private static final Logger logger = LoggerFactory.getLogger(ClientMcpSyncHandlersRegistry.class); + private final Map> samplingHandlers = new HashMap<>(); private final Map> elicitationHandlers = new HashMap<>(); @@ -90,6 +94,8 @@ public McpSchema.ClientCapabilities getCapabilities(String clientName) { * @see McpSampling */ public McpSchema.CreateMessageResult handleSampling(String name, McpSchema.CreateMessageRequest samplingRequest) { + logger.debug("Handling sampling request for client {}", name); + var handler = this.samplingHandlers.get(name); if (handler != null) { return handler.apply(samplingRequest); @@ -104,6 +110,8 @@ public McpSchema.CreateMessageResult handleSampling(String name, McpSchema.Creat * @see McpElicitation */ public McpSchema.ElicitResult handleElicitation(String name, McpSchema.ElicitRequest elicitationRequest) { + logger.debug("Handling elicitation request for client {}", name); + var handler = this.elicitationHandlers.get(name); if (handler != null) { return handler.apply(elicitationRequest); @@ -118,9 +126,10 @@ public McpSchema.ElicitResult handleElicitation(String name, McpSchema.ElicitReq * @see McpLogging */ public void handleLogging(String name, McpSchema.LoggingMessageNotification loggingMessageNotification) { + logger.debug("Handling logging notification for client {}", name); + var consumers = this.loggingHandlers.get(name); if (consumers == null) { - // TODO handle return; } for (var consumer : consumers) { @@ -134,9 +143,10 @@ public void handleLogging(String name, McpSchema.LoggingMessageNotification logg * @see McpProgress */ public void handleProgress(String name, McpSchema.ProgressNotification progressNotification) { + logger.debug("Handling progress notification for client {}", name); + var consumers = this.progressHandlers.get(name); if (consumers == null) { - // TODO handle return; } for (var consumer : consumers) { @@ -150,9 +160,10 @@ public void handleProgress(String name, McpSchema.ProgressNotification progressN * @see McpToolListChanged */ public void handleToolListChanged(String name, List updatedTools) { + logger.debug("Handling tool list changed notification for client {}", name); + var consumers = this.toolListChangedHandlers.get(name); if (consumers == null) { - // TODO handle return; } for (var consumer : consumers) { @@ -166,9 +177,10 @@ public void handleToolListChanged(String name, List updatedTools * @see McpPromptListChanged */ public void handlePromptListChanged(String name, List updatedPrompts) { + logger.debug("Handling prompt list changed notification for client {}", name); + var consumers = this.promptListChangedHandlers.get(name); if (consumers == null) { - // TODO handle return; } for (var consumer : consumers) { @@ -182,9 +194,10 @@ public void handlePromptListChanged(String name, List updatedP * @see McpResourceListChanged */ public void handleResourceListChanged(String name, List updatedResources) { + logger.debug("Handling resource list changed notification for client {}", name); + var consumers = this.resourceListChangedHandlers.get(name); if (consumers == null) { - // TODO handle return; } for (var consumer : consumers) { @@ -200,6 +213,7 @@ public void afterSingletonsInstantiated() { .samplingSpecifications(new ArrayList<>(beansByAnnotation.get(McpSampling.class))); for (var samplingSpec : samplingSpecs) { for (var client : samplingSpec.clients()) { + logger.debug("Registering sampling handler for {}", client); this.samplingHandlers.put(client, samplingSpec.samplingHandler()); } } @@ -208,6 +222,7 @@ public void afterSingletonsInstantiated() { .elicitationSpecifications(new ArrayList<>(beansByAnnotation.get(McpElicitation.class))); for (var elicitationSpec : elicitationSpecs) { for (var client : elicitationSpec.clients()) { + logger.debug("Registering elicitation handler for {}", client); this.elicitationHandlers.put(client, elicitationSpec.elicitationHandler()); } } @@ -216,6 +231,7 @@ public void afterSingletonsInstantiated() { .loggingSpecifications(new ArrayList<>(beansByAnnotation.get(McpLogging.class))); for (var loggingSpec : loggingSpecs) { for (var client : loggingSpec.clients()) { + logger.debug("Registering logging handler for {}", client); this.loggingHandlers.computeIfAbsent(client, k -> new ArrayList<>()).add(loggingSpec.loggingHandler()); } } @@ -224,6 +240,7 @@ public void afterSingletonsInstantiated() { .progressSpecifications(new ArrayList<>(beansByAnnotation.get(McpProgress.class))); for (var progressSpec : progressSpecs) { for (var client : progressSpec.clients()) { + logger.debug("Registering progress handler for {}", client); this.progressHandlers.computeIfAbsent(client, k -> new ArrayList<>()) .add(progressSpec.progressHandler()); } @@ -233,6 +250,7 @@ public void afterSingletonsInstantiated() { .toolListChangedSpecifications(new ArrayList<>(beansByAnnotation.get(McpToolListChanged.class))); for (var toolsListChangedSpec : toolsListChangedSpecs) { for (var client : toolsListChangedSpec.clients()) { + logger.debug("Registering tool list changed handler for {}", client); this.toolListChangedHandlers.computeIfAbsent(client, k -> new ArrayList<>()) .add(toolsListChangedSpec.toolListChangeHandler()); } @@ -242,6 +260,7 @@ public void afterSingletonsInstantiated() { .promptListChangedSpecifications(new ArrayList<>(beansByAnnotation.get(McpPromptListChanged.class))); for (var promptListChangedSpec : promptListChangedSpecs) { for (var client : promptListChangedSpec.clients()) { + logger.debug("Registering prompt list changed handler for {}", client); this.promptListChangedHandlers.computeIfAbsent(client, k -> new ArrayList<>()) .add(promptListChangedSpec.promptListChangeHandler()); } @@ -251,6 +270,7 @@ public void afterSingletonsInstantiated() { .resourceListChangedSpecifications(new ArrayList<>(beansByAnnotation.get(McpResourceListChanged.class))); for (var resourceListChangedSpec : resourceListChangedSpecs) { for (var client : resourceListChangedSpec.clients()) { + logger.debug("Registering resource list changed handler for {}", client); this.resourceListChangedHandlers.computeIfAbsent(client, k -> new ArrayList<>()) .add(resourceListChangedSpec.resourceListChangeHandler()); } From fa6a1f47b5aa60c0e5e3310b3b82f32f4657282c Mon Sep 17 00:00:00 2001 From: Daniel Garnier-Moiroux Date: Thu, 6 Nov 2025 12:23:57 +0100 Subject: [PATCH 09/11] Throw MCP Error on missing sampling and elicitation handlers in a client Signed-off-by: Daniel Garnier-Moiroux --- .../ClientMcpAsyncHandlersRegistry.java | 14 ++--- .../spring/ClientMcpSyncHandlersRegistry.java | 9 +-- .../ClientMcpAsyncHandlersRegistryTests.java | 55 ++++++++++++++++--- .../ClientMcpSyncHandlersRegistryTests.java | 51 ++++++++++++++--- 4 files changed, 100 insertions(+), 29 deletions(-) diff --git a/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistry.java b/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistry.java index bb63a5d6732..536f2eec7c1 100644 --- a/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistry.java +++ b/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistry.java @@ -22,6 +22,7 @@ import java.util.Map; import java.util.function.Function; +import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -101,8 +102,8 @@ public Mono handleSampling(String name, if (handler != null) { return handler.apply(samplingRequest); } - // TODO: handle null - return Mono.empty(); + return Mono.error(new McpError(new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.METHOD_NOT_FOUND, + "Sampling not supported", Map.of("reason", "Client does not have sampling capability")))); } /** @@ -116,8 +117,8 @@ public Mono handleElicitation(String name, McpSchema.Eli if (handler != null) { return handler.apply(elicitationRequest); } - // TODO: handle null - return Mono.empty(); + return Mono.error(new McpError(new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.METHOD_NOT_FOUND, + "Elicitation not supported", Map.of("reason", "Client does not have elicitation capability")))); } /** @@ -129,7 +130,6 @@ public Mono handleLogging(String name, McpSchema.LoggingMessageNotificatio logger.debug("Handling logging notification for client {}", name); var consumers = this.loggingHandlers.get(name); if (consumers == null) { - // TODO handle return Mono.empty(); } return Flux.fromIterable(consumers).flatMap(c -> c.apply(loggingMessageNotification)).then(); @@ -144,7 +144,6 @@ public Mono handleProgress(String name, McpSchema.ProgressNotification pro logger.debug("Handling progress notification for client {}", name); var consumers = this.progressHandlers.get(name); if (consumers == null) { - // TODO handle return Mono.empty(); } return Flux.fromIterable(consumers).flatMap(c -> c.apply(progressNotification)).then(); @@ -159,7 +158,6 @@ public Mono handleToolListChanged(String name, List update logger.debug("Handling tool list changed notification for client {}", name); var consumers = this.toolListChangedHandlers.get(name); if (consumers == null) { - // TODO handle return Mono.empty(); } return Flux.fromIterable(consumers).flatMap(c -> c.apply(updatedTools)).then(); @@ -174,7 +172,6 @@ public Mono handlePromptListChanged(String name, List up logger.debug("Handling prompt list changed notification for client {}", name); var consumers = this.promptListChangedHandlers.get(name); if (consumers == null) { - // TODO handle return Mono.empty(); } return Flux.fromIterable(consumers).flatMap(c -> c.apply(updatedPrompts)).then(); @@ -189,7 +186,6 @@ public Mono handleResourceListChanged(String name, List c.apply(updatedResources)).then(); diff --git a/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistry.java b/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistry.java index aec73a88a86..36a4d63fa14 100644 --- a/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistry.java +++ b/mcp/mcp-annotations-spring/src/main/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistry.java @@ -23,6 +23,7 @@ import java.util.function.Consumer; import java.util.function.Function; +import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -100,8 +101,8 @@ public McpSchema.CreateMessageResult handleSampling(String name, McpSchema.Creat if (handler != null) { return handler.apply(samplingRequest); } - // TODO: handle null - return null; + throw new McpError(new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.METHOD_NOT_FOUND, + "Sampling not supported", Map.of("reason", "Client does not have sampling capability"))); } /** @@ -116,8 +117,8 @@ public McpSchema.ElicitResult handleElicitation(String name, McpSchema.ElicitReq if (handler != null) { return handler.apply(elicitationRequest); } - // TODO: handle null - return null; + throw new McpError(new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.METHOD_NOT_FOUND, + "Elicitation not supported", Map.of("reason", "Client does not have elicitation capability"))); } /** diff --git a/mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistryTests.java b/mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistryTests.java index e6a93caa64e..6e7300bdd6f 100644 --- a/mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistryTests.java +++ b/mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpAsyncHandlersRegistryTests.java @@ -21,8 +21,8 @@ import java.util.List; import java.util.Map; +import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.springaicommunity.mcp.annotation.McpElicitation; import org.springaicommunity.mcp.annotation.McpLogging; @@ -39,7 +39,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.assertj.core.api.Assertions.fail; +import static org.assertj.core.api.InstanceOfAssertFactories.type; class ClientMcpAsyncHandlersRegistryTests { @@ -147,6 +147,27 @@ void elicitation() { assertThat(response.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); } + @Test + void missingElicitationHandler() { + var registry = new ClientMcpAsyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder + .genericBeanDefinition(ClientMcpAsyncHandlersRegistryTests.HandlersConfiguration.class) + .getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + registry.afterSingletonsInstantiated(); + + var request = McpSchema.ElicitRequest.builder().message("Elicit request").progressToken("token-12345").build(); + assertThatThrownBy(() -> registry.handleElicitation("client-unknown", request).block()) + .hasMessage("Elicitation not supported") + .asInstanceOf(type(McpError.class)) + .extracting(McpError::getJsonRpcError) + .satisfies(error -> assertThat(error.data()) + .isEqualTo(Map.of("reason", "Client does not have elicitation capability"))) + .satisfies(error -> assertThat(error.code()).isEqualTo(McpSchema.ErrorCodes.METHOD_NOT_FOUND)); + } + @Test void sampling() { var registry = new ClientMcpAsyncHandlersRegistry(); @@ -168,6 +189,30 @@ void sampling() { assertThat(content.text()).isEqualTo("Tell a joke"); } + @Test + void missingSamplingHandler() { + var registry = new ClientMcpAsyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder + .genericBeanDefinition(ClientMcpAsyncHandlersRegistryTests.HandlersConfiguration.class) + .getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + registry.afterSingletonsInstantiated(); + + var request = McpSchema.CreateMessageRequest.builder() + .messages(List + .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Tell a joke")))) + .build(); + assertThatThrownBy(() -> registry.handleSampling("client-unknown", request).block()) + .hasMessage("Sampling not supported") + .asInstanceOf(type(McpError.class)) + .extracting(McpError::getJsonRpcError) + .satisfies(error -> assertThat(error.data()) + .isEqualTo(Map.of("reason", "Client does not have sampling capability"))) + .satisfies(error -> assertThat(error.code()).isEqualTo(McpSchema.ErrorCodes.METHOD_NOT_FOUND)); + } + @Test void logging() { var registry = new ClientMcpAsyncHandlersRegistry(); @@ -295,12 +340,6 @@ void supportsProxiedClass() { assertThat(registry.getCapabilities("client-1").elicitation()).isNotNull(); } - @Test - @Disabled - void missingHandler() { - fail("TODO"); - } - static class ClientCapabilitiesConfiguration { @McpElicitation(clients = { "client-1", "client-2" }) diff --git a/mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistryTests.java b/mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistryTests.java index 967249bc4fc..796eed78fba 100644 --- a/mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistryTests.java +++ b/mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistryTests.java @@ -21,8 +21,8 @@ import java.util.List; import java.util.Map; +import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.springaicommunity.mcp.annotation.McpElicitation; import org.springaicommunity.mcp.annotation.McpLogging; @@ -38,7 +38,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.assertj.core.api.Assertions.fail; +import static org.assertj.core.api.InstanceOfAssertFactories.type; class ClientMcpSyncHandlersRegistryTests { @@ -145,6 +145,25 @@ void elicitation() { assertThat(response.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); } + @Test + void missingElicitationHandler() { + var registry = new ClientMcpSyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder.genericBeanDefinition(HandlersConfiguration.class).getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + registry.afterSingletonsInstantiated(); + + var request = McpSchema.ElicitRequest.builder().message("Elicit request").progressToken("token-12345").build(); + assertThatThrownBy(() -> registry.handleElicitation("client-unknown", request)) + .hasMessage("Elicitation not supported") + .asInstanceOf(type(McpError.class)) + .extracting(McpError::getJsonRpcError) + .satisfies(error -> assertThat(error.data()) + .isEqualTo(Map.of("reason", "Client does not have elicitation capability"))) + .satisfies(error -> assertThat(error.code()).isEqualTo(McpSchema.ErrorCodes.METHOD_NOT_FOUND)); + } + @Test void sampling() { var registry = new ClientMcpSyncHandlersRegistry(); @@ -166,6 +185,28 @@ void sampling() { assertThat(content.text()).isEqualTo("Tell a joke"); } + @Test + void missingSamplingHandler() { + var registry = new ClientMcpSyncHandlersRegistry(); + var beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition("myConfig", + BeanDefinitionBuilder.genericBeanDefinition(HandlersConfiguration.class).getBeanDefinition()); + registry.postProcessBeanFactory(beanFactory); + registry.afterSingletonsInstantiated(); + + var request = McpSchema.CreateMessageRequest.builder() + .messages(List + .of(new McpSchema.SamplingMessage(McpSchema.Role.USER, new McpSchema.TextContent("Tell a joke")))) + .build(); + assertThatThrownBy(() -> registry.handleSampling("client-unknown", request)) + .hasMessage("Sampling not supported") + .asInstanceOf(type(McpError.class)) + .extracting(McpError::getJsonRpcError) + .satisfies(error -> assertThat(error.data()) + .isEqualTo(Map.of("reason", "Client does not have sampling capability"))) + .satisfies(error -> assertThat(error.code()).isEqualTo(McpSchema.ErrorCodes.METHOD_NOT_FOUND)); + } + @Test void logging() { var registry = new ClientMcpSyncHandlersRegistry(); @@ -291,12 +332,6 @@ void supportsProxiedClass() { assertThat(registry.getCapabilities("client-1").elicitation()).isNotNull(); } - @Test - @Disabled - void missingHandler() { - fail("TODO"); - } - static class ClientCapabilitiesConfiguration { @McpElicitation(clients = { "client-1", "client-2" }) From f4eb6fe4ba7a1c3f5161fc378e1ae7f0e45a9dbf Mon Sep 17 00:00:00 2001 From: Daniel Garnier-Moiroux Date: Fri, 7 Nov 2025 10:20:27 +0100 Subject: [PATCH 10/11] Fix missing auto-configurations McpClientAutoConfigurationIT Signed-off-by: Daniel Garnier-Moiroux --- .../common/autoconfigure/McpClientAutoConfigurationIT.java | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/McpClientAutoConfigurationIT.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/McpClientAutoConfigurationIT.java index 1d1fbb92ae4..b4a72354db7 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/McpClientAutoConfigurationIT.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/McpClientAutoConfigurationIT.java @@ -29,6 +29,7 @@ import org.mockito.Mockito; import reactor.core.publisher.Mono; +import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientAnnotationScannerAutoConfiguration; import org.springframework.ai.mcp.client.common.autoconfigure.configurer.McpSyncClientConfigurer; import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpClientCommonProperties; import org.springframework.ai.mcp.customizer.McpSyncClientCustomizer; @@ -84,8 +85,9 @@ */ public class McpClientAutoConfigurationIT { - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner().withConfiguration( - AutoConfigurations.of(McpToolCallbackAutoConfiguration.class, McpClientAutoConfiguration.class)); + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(McpToolCallbackAutoConfiguration.class, + McpClientAutoConfiguration.class, McpClientAnnotationScannerAutoConfiguration.class)); /** * Tests the default MCP client auto-configuration. From 3929649d195e002a78c19bc369ac5cd9ed92c7e6 Mon Sep 17 00:00:00 2001 From: Daniel Garnier-Moiroux Date: Fri, 7 Nov 2025 13:45:06 +0100 Subject: [PATCH 11/11] Fix integration tests Signed-off-by: Daniel Garnier-Moiroux --- ...lientListChangedAnnotationsScanningIT.java | 116 ++++++++++++++---- ...ttpClientTransportAutoConfigurationIT.java | 5 +- ...ttpClientTransportAutoConfigurationIT.java | 2 + ...seWebFluxTransportAutoConfigurationIT.java | 5 +- ...ttpClientTransportAutoConfigurationIT.java | 2 + .../SseWebClientWebFluxServerIT.java | 9 +- .../StatelessWebClientWebFluxServerIT.java | 4 +- .../StreamableWebClientWebFluxServerIT.java | 5 +- .../ClientMcpSyncHandlersRegistryTests.java | 8 +- 9 files changed, 121 insertions(+), 35 deletions(-) diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientListChangedAnnotationsScanningIT.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientListChangedAnnotationsScanningIT.java index 04bd0cf7e73..d406da99dd3 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientListChangedAnnotationsScanningIT.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/annotations/McpClientListChangedAnnotationsScanningIT.java @@ -16,15 +16,21 @@ package org.springframework.ai.mcp.client.common.autoconfigure.annotations; +import java.util.ArrayList; +import java.util.Collections; import java.util.List; import io.modelcontextprotocol.spec.McpSchema; +import org.junit.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import org.springaicommunity.mcp.annotation.McpPromptListChanged; import org.springaicommunity.mcp.annotation.McpResourceListChanged; import org.springaicommunity.mcp.annotation.McpToolListChanged; +import reactor.core.publisher.Mono; +import org.springframework.ai.mcp.annotation.spring.ClientMcpAsyncHandlersRegistry; +import org.springframework.ai.mcp.annotation.spring.ClientMcpSyncHandlersRegistry; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; @@ -47,25 +53,65 @@ public class McpClientListChangedAnnotationsScanningIT { private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withConfiguration(AutoConfigurations.of(McpClientAnnotationScannerAutoConfiguration.class)); - @ParameterizedTest - @ValueSource(strings = { "SYNC", "ASYNC" }) - void shouldScanAllThreeListChangedAnnotations(String clientType) { - String prefix = clientType.toLowerCase(); + @Test + public void shouldScanAllThreeListChangedAnnotationsSync() { + this.contextRunner.withUserConfiguration(AllListChangedConfiguration.class) + .withPropertyValues("spring.ai.mcp.client.type=SYNC") + .run(context -> { + // Verify all three annotations were scanned + var registry = context.getBean(ClientMcpSyncHandlersRegistry.class); + var handlers = context.getBean(TestListChangedHandlers.class); + assertThat(registry).isNotNull(); + + List updatedTools = List.of(McpSchema.Tool.builder().name("tool-1").build(), + McpSchema.Tool.builder().name("tool-2").build()); + List updatedPrompts = List.of( + new McpSchema.Prompt("prompt-1", "a test prompt", Collections.emptyList()), + new McpSchema.Prompt("prompt-2", "another test prompt", Collections.emptyList())); + List updatedResources = List.of( + McpSchema.Resource.builder().name("resource-1").uri("file:///resource/1").build(), + McpSchema.Resource.builder().name("resource-2").uri("file:///resource/2").build()); + + registry.handleToolListChanged("test-client", updatedTools); + registry.handleResourceListChanged("test-client", updatedResources); + registry.handlePromptListChanged("test-client", updatedPrompts); + + assertThat(handlers.getCalls()).hasSize(3) + .containsExactlyInAnyOrder( + new TestListChangedHandlers.Call("resource-list-changed", updatedResources), + new TestListChangedHandlers.Call("prompt-list-changed", updatedPrompts), + new TestListChangedHandlers.Call("tool-list-changed", updatedTools)); + }); + } + @Test + public void shouldScanAllThreeListChangedAnnotationsAsync() { this.contextRunner.withUserConfiguration(AllListChangedConfiguration.class) - .withPropertyValues("spring.ai.mcp.client.type=" + clientType) + .withPropertyValues("spring.ai.mcp.client.type=ASYNC") .run(context -> { // Verify all three annotations were scanned - McpClientAnnotationScannerAutoConfiguration.ClientMcpAnnotatedBeans annotatedBeans = context - .getBean(McpClientAnnotationScannerAutoConfiguration.ClientMcpAnnotatedBeans.class); - assertThat(annotatedBeans.getBeansByAnnotation(McpToolListChanged.class)).hasSize(1); - assertThat(annotatedBeans.getBeansByAnnotation(McpResourceListChanged.class)).hasSize(1); - assertThat(annotatedBeans.getBeansByAnnotation(McpPromptListChanged.class)).hasSize(1); - - // Verify all three specification beans were created - assertThat(context).hasBean(prefix + "ToolListChangedSpecs"); - assertThat(context).hasBean(prefix + "ResourceListChangedSpecs"); - assertThat(context).hasBean(prefix + "PromptListChangedSpecs"); + var registry = context.getBean(ClientMcpAsyncHandlersRegistry.class); + var handlers = context.getBean(TestListChangedHandlers.class); + assertThat(registry).isNotNull(); + + List updatedTools = List.of(McpSchema.Tool.builder().name("tool-1").build(), + McpSchema.Tool.builder().name("tool-2").build()); + List updatedPrompts = List.of( + new McpSchema.Prompt("prompt-1", "a test prompt", Collections.emptyList()), + new McpSchema.Prompt("prompt-2", "another test prompt", Collections.emptyList())); + List updatedResources = List.of( + McpSchema.Resource.builder().name("resource-1").uri("file:///resource/1").build(), + McpSchema.Resource.builder().name("resource-2").uri("file:///resource/2").build()); + + registry.handleToolListChanged("test-client", updatedTools).block(); + registry.handleResourceListChanged("test-client", updatedResources).block(); + registry.handlePromptListChanged("test-client", updatedPrompts).block(); + + assertThat(handlers.getCalls()).hasSize(3) + .containsExactlyInAnyOrder( + new TestListChangedHandlers.Call("resource-list-changed", updatedResources), + new TestListChangedHandlers.Call("prompt-list-changed", updatedPrompts), + new TestListChangedHandlers.Call("tool-list-changed", updatedTools)); }); } @@ -79,10 +125,8 @@ void shouldNotScanAnnotationsWhenScannerDisabled(String clientType) { "spring.ai.mcp.client.annotation-scanner.enabled=false") .run(context -> { // Verify scanner beans were not created - assertThat(context).doesNotHaveBean(McpClientAnnotationScannerAutoConfiguration.class); - assertThat(context).doesNotHaveBean(prefix + "ToolListChangedSpecs"); - assertThat(context).doesNotHaveBean(prefix + "ResourceListChangedSpecs"); - assertThat(context).doesNotHaveBean(prefix + "PromptListChangedSpecs"); + assertThat(context).doesNotHaveBean(ClientMcpSyncHandlersRegistry.class); + assertThat(context).doesNotHaveBean(ClientMcpAsyncHandlersRegistry.class); }); } @@ -98,19 +142,47 @@ TestListChangedHandlers testHandlers() { static class TestListChangedHandlers { + private final List calls = new ArrayList<>(); + + public List getCalls() { + return this.calls; + } + @McpToolListChanged(clients = "test-client") public void onToolListChanged(List updatedTools) { - // Test handler for tool list changes + this.calls.add(new Call("tool-list-changed", updatedTools)); } @McpResourceListChanged(clients = "test-client") public void onResourceListChanged(List updatedResources) { - // Test handler for resource list changes + this.calls.add(new Call("resource-list-changed", updatedResources)); } @McpPromptListChanged(clients = "test-client") public void onPromptListChanged(List updatedPrompts) { - // Test handler for prompt list changes + this.calls.add(new Call("prompt-list-changed", updatedPrompts)); + } + + @McpToolListChanged(clients = "test-client") + public Mono onToolListChangedReactive(List updatedTools) { + this.calls.add(new Call("tool-list-changed", updatedTools)); + return Mono.empty(); + } + + @McpResourceListChanged(clients = "test-client") + public Mono onResourceListChangedReactive(List updatedResources) { + this.calls.add(new Call("resource-list-changed", updatedResources)); + return Mono.empty(); + } + + @McpPromptListChanged(clients = "test-client") + public Mono onPromptListChangedReactive(List updatedPrompts) { + this.calls.add(new Call("prompt-list-changed", updatedPrompts)); + return Mono.empty(); + } + + // Record calls made to this object + record Call(String name, Object callRequest) { } } diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-httpclient/src/test/java/org/springframework/ai/mcp/client/autoconfigure/SseHttpClientTransportAutoConfigurationIT.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-httpclient/src/test/java/org/springframework/ai/mcp/client/autoconfigure/SseHttpClientTransportAutoConfigurationIT.java index 7dae305197e..8d3fbf94e5e 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-httpclient/src/test/java/org/springframework/ai/mcp/client/autoconfigure/SseHttpClientTransportAutoConfigurationIT.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-httpclient/src/test/java/org/springframework/ai/mcp/client/autoconfigure/SseHttpClientTransportAutoConfigurationIT.java @@ -33,6 +33,7 @@ import reactor.core.publisher.Mono; import org.springframework.ai.mcp.client.common.autoconfigure.McpClientAutoConfiguration; +import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientAnnotationScannerAutoConfiguration; import org.springframework.ai.mcp.client.httpclient.autoconfigure.SseHttpClientTransportAutoConfiguration; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.context.annotation.UserConfigurations; @@ -56,8 +57,8 @@ public class SseHttpClientTransportAutoConfigurationIT { private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withPropertyValues("spring.ai.mcp.client.initialized=false", "spring.ai.mcp.client.sse.connections.server1.url=" + host) - .withConfiguration( - AutoConfigurations.of(McpClientAutoConfiguration.class, SseHttpClientTransportAutoConfiguration.class)); + .withConfiguration(AutoConfigurations.of(McpClientAutoConfiguration.class, + McpClientAnnotationScannerAutoConfiguration.class, SseHttpClientTransportAutoConfiguration.class)); static String host = "http://localhost:3001"; diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-httpclient/src/test/java/org/springframework/ai/mcp/client/autoconfigure/StreamableHttpHttpClientTransportAutoConfigurationIT.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-httpclient/src/test/java/org/springframework/ai/mcp/client/autoconfigure/StreamableHttpHttpClientTransportAutoConfigurationIT.java index 230605fdc46..0b52cc49ecb 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-httpclient/src/test/java/org/springframework/ai/mcp/client/autoconfigure/StreamableHttpHttpClientTransportAutoConfigurationIT.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-httpclient/src/test/java/org/springframework/ai/mcp/client/autoconfigure/StreamableHttpHttpClientTransportAutoConfigurationIT.java @@ -33,6 +33,7 @@ import reactor.core.publisher.Mono; import org.springframework.ai.mcp.client.common.autoconfigure.McpClientAutoConfiguration; +import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientAnnotationScannerAutoConfiguration; import org.springframework.ai.mcp.client.httpclient.autoconfigure.StreamableHttpHttpClientTransportAutoConfiguration; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.context.annotation.UserConfigurations; @@ -58,6 +59,7 @@ public class StreamableHttpHttpClientTransportAutoConfigurationIT { .withPropertyValues("spring.ai.mcp.client.initialized=false", "spring.ai.mcp.client.streamable-http.connections.server1.url=" + host) .withConfiguration(AutoConfigurations.of(McpClientAutoConfiguration.class, + McpClientAnnotationScannerAutoConfiguration.class, StreamableHttpHttpClientTransportAutoConfiguration.class)); static String host = "http://localhost:3001"; diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/SseWebFluxTransportAutoConfigurationIT.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/SseWebFluxTransportAutoConfigurationIT.java index 744b494cf4d..15fa7c3cb33 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/SseWebFluxTransportAutoConfigurationIT.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/SseWebFluxTransportAutoConfigurationIT.java @@ -30,6 +30,7 @@ import org.testcontainers.containers.wait.strategy.Wait; import org.springframework.ai.mcp.client.common.autoconfigure.McpClientAutoConfiguration; +import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientAnnotationScannerAutoConfiguration; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -43,8 +44,8 @@ public class SseWebFluxTransportAutoConfigurationIT { private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withPropertyValues("spring.ai.mcp.client.initialized=false", "spring.ai.mcp.client.sse.connections.server1.url=" + host) - .withConfiguration( - AutoConfigurations.of(McpClientAutoConfiguration.class, SseWebFluxTransportAutoConfiguration.class)); + .withConfiguration(AutoConfigurations.of(McpClientAutoConfiguration.class, + McpClientAnnotationScannerAutoConfiguration.class, SseWebFluxTransportAutoConfiguration.class)); static String host = "http://localhost:3001"; diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/StreamableHttpHttpClientTransportAutoConfigurationIT.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/StreamableHttpHttpClientTransportAutoConfigurationIT.java index 257df12a97b..83da4876bd1 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/StreamableHttpHttpClientTransportAutoConfigurationIT.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/StreamableHttpHttpClientTransportAutoConfigurationIT.java @@ -30,6 +30,7 @@ import org.testcontainers.containers.wait.strategy.Wait; import org.springframework.ai.mcp.client.common.autoconfigure.McpClientAutoConfiguration; +import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientAnnotationScannerAutoConfiguration; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -45,6 +46,7 @@ public class StreamableHttpHttpClientTransportAutoConfigurationIT { .withPropertyValues("spring.ai.mcp.client.initialized=false", "spring.ai.mcp.client.streamable-http.connections.server1.url=" + host) .withConfiguration(AutoConfigurations.of(McpClientAutoConfiguration.class, + McpClientAnnotationScannerAutoConfiguration.class, StreamableHttpWebFluxTransportAutoConfiguration.class)); static String host = "http://localhost:3001"; diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/SseWebClientWebFluxServerIT.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/SseWebClientWebFluxServerIT.java index a833402c1d1..57961417cf8 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/SseWebClientWebFluxServerIT.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/SseWebClientWebFluxServerIT.java @@ -62,6 +62,7 @@ import org.springframework.ai.mcp.client.common.autoconfigure.McpClientAutoConfiguration; import org.springframework.ai.mcp.client.common.autoconfigure.McpToolCallbackAutoConfiguration; +import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientAnnotationScannerAutoConfiguration; import org.springframework.ai.mcp.client.webflux.autoconfigure.SseWebFluxTransportAutoConfiguration; import org.springframework.ai.mcp.customizer.McpSyncClientCustomizer; import org.springframework.ai.mcp.server.common.autoconfigure.McpServerAutoConfiguration; @@ -93,9 +94,9 @@ public class SseWebClientWebFluxServerIT { AutoConfigurations.of(McpServerAutoConfiguration.class, McpServerObjectMapperAutoConfiguration.class, ToolCallbackConverterAutoConfiguration.class, McpServerSseWebFluxAutoConfiguration.class)); - private final ApplicationContextRunner clientApplicationContext = new ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(McpToolCallbackAutoConfiguration.class, - McpClientAutoConfiguration.class, SseWebFluxTransportAutoConfiguration.class)); + private final ApplicationContextRunner clientApplicationContext = new ApplicationContextRunner().withConfiguration( + AutoConfigurations.of(McpToolCallbackAutoConfiguration.class, McpClientAutoConfiguration.class, + McpClientAnnotationScannerAutoConfiguration.class, SseWebFluxTransportAutoConfiguration.class)); @Test void clientServerCapabilities() { @@ -518,6 +519,8 @@ McpSyncClientCustomizer clientCustomizer(TestContext testContext) { assertThat(progressNotification.total()).isEqualTo(1.0); // assertThat(progressNotification.message()).isEqualTo("processing"); }); + + mcpClientSpec.capabilities(McpSchema.ClientCapabilities.builder().elicitation().sampling().build()); }; } diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StatelessWebClientWebFluxServerIT.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StatelessWebClientWebFluxServerIT.java index a4eb89181cd..f584a3ad91b 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StatelessWebClientWebFluxServerIT.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StatelessWebClientWebFluxServerIT.java @@ -49,6 +49,7 @@ import org.springframework.ai.mcp.McpToolUtils; import org.springframework.ai.mcp.client.common.autoconfigure.McpClientAutoConfiguration; import org.springframework.ai.mcp.client.common.autoconfigure.McpToolCallbackAutoConfiguration; +import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientAnnotationScannerAutoConfiguration; import org.springframework.ai.mcp.client.webflux.autoconfigure.StreamableHttpWebFluxTransportAutoConfiguration; import org.springframework.ai.mcp.customizer.McpSyncClientCustomizer; import org.springframework.ai.mcp.server.common.autoconfigure.McpServerObjectMapperAutoConfiguration; @@ -84,7 +85,8 @@ public class StatelessWebClientWebFluxServerIT { private final ApplicationContextRunner clientApplicationContext = new ApplicationContextRunner() .withConfiguration(AutoConfigurations.of(McpToolCallbackAutoConfiguration.class, - McpClientAutoConfiguration.class, StreamableHttpWebFluxTransportAutoConfiguration.class)); + McpClientAutoConfiguration.class, McpClientAnnotationScannerAutoConfiguration.class, + StreamableHttpWebFluxTransportAutoConfiguration.class)); @Test void clientServerCapabilities() { diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableWebClientWebFluxServerIT.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableWebClientWebFluxServerIT.java index 8ee43cf8d07..96108790929 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableWebClientWebFluxServerIT.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableWebClientWebFluxServerIT.java @@ -63,6 +63,7 @@ import org.springframework.ai.mcp.client.common.autoconfigure.McpClientAutoConfiguration; import org.springframework.ai.mcp.client.common.autoconfigure.McpToolCallbackAutoConfiguration; +import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientAnnotationScannerAutoConfiguration; import org.springframework.ai.mcp.client.webflux.autoconfigure.StreamableHttpWebFluxTransportAutoConfiguration; import org.springframework.ai.mcp.customizer.McpSyncClientCustomizer; import org.springframework.ai.mcp.server.common.autoconfigure.McpServerAutoConfiguration; @@ -99,7 +100,8 @@ public class StreamableWebClientWebFluxServerIT { private final ApplicationContextRunner clientApplicationContext = new ApplicationContextRunner() .withConfiguration(AutoConfigurations.of(McpToolCallbackAutoConfiguration.class, - McpClientAutoConfiguration.class, StreamableHttpWebFluxTransportAutoConfiguration.class)); + McpClientAutoConfiguration.class, McpClientAnnotationScannerAutoConfiguration.class, + StreamableHttpWebFluxTransportAutoConfiguration.class)); @Test void clientServerCapabilities() { @@ -521,6 +523,7 @@ McpSyncClientCustomizer clientCustomizer(TestContext testContext) { testContext.progressNotifications.add(progressNotification); testContext.progressLatch.countDown(); }); + mcpClientSpec.capabilities(McpSchema.ClientCapabilities.builder().sampling().elicitation().build()); }; } diff --git a/mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistryTests.java b/mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistryTests.java index 796eed78fba..9b75acf8aa6 100644 --- a/mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistryTests.java +++ b/mcp/mcp-annotations-spring/src/test/java/org/springframework/ai/mcp/annotation/spring/ClientMcpSyncHandlersRegistryTests.java @@ -276,14 +276,14 @@ void promptListChanged() { registry.afterSingletonsInstantiated(); var handlers = beanFactory.getBean(HandlersConfiguration.class); - List updatedTools = List.of( + List updatedPrompts = List.of( new McpSchema.Prompt("prompt-1", "a test prompt", Collections.emptyList()), new McpSchema.Prompt("prompt-2", "another test prompt", Collections.emptyList())); - registry.handlePromptListChanged("client-1", updatedTools); + registry.handlePromptListChanged("client-1", updatedPrompts); assertThat(handlers.getCalls()).hasSize(2) - .containsExactlyInAnyOrder(new HandlersConfiguration.Call("handlePromptListChanged", updatedTools), - new HandlersConfiguration.Call("handlePromptListChangedAgain", updatedTools)); + .containsExactlyInAnyOrder(new HandlersConfiguration.Call("handlePromptListChanged", updatedPrompts), + new HandlersConfiguration.Call("handlePromptListChangedAgain", updatedPrompts)); } @Test