diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/pom.xml b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/pom.xml index c22051c49ed..04af6dd307e 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/pom.xml +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/pom.xml @@ -42,6 +42,12 @@ true + + org.springframework + spring-webflux + true + + org.springframework.boot spring-boot-configuration-processor @@ -73,6 +79,10 @@ mockito-core test - + + org.springframework + spring-webflux + + diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/WebClientFactory.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/WebClientFactory.java new file mode 100644 index 00000000000..3efcc01f0e5 --- /dev/null +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/WebClientFactory.java @@ -0,0 +1,52 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mcp.client.common.autoconfigure; + +import org.springframework.web.reactive.function.client.WebClient; + +/** + * Factory interface for creating {@link WebClient.Builder} instances per connection name. + * + *

+ * This factory allows customization of WebClient configuration on a per-connection basis, + * enabling fine-grained control over HTTP client settings such as timeouts, SSL + * configurations, and base URLs for each MCP server connection. + * + *

+ * The default implementation returns a standard {@link WebClient.Builder}. Custom + * implementations can provide connection-specific configurations based on the connection + * name. + * + * @author limch02 + * @since 1.0.0 + */ +public interface WebClientFactory { + + /** + * Creates a {@link WebClient.Builder} for the given connection name. + *

+ * The default implementation returns a standard {@link WebClient.Builder}. Custom + * implementations can override this method to provide connection-specific + * configurations. + * @param connectionName the name of the MCP server connection + * @return a WebClient.Builder instance configured for the connection + */ + default WebClient.Builder create(String connectionName) { + return WebClient.builder(); + } + +} diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/main/java/org/springframework/ai/mcp/client/webflux/autoconfigure/DefaultWebClientFactory.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/main/java/org/springframework/ai/mcp/client/webflux/autoconfigure/DefaultWebClientFactory.java new file mode 100644 index 00000000000..13453c1c117 --- /dev/null +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/main/java/org/springframework/ai/mcp/client/webflux/autoconfigure/DefaultWebClientFactory.java @@ -0,0 +1,63 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mcp.client.webflux.autoconfigure; + +import org.springframework.ai.mcp.client.common.autoconfigure.WebClientFactory; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.context.annotation.Bean; +import org.springframework.web.reactive.function.client.WebClient; + +/** + * Default configuration for {@link WebClientFactory}. + * + *

+ * Provides a default implementation of {@link WebClientFactory} that returns a standard + * {@link WebClient.Builder} for all connections. This bean is only created if no custom + * {@link WebClientFactory} bean is provided. + * + * @author limch02 + * @since 1.0.0 + */ +@AutoConfiguration +@ConditionalOnClass(WebClient.class) +public class DefaultWebClientFactory { + + /** + * Creates a default {@link WebClientFactory} implementation. + *

+ * This factory returns a standard {@link WebClient.Builder} for all connection names. + * Custom implementations can be provided by defining a {@link WebClientFactory} bean. + * @return the default WebClientFactory instance + */ + @Bean + @ConditionalOnMissingBean(WebClientFactory.class) + public WebClientFactory webClientFactory() { + return new DefaultWebClientFactoryImpl(); + } + + private static class DefaultWebClientFactoryImpl implements WebClientFactory { + + @Override + public WebClient.Builder create(String connectionName) { + return WebClient.builder(); + } + + } + +} diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/main/java/org/springframework/ai/mcp/client/webflux/autoconfigure/StreamableHttpWebFluxTransportAutoConfiguration.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/main/java/org/springframework/ai/mcp/client/webflux/autoconfigure/StreamableHttpWebFluxTransportAutoConfiguration.java index e3631cba982..2c3e0b1c265 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/main/java/org/springframework/ai/mcp/client/webflux/autoconfigure/StreamableHttpWebFluxTransportAutoConfiguration.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/main/java/org/springframework/ai/mcp/client/webflux/autoconfigure/StreamableHttpWebFluxTransportAutoConfiguration.java @@ -25,6 +25,7 @@ import io.modelcontextprotocol.json.jackson.JacksonMcpJsonMapper; import org.springframework.ai.mcp.client.common.autoconfigure.NamedClientMcpTransport; +import org.springframework.ai.mcp.client.common.autoconfigure.WebClientFactory; import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpClientCommonProperties; import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpStreamableHttpClientProperties; import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpStreamableHttpClientProperties.ConnectionParameters; @@ -58,7 +59,7 @@ * @see WebClientStreamableHttpTransport * @see McpStreamableHttpClientProperties */ -@AutoConfiguration +@AutoConfiguration(after = DefaultWebClientFactory.class) @ConditionalOnClass({ WebClientStreamableHttpTransport.class, WebClient.class }) @EnableConfigurationProperties({ McpStreamableHttpClientProperties.class, McpClientCommonProperties.class }) @ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", @@ -71,31 +72,31 @@ public class StreamableHttpWebFluxTransportAutoConfiguration { *

* Each transport is configured with: *

* @param streamableProperties the Streamable HTTP client properties containing server * configurations - * @param webClientBuilderProvider the provider for WebClient.Builder + * @param webClientFactory the factory for creating WebClient.Builder instances per + * connection name * @param objectMapperProvider the provider for ObjectMapper or a new instance if not * available * @return list of named MCP transports */ @Bean public List streamableHttpWebFluxClientTransports( - McpStreamableHttpClientProperties streamableProperties, - ObjectProvider webClientBuilderProvider, + McpStreamableHttpClientProperties streamableProperties, WebClientFactory webClientFactory, ObjectProvider objectMapperProvider) { List streamableHttpTransports = new ArrayList<>(); - var webClientBuilderTemplate = webClientBuilderProvider.getIfAvailable(WebClient::builder); var objectMapper = objectMapperProvider.getIfAvailable(ObjectMapper::new); for (Map.Entry serverParameters : streamableProperties.getConnections() .entrySet()) { - var webClientBuilder = webClientBuilderTemplate.clone().baseUrl(serverParameters.getValue().url()); + String connectionName = serverParameters.getKey(); + var webClientBuilder = webClientFactory.create(connectionName).baseUrl(serverParameters.getValue().url()); String streamableHttpEndpoint = serverParameters.getValue().endpoint() != null ? serverParameters.getValue().endpoint() : "/mcp"; @@ -104,7 +105,7 @@ public List streamableHttpWebFluxClientTransports( .jsonMapper(new JacksonMcpJsonMapper(objectMapper)) .build(); - streamableHttpTransports.add(new NamedClientMcpTransport(serverParameters.getKey(), transport)); + streamableHttpTransports.add(new NamedClientMcpTransport(connectionName, transport)); } return streamableHttpTransports; diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports index cd975659070..02742d11d82 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +org.springframework.ai.mcp.client.webflux.autoconfigure.DefaultWebClientFactory org.springframework.ai.mcp.client.webflux.autoconfigure.SseWebFluxTransportAutoConfiguration org.springframework.ai.mcp.client.webflux.autoconfigure.StreamableHttpWebFluxTransportAutoConfiguration diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/McpToolsConfigurationTests.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/McpToolsConfigurationTests.java index b9db603a0bc..5af7219b755 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/McpToolsConfigurationTests.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/McpToolsConfigurationTests.java @@ -41,6 +41,7 @@ import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Lazy; import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.ResolvableType; @@ -67,6 +68,8 @@ void mcpClientSupportsSampling() { .withPropertyValues("spring.ai.mcp.client.streamable-http.connections.server1.url=http://localhost:0", "spring.ai.mcp.client.initialized=false") .withConfiguration(AutoConfigurations.of( + // WebClientFactory + DefaultWebClientFactory.class, // Transport StreamableHttpWebFluxTransportAutoConfiguration.class, // MCP clients @@ -180,16 +183,16 @@ CustomToolCallbackProvider customToolCallbackProvider() { // Ignored by the resolver @Bean - SyncMcpToolCallbackProvider mcpToolCallbackProvider() { + SyncMcpToolCallbackProvider mcpToolCallbackProvider(@Lazy ToolCallbackResolver resolver) { var tcp = mock(SyncMcpToolCallbackProvider.class); when(tcp.getToolCallbacks()) .thenThrow(new RuntimeException("mcpToolCallbackProvider#getToolCallbacks should not be called")); return tcp; } - // Ignored by the resolver + // This bean depends on the resolver, to ensure there are no cyclic dependencies @Bean - CustomMcpToolCallbackProvider customMcpToolCallbackProvider() { + CustomMcpToolCallbackProvider customMcpToolCallbackProvider(@Lazy ToolCallbackResolver resolver) { return new CustomMcpToolCallbackProvider(); } diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/SseWebFluxTransportAutoConfigurationIT.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/SseWebFluxTransportAutoConfigurationIT.java index 15fa7c3cb33..437052d1e0b 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/SseWebFluxTransportAutoConfigurationIT.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/SseWebFluxTransportAutoConfigurationIT.java @@ -44,7 +44,7 @@ public class SseWebFluxTransportAutoConfigurationIT { private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withPropertyValues("spring.ai.mcp.client.initialized=false", "spring.ai.mcp.client.sse.connections.server1.url=" + host) - .withConfiguration(AutoConfigurations.of(McpClientAutoConfiguration.class, + .withConfiguration(AutoConfigurations.of(DefaultWebClientFactory.class, McpClientAutoConfiguration.class, McpClientAnnotationScannerAutoConfiguration.class, SseWebFluxTransportAutoConfiguration.class)); static String host = "http://localhost:3001"; diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/StreamableHttpHttpClientTransportAutoConfigurationIT.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/StreamableHttpHttpClientTransportAutoConfigurationIT.java index 83da4876bd1..cbd13a70b23 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/StreamableHttpHttpClientTransportAutoConfigurationIT.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/StreamableHttpHttpClientTransportAutoConfigurationIT.java @@ -45,7 +45,7 @@ public class StreamableHttpHttpClientTransportAutoConfigurationIT { private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withPropertyValues("spring.ai.mcp.client.initialized=false", "spring.ai.mcp.client.streamable-http.connections.server1.url=" + host) - .withConfiguration(AutoConfigurations.of(McpClientAutoConfiguration.class, + .withConfiguration(AutoConfigurations.of(DefaultWebClientFactory.class, McpClientAutoConfiguration.class, McpClientAnnotationScannerAutoConfiguration.class, StreamableHttpWebFluxTransportAutoConfiguration.class)); diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/StreamableHttpWebFluxTransportAutoConfigurationTests.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/StreamableHttpWebFluxTransportAutoConfigurationTests.java index 9551b41b874..ac3421dea88 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/StreamableHttpWebFluxTransportAutoConfigurationTests.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/StreamableHttpWebFluxTransportAutoConfigurationTests.java @@ -24,6 +24,7 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.mcp.client.common.autoconfigure.NamedClientMcpTransport; +import org.springframework.ai.mcp.client.common.autoconfigure.WebClientFactory; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.FilteredClassLoader; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -42,7 +43,8 @@ public class StreamableHttpWebFluxTransportAutoConfigurationTests { private final ApplicationContextRunner applicationContext = new ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(StreamableHttpWebFluxTransportAutoConfiguration.class)); + .withConfiguration(AutoConfigurations.of(DefaultWebClientFactory.class, + StreamableHttpWebFluxTransportAutoConfiguration.class)); @Test void webFluxClientTransportsPresentIfWebClientStreamableHttpTransportPresent() { @@ -65,6 +67,7 @@ void webFluxClientTransportsNotPresentIfMcpClientDisabled() { } @Test + @SuppressWarnings("unchecked") void noTransportsCreatedWithEmptyConnections() { this.applicationContext.run(context -> { List transports = context.getBean("streamableHttpWebFluxClientTransports", @@ -74,6 +77,7 @@ void noTransportsCreatedWithEmptyConnections() { } @Test + @SuppressWarnings("unchecked") void singleConnectionCreatesOneTransport() { this.applicationContext .withPropertyValues("spring.ai.mcp.client.streamable-http.connections.server1.url=http://localhost:8080") @@ -87,6 +91,7 @@ void singleConnectionCreatesOneTransport() { } @Test + @SuppressWarnings("unchecked") void multipleConnectionsCreateMultipleTransports() { this.applicationContext .withPropertyValues("spring.ai.mcp.client.streamable-http.connections.server1.url=http://localhost:8080", @@ -107,6 +112,7 @@ void multipleConnectionsCreateMultipleTransports() { } @Test + @SuppressWarnings("unchecked") void customStreamableHttpEndpointIsRespected() { this.applicationContext .withPropertyValues("spring.ai.mcp.client.streamable-http.connections.server1.url=http://localhost:8080", @@ -124,11 +130,12 @@ void customStreamableHttpEndpointIsRespected() { } @Test - void customWebClientBuilderIsUsed() { - this.applicationContext.withUserConfiguration(CustomWebClientConfiguration.class) + @SuppressWarnings("unchecked") + void customWebClientFactoryIsUsed() { + this.applicationContext.withUserConfiguration(CustomWebClientFactoryConfiguration.class) .withPropertyValues("spring.ai.mcp.client.streamable-http.connections.server1.url=http://localhost:8080") .run(context -> { - assertThat(context.getBean(WebClient.Builder.class)).isNotNull(); + assertThat(context.getBean(WebClientFactory.class)).isNotNull(); List transports = context.getBean("streamableHttpWebFluxClientTransports", List.class); assertThat(transports).hasSize(1); @@ -136,6 +143,77 @@ void customWebClientBuilderIsUsed() { } @Test + @SuppressWarnings("unchecked") + void customWebClientFactoryPerConnectionCustomization() { + this.applicationContext.withUserConfiguration(PerConnectionWebClientFactoryConfiguration.class) + .withPropertyValues("spring.ai.mcp.client.streamable-http.connections.server1.url=http://localhost:8080", + "spring.ai.mcp.client.streamable-http.connections.server2.url=http://otherserver:8081") + .run(context -> { + List transports = context.getBean("streamableHttpWebFluxClientTransports", + List.class); + assertThat(transports).hasSize(2); + // Verify that custom factory was called for each connection + PerConnectionWebClientFactory factory = context.getBean(PerConnectionWebClientFactory.class); + assertThat(factory.getCreatedConnections()).containsExactlyInAnyOrder("server1", "server2"); + }); + } + + @Test + @SuppressWarnings("unchecked") + void defaultWebClientFactoryReturnsBuilder() { + this.applicationContext + .withPropertyValues("spring.ai.mcp.client.streamable-http.connections.server1.url=http://localhost:8080") + .run(context -> { + // Verify default factory is created + assertThat(context.getBean(WebClientFactory.class)).isNotNull(); + List transports = context.getBean("streamableHttpWebFluxClientTransports", + List.class); + assertThat(transports).hasSize(1); + assertThat(transports.get(0).transport()).isInstanceOf(WebClientStreamableHttpTransport.class); + }); + } + + @Test + @SuppressWarnings("unchecked") + void fallbackToDefaultFactoryWhenNoCustomFactoryProvided() { + this.applicationContext + .withPropertyValues("spring.ai.mcp.client.streamable-http.connections.server1.url=http://localhost:8080", + "spring.ai.mcp.client.streamable-http.connections.server2.url=http://otherserver:8081") + .run(context -> { + // Verify default factory is used when no custom factory is provided + WebClientFactory factory = context.getBean(WebClientFactory.class); + assertThat(factory).isNotNull(); + // Default factory should return a builder for any connection name + WebClient.Builder builder1 = factory.create("server1"); + WebClient.Builder builder2 = factory.create("server2"); + assertThat(builder1).isNotNull(); + assertThat(builder2).isNotNull(); + List transports = context.getBean("streamableHttpWebFluxClientTransports", + List.class); + assertThat(transports).hasSize(2); + }); + } + + @Test + @SuppressWarnings("unchecked") + void customWebClientFactoryTakesPrecedenceOverDefault() { + this.applicationContext.withUserConfiguration(CustomWebClientFactoryConfiguration.class) + .withPropertyValues("spring.ai.mcp.client.streamable-http.connections.server1.url=http://localhost:8080") + .run(context -> { + // Verify custom factory is used, not the default + WebClientFactory factory = context.getBean(WebClientFactory.class); + assertThat(factory).isNotNull(); + assertThat(factory).isInstanceOf(CustomWebClientFactory.class); + // Verify only one factory bean exists (the custom one) + assertThat(context.getBeansOfType(WebClientFactory.class)).hasSize(1); + List transports = context.getBean("streamableHttpWebFluxClientTransports", + List.class); + assertThat(transports).hasSize(1); + }); + } + + @Test + @SuppressWarnings("unchecked") void customObjectMapperIsUsed() { this.applicationContext.withUserConfiguration(CustomObjectMapperConfiguration.class) .withPropertyValues("spring.ai.mcp.client.streamable-http.connections.server1.url=http://localhost:8080") @@ -148,6 +226,7 @@ void customObjectMapperIsUsed() { } @Test + @SuppressWarnings("unchecked") void defaultStreamableHttpEndpointIsUsedWhenNotSpecified() { this.applicationContext .withPropertyValues("spring.ai.mcp.client.streamable-http.connections.server1.url=http://localhost:8080") @@ -163,6 +242,7 @@ void defaultStreamableHttpEndpointIsUsedWhenNotSpecified() { } @Test + @SuppressWarnings("unchecked") void mixedConnectionsWithAndWithoutCustomStreamableHttpEndpoint() { this.applicationContext .withPropertyValues("spring.ai.mcp.client.streamable-http.connections.server1.url=http://localhost:8080", @@ -189,18 +269,131 @@ void mixedConnectionsWithAndWithoutCustomStreamableHttpEndpoint() { }); } + @Test + @SuppressWarnings("unchecked") + void eachConnectionGetsSeparateWebClientInstance() { + this.applicationContext + .withPropertyValues("spring.ai.mcp.client.streamable-http.connections.server1.url=http://localhost:8080", + "spring.ai.mcp.client.streamable-http.connections.server2.url=http://otherserver:8081", + "spring.ai.mcp.client.streamable-http.connections.server3.url=http://thirdserver:8082") + .run(context -> { + List transports = context.getBean("streamableHttpWebFluxClientTransports", + List.class); + assertThat(transports).hasSize(3); + + // Extract WebClient instances from each transport + WebClient webClient1 = getWebClient((WebClientStreamableHttpTransport) transports.get(0).transport()); + WebClient webClient2 = getWebClient((WebClientStreamableHttpTransport) transports.get(1).transport()); + WebClient webClient3 = getWebClient((WebClientStreamableHttpTransport) transports.get(2).transport()); + + // Verify that each connection has a separate WebClient instance + // They should not be the same object reference + assertThat(webClient1).isNotNull(); + assertThat(webClient2).isNotNull(); + assertThat(webClient3).isNotNull(); + assertThat(webClient1).isNotSameAs(webClient2); + assertThat(webClient1).isNotSameAs(webClient3); + assertThat(webClient2).isNotSameAs(webClient3); + + // Verify that WebClientFactory.create() was called for each connection + WebClientFactory factory = context.getBean(WebClientFactory.class); + assertThat(factory).isNotNull(); + }); + } + + /** + * Verifies that WebClientFactory.create() is called separately for each connection, + * ensuring that each connection gets its own WebClient.Builder instance. This is the + * core functionality of the WebClientFactory pattern. + */ + @Test + @SuppressWarnings("unchecked") + void webClientFactoryIsCalledPerConnection() { + this.applicationContext.withUserConfiguration(PerConnectionWebClientFactoryConfiguration.class) + .withPropertyValues("spring.ai.mcp.client.streamable-http.connections.server1.url=http://localhost:8080", + "spring.ai.mcp.client.streamable-http.connections.server2.url=http://otherserver:8081", + "spring.ai.mcp.client.streamable-http.connections.server3.url=http://thirdserver:8082") + .run(context -> { + List transports = context.getBean("streamableHttpWebFluxClientTransports", + List.class); + assertThat(transports).hasSize(3); + + // Verify that WebClientFactory.create() was called for each connection + // name + PerConnectionWebClientFactory factory = context.getBean(PerConnectionWebClientFactory.class); + assertThat(factory.getCreatedConnections()).hasSize(3); + assertThat(factory.getCreatedConnections()).containsExactlyInAnyOrder("server1", "server2", "server3"); + + // Verify that each connection has a separate WebClient instance + WebClient webClient1 = getWebClient((WebClientStreamableHttpTransport) transports.get(0).transport()); + WebClient webClient2 = getWebClient((WebClientStreamableHttpTransport) transports.get(1).transport()); + WebClient webClient3 = getWebClient((WebClientStreamableHttpTransport) transports.get(2).transport()); + + // Each WebClient should be a different instance + assertThat(webClient1).isNotSameAs(webClient2); + assertThat(webClient1).isNotSameAs(webClient3); + assertThat(webClient2).isNotSameAs(webClient3); + + // Verify that factory.create() was called exactly 3 times (once per + // connection) + assertThat(factory.getCreatedConnections()).hasSize(3); + }); + } + private String getStreamableHttpEndpoint(WebClientStreamableHttpTransport transport) { Field privateField = ReflectionUtils.findField(WebClientStreamableHttpTransport.class, "endpoint"); ReflectionUtils.makeAccessible(privateField); return (String) ReflectionUtils.getField(privateField, transport); } + private WebClient getWebClient(WebClientStreamableHttpTransport transport) { + // Try common field names for WebClient + String[] possibleFieldNames = { "webClient", "client", "httpClient" }; + for (String fieldName : possibleFieldNames) { + Field field = ReflectionUtils.findField(WebClientStreamableHttpTransport.class, fieldName); + if (field != null) { + ReflectionUtils.makeAccessible(field); + Object value = ReflectionUtils.getField(field, transport); + if (value instanceof WebClient) { + return (WebClient) value; + } + } + } + // If direct field access fails, try to find any WebClient field + // Check all declared fields including inherited ones + Class clazz = WebClientStreamableHttpTransport.class; + while (clazz != null) { + Field[] fields = clazz.getDeclaredFields(); + for (Field field : fields) { + if (WebClient.class.isAssignableFrom(field.getType())) { + ReflectionUtils.makeAccessible(field); + Object value = ReflectionUtils.getField(field, transport); + if (value instanceof WebClient) { + return (WebClient) value; + } + } + } + clazz = clazz.getSuperclass(); + } + throw new IllegalStateException("Could not find WebClient field in WebClientStreamableHttpTransport"); + } + @Configuration - static class CustomWebClientConfiguration { + static class CustomWebClientFactoryConfiguration { @Bean - WebClient.Builder webClientBuilder() { - return WebClient.builder().baseUrl("http://custom-base-url"); + WebClientFactory webClientFactory() { + return new CustomWebClientFactory(); + } + + } + + @Configuration + static class PerConnectionWebClientFactoryConfiguration { + + @Bean + PerConnectionWebClientFactory webClientFactory() { + return new PerConnectionWebClientFactory(); } } @@ -215,4 +408,29 @@ ObjectMapper objectMapper() { } + static class CustomWebClientFactory implements WebClientFactory { + + @Override + public WebClient.Builder create(String connectionName) { + return WebClient.builder().baseUrl("http://custom-base-url"); + } + + } + + static class PerConnectionWebClientFactory implements WebClientFactory { + + private final java.util.List createdConnections = new java.util.ArrayList<>(); + + @Override + public WebClient.Builder create(String connectionName) { + createdConnections.add(connectionName); + return WebClient.builder(); + } + + java.util.List getCreatedConnections() { + return createdConnections; + } + + } + } diff --git a/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cassandra/src/main/java/org/springframework/ai/model/chat/memory/repository/cassandra/autoconfigure/CassandraChatMemoryRepositoryAutoConfiguration.java b/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cassandra/src/main/java/org/springframework/ai/model/chat/memory/repository/cassandra/autoconfigure/CassandraChatMemoryRepositoryAutoConfiguration.java index e55bc82cea5..34543c5844e 100644 --- a/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cassandra/src/main/java/org/springframework/ai/model/chat/memory/repository/cassandra/autoconfigure/CassandraChatMemoryRepositoryAutoConfiguration.java +++ b/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cassandra/src/main/java/org/springframework/ai/model/chat/memory/repository/cassandra/autoconfigure/CassandraChatMemoryRepositoryAutoConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,7 +32,6 @@ * {@link AutoConfiguration Auto-configuration} for {@link CassandraChatMemoryRepository}. * * @author Mick Semb Wever - * @author Jihoon Kim * @since 1.0.0 */ @AutoConfiguration(after = CassandraAutoConfiguration.class, before = ChatMemoryAutoConfiguration.class) @@ -45,20 +44,22 @@ public class CassandraChatMemoryRepositoryAutoConfiguration { public CassandraChatMemoryRepository cassandraChatMemoryRepository( CassandraChatMemoryRepositoryProperties properties, CqlSession cqlSession) { - var builder = CassandraChatMemoryRepositoryConfig.builder().withCqlSession(cqlSession); - - builder = builder.withKeyspaceName(properties.getKeyspace()) + var configBuilder = CassandraChatMemoryRepositoryConfig.builder() + .withCqlSession(cqlSession) + .withKeyspaceName(properties.getKeyspace()) .withTableName(properties.getTable()) .withMessagesColumnName(properties.getMessagesColumn()); - if (!properties.isInitializeSchema()) { - builder = builder.disallowSchemaChanges(); + if (properties.getTimeToLive() != null) { + configBuilder.withTimeToLive(properties.getTimeToLive()); } - if (null != properties.getTimeToLive()) { - builder = builder.withTimeToLive(properties.getTimeToLive()); + + if (!properties.isInitializeSchema()) { + configBuilder.disallowSchemaChanges(); } - return CassandraChatMemoryRepository.create(builder.build()); + return CassandraChatMemoryRepository.create(configBuilder.build()); } } + diff --git a/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cassandra/src/main/java/org/springframework/ai/model/chat/memory/repository/cassandra/autoconfigure/CassandraChatMemoryRepositoryProperties.java b/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cassandra/src/main/java/org/springframework/ai/model/chat/memory/repository/cassandra/autoconfigure/CassandraChatMemoryRepositoryProperties.java index 7b7469dbb0b..f07d4032951 100644 --- a/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cassandra/src/main/java/org/springframework/ai/model/chat/memory/repository/cassandra/autoconfigure/CassandraChatMemoryRepositoryProperties.java +++ b/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cassandra/src/main/java/org/springframework/ai/model/chat/memory/repository/cassandra/autoconfigure/CassandraChatMemoryRepositoryProperties.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,18 +18,13 @@ import java.time.Duration; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import org.springframework.ai.chat.memory.repository.cassandra.CassandraChatMemoryRepositoryConfig; import org.springframework.boot.context.properties.ConfigurationProperties; -import org.springframework.lang.Nullable; /** - * Configuration properties for Cassandra chat memory. + * Configuration properties for Cassandra Chat Memory Repository. * * @author Mick Semb Wever - * @author Jihoon Kim * @since 1.0.0 */ @ConfigurationProperties(CassandraChatMemoryRepositoryProperties.CONFIG_PREFIX) @@ -37,25 +32,30 @@ public class CassandraChatMemoryRepositoryProperties { public static final String CONFIG_PREFIX = "spring.ai.chat.memory.repository.cassandra"; - private static final Logger logger = LoggerFactory.getLogger(CassandraChatMemoryRepositoryProperties.class); - + /** + * Cassandra keyspace name. + */ private String keyspace = CassandraChatMemoryRepositoryConfig.DEFAULT_KEYSPACE_NAME; + /** + * Cassandra table name. + */ private String table = CassandraChatMemoryRepositoryConfig.DEFAULT_TABLE_NAME; + /** + * Cassandra column name for messages. + */ private String messagesColumn = CassandraChatMemoryRepositoryConfig.DEFAULT_MESSAGES_COLUMN_NAME; - private boolean initializeSchema = true; - - public boolean isInitializeSchema() { - return this.initializeSchema; - } - - public void setInitializeSchema(boolean initializeSchema) { - this.initializeSchema = initializeSchema; - } + /** + * Time to live (TTL) for messages written in Cassandra. + */ + private Duration timeToLive; - private Duration timeToLive = null; + /** + * Whether to initialize the schema on startup. + */ + private boolean initializeSchema = true; public String getKeyspace() { return this.keyspace; @@ -81,7 +81,6 @@ public void setMessagesColumn(String messagesColumn) { this.messagesColumn = messagesColumn; } - @Nullable public Duration getTimeToLive() { return this.timeToLive; } @@ -90,4 +89,13 @@ public void setTimeToLive(Duration timeToLive) { this.timeToLive = timeToLive; } + public boolean isInitializeSchema() { + return this.initializeSchema; + } + + public void setInitializeSchema(boolean initializeSchema) { + this.initializeSchema = initializeSchema; + } + } + diff --git a/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cosmos-db/src/main/java/org/springframework/ai/model/chat/memory/repository/cosmosdb/autoconfigure/CosmosDBChatMemoryRepositoryAutoConfiguration.java b/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cosmos-db/src/main/java/org/springframework/ai/model/chat/memory/repository/cosmosdb/autoconfigure/CosmosDBChatMemoryRepositoryAutoConfiguration.java index 7330132802a..a008c2523f1 100644 --- a/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cosmos-db/src/main/java/org/springframework/ai/model/chat/memory/repository/cosmosdb/autoconfigure/CosmosDBChatMemoryRepositoryAutoConfiguration.java +++ b/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cosmos-db/src/main/java/org/springframework/ai/model/chat/memory/repository/cosmosdb/autoconfigure/CosmosDBChatMemoryRepositoryAutoConfiguration.java @@ -22,6 +22,7 @@ import org.springframework.ai.chat.memory.repository.cosmosdb.CosmosDBChatMemoryRepository; import org.springframework.ai.chat.memory.repository.cosmosdb.CosmosDBChatMemoryRepositoryConfig; +import org.springframework.ai.model.chat.memory.autoconfigure.ChatMemoryAutoConfiguration; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; @@ -35,25 +36,20 @@ * @author Theo van Kraay * @since 1.1.0 */ -@AutoConfiguration +@AutoConfiguration(before = ChatMemoryAutoConfiguration.class) @ConditionalOnClass({ CosmosDBChatMemoryRepository.class, CosmosAsyncClient.class }) @EnableConfigurationProperties(CosmosDBChatMemoryRepositoryProperties.class) +@ConditionalOnProperty(prefix = CosmosDBChatMemoryRepositoryProperties.CONFIG_PREFIX, name = "endpoint") public class CosmosDBChatMemoryRepositoryAutoConfiguration { - private final String agentSuffix = "SpringAI-CDBNoSQL-ChatMemoryRepository"; + private final String agentSuffix = "SpringAI-CDBNoSQL-ChatMemory"; @Bean @ConditionalOnMissingBean - @ConditionalOnProperty(prefix = "spring.ai.chat.memory.repository.cosmosdb", name = "endpoint") - public CosmosAsyncClient cosmosClient(CosmosDBChatMemoryRepositoryProperties properties) { - if (properties.getEndpoint() == null || properties.getEndpoint().isEmpty()) { - throw new IllegalArgumentException( - "Cosmos DB endpoint must be provided via spring.ai.chat.memory.repository.cosmosdb.endpoint property"); - } - + public CosmosAsyncClient cosmosAsyncClient(CosmosDBChatMemoryRepositoryProperties properties) { String mode = properties.getConnectionMode(); if (mode == null) { - properties.setConnectionMode("gateway"); + mode = "gateway"; } else if (!mode.equals("direct") && !mode.equals("gateway")) { throw new IllegalArgumentException("Connection mode must be either 'direct' or 'gateway'"); @@ -69,27 +65,22 @@ else if (!mode.equals("direct") && !mode.equals("gateway")) { builder.key(properties.getKey()); } - return ("direct".equals(properties.getConnectionMode()) ? builder.directMode() : builder.gatewayMode()) - .buildAsyncClient(); + return ("direct".equals(mode) ? builder.directMode() : builder.gatewayMode()).buildAsyncClient(); } @Bean @ConditionalOnMissingBean - public CosmosDBChatMemoryRepositoryConfig cosmosDBChatMemoryRepositoryConfig( + public CosmosDBChatMemoryRepository cosmosDBChatMemoryRepository( CosmosDBChatMemoryRepositoryProperties properties, CosmosAsyncClient cosmosAsyncClient) { - return CosmosDBChatMemoryRepositoryConfig.builder() + var configBuilder = CosmosDBChatMemoryRepositoryConfig.builder() .withCosmosClient(cosmosAsyncClient) .withDatabaseName(properties.getDatabaseName()) .withContainerName(properties.getContainerName()) - .withPartitionKeyPath(properties.getPartitionKeyPath()) - .build(); - } + .withPartitionKeyPath(properties.getPartitionKeyPath()); - @Bean - @ConditionalOnMissingBean - public CosmosDBChatMemoryRepository cosmosDBChatMemoryRepository(CosmosDBChatMemoryRepositoryConfig config) { - return CosmosDBChatMemoryRepository.create(config); + return CosmosDBChatMemoryRepository.create(configBuilder.build()); } } + diff --git a/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cosmos-db/src/main/java/org/springframework/ai/model/chat/memory/repository/cosmosdb/autoconfigure/CosmosDBChatMemoryRepositoryProperties.java b/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cosmos-db/src/main/java/org/springframework/ai/model/chat/memory/repository/cosmosdb/autoconfigure/CosmosDBChatMemoryRepositoryProperties.java index 61bea63cbfe..1c8dbd90a7c 100644 --- a/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cosmos-db/src/main/java/org/springframework/ai/model/chat/memory/repository/cosmosdb/autoconfigure/CosmosDBChatMemoryRepositoryProperties.java +++ b/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cosmos-db/src/main/java/org/springframework/ai/model/chat/memory/repository/cosmosdb/autoconfigure/CosmosDBChatMemoryRepositoryProperties.java @@ -20,7 +20,7 @@ import org.springframework.boot.context.properties.ConfigurationProperties; /** - * Configuration properties for CosmosDB chat memory. + * Configuration properties for CosmosDB Chat Memory Repository. * * @author Theo van Kraay * @since 1.1.0 @@ -30,16 +30,35 @@ public class CosmosDBChatMemoryRepositoryProperties { public static final String CONFIG_PREFIX = "spring.ai.chat.memory.repository.cosmosdb"; + /** + * Azure Cosmos DB endpoint URI. Required for auto-configuration. + */ private String endpoint; + /** + * Azure Cosmos DB primary or secondary key. If not provided, Azure Identity + * authentication will be used. + */ private String key; + /** + * Connection mode for Cosmos DB client (direct or gateway). + */ private String connectionMode = "gateway"; + /** + * Name of the Cosmos DB database. + */ private String databaseName = CosmosDBChatMemoryRepositoryConfig.DEFAULT_DATABASE_NAME; + /** + * Name of the Cosmos DB container. + */ private String containerName = CosmosDBChatMemoryRepositoryConfig.DEFAULT_CONTAINER_NAME; + /** + * Partition key path for the container. + */ private String partitionKeyPath = CosmosDBChatMemoryRepositoryConfig.DEFAULT_PARTITION_KEY_PATH; public String getEndpoint() { @@ -91,3 +110,4 @@ public void setPartitionKeyPath(String partitionKeyPath) { } } +