diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/McpToolCallbackParameterlessToolIT.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/McpToolCallbackParameterlessToolIT.java new file mode 100644 index 00000000000..366f20df358 --- /dev/null +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/McpToolCallbackParameterlessToolIT.java @@ -0,0 +1,212 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mcp.server.autoconfigure; + +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.McpSyncServer; +import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema; +import org.junit.jupiter.api.Test; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +import org.springframework.ai.mcp.SyncMcpToolCallbackProvider; +import org.springframework.ai.mcp.client.common.autoconfigure.McpClientAutoConfiguration; +import org.springframework.ai.mcp.client.common.autoconfigure.McpToolCallbackAutoConfiguration; +import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientAnnotationScannerAutoConfiguration; +import org.springframework.ai.mcp.client.webflux.autoconfigure.StreamableHttpWebFluxTransportAutoConfiguration; +import org.springframework.ai.mcp.server.common.autoconfigure.McpServerAutoConfiguration; +import org.springframework.ai.mcp.server.common.autoconfigure.McpServerObjectMapperAutoConfiguration; +import org.springframework.ai.mcp.server.common.autoconfigure.ToolCallbackConverterAutoConfiguration; +import org.springframework.ai.mcp.server.common.autoconfigure.annotations.McpServerAnnotationScannerAutoConfiguration; +import org.springframework.ai.mcp.server.common.autoconfigure.annotations.McpServerSpecificationFactoryAutoConfiguration; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.retry.autoconfigure.SpringAiRetryAutoConfiguration; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.ToolCallbackProvider; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; +import org.springframework.boot.autoconfigure.web.reactive.function.client.WebClientAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.ApplicationContext; +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.test.util.TestSocketUtils; +import org.springframework.web.reactive.function.server.RouterFunctions; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; + +/** + * Integration test to reproduce the issue where MCP tools with no parameters (incomplete + * schemas) fail to create valid tool definitions. + * + * @author Ilayaperumal Gopinathan + */ +class McpToolCallbackParameterlessToolIT { + + private final ApplicationContextRunner syncServerContextRunner = new ApplicationContextRunner() + .withPropertyValues("spring.ai.mcp.server.protocol=STREAMABLE", "spring.ai.mcp.server.type=SYNC") + .withConfiguration(AutoConfigurations.of(McpServerAutoConfiguration.class, + McpServerObjectMapperAutoConfiguration.class, ToolCallbackConverterAutoConfiguration.class, + McpServerStreamableHttpWebFluxAutoConfiguration.class, + McpServerAnnotationScannerAutoConfiguration.class, + McpServerSpecificationFactoryAutoConfiguration.class)); + + private final ApplicationContextRunner clientApplicationContext = new ApplicationContextRunner() + .withConfiguration(baseAutoConfig(McpToolCallbackAutoConfiguration.class, McpClientAutoConfiguration.class, + StreamableHttpWebFluxTransportAutoConfiguration.class, + McpClientAnnotationScannerAutoConfiguration.class)); + + private static AutoConfigurations baseAutoConfig(Class... additional) { + Class[] dependencies = { SpringAiRetryAutoConfiguration.class, RestClientAutoConfiguration.class, + WebClientAutoConfiguration.class }; + Class[] all = Stream.concat(Arrays.stream(dependencies), Arrays.stream(additional)).toArray(Class[]::new); + return AutoConfigurations.of(all); + } + + @Test + void testMcpServerClientIntegrationWithIncompleteSchemaSyncTool() { + int serverPort = TestSocketUtils.findAvailableTcpPort(); + + this.syncServerContextRunner + .withPropertyValues(// @formatter:off + "spring.ai.mcp.server.name=test-incomplete-schema-server", + "spring.ai.mcp.server.version=1.0.0", + "spring.ai.mcp.server.streamable-http.keep-alive-interval=1s", + "spring.ai.mcp.server.streamable-http.mcp-endpoint=/mcp") // @formatter:on + .run(serverContext -> { + + McpSyncServer mcpSyncServer = serverContext.getBean(McpSyncServer.class); + + ObjectMapper objectMapper = serverContext.getBean(ObjectMapper.class); + + String incompleteSchemaJson = "{\"type\":\"object\",\"additionalProperties\":false}"; + McpSchema.JsonSchema incompleteSchema = objectMapper.readValue(incompleteSchemaJson, + McpSchema.JsonSchema.class); + + // Build the tool using the builder pattern + McpSchema.Tool parameterlessTool = McpSchema.Tool.builder() + .name("getCurrentTime") + .description("Get the current server time") + .inputSchema(incompleteSchema) + .build(); + + // Create a tool specification that returns a simple response + McpServerFeatures.SyncToolSpecification toolSpec = new McpServerFeatures.SyncToolSpecification( + parameterlessTool, (exchange, arguments) -> { + McpSchema.TextContent content = new McpSchema.TextContent( + "Current time: " + Instant.now().toString()); + return new McpSchema.CallToolResult(List.of(content), false, null); + }, (exchange, request) -> { + McpSchema.TextContent content = new McpSchema.TextContent( + "Current time: " + Instant.now().toString()); + return new McpSchema.CallToolResult(List.of(content), false, null); + }); + + // Add the tool with incomplete schema to the server + mcpSyncServer.addTool(toolSpec); + + var httpServer = startHttpServer(serverContext, serverPort); + + this.clientApplicationContext + .withPropertyValues(// @formatter:off + "spring.ai.mcp.client.type=SYNC", + "spring.ai.mcp.client.streamable-http.connections.server1.url=http://localhost:" + serverPort, + "spring.ai.mcp.client.initialized=false") // @formatter:on + .run(clientContext -> { + + ToolCallbackProvider toolCallbackProvider = clientContext + .getBean(SyncMcpToolCallbackProvider.class); + + // Wait for the client to receive the tool from the server + await().atMost(Duration.ofSeconds(5)) + .pollInterval(Duration.ofMillis(100)) + .untilAsserted(() -> assertThat(toolCallbackProvider.getToolCallbacks()).isNotEmpty()); + + List toolCallbacks = Arrays.asList(toolCallbackProvider.getToolCallbacks()); + + // We expect 1 tool: getCurrentTime (parameterless with incomplete + // schema) + assertThat(toolCallbacks).hasSize(1); + + // Get the tool callback + ToolCallback toolCallback = toolCallbacks.get(0); + ToolDefinition toolDefinition = toolCallback.getToolDefinition(); + + // Verify the tool definition + assertThat(toolDefinition).isNotNull(); + assertThat(toolDefinition.name()).contains("getCurrentTime"); + assertThat(toolDefinition.description()).isEqualTo("Get the current server time"); + + // **THE KEY VERIFICATION**: The input schema should now have the + // "properties" field + // even though the server provided a schema without it + String inputSchema = toolDefinition.inputSchema(); + assertThat(inputSchema).isNotNull().isNotEmpty(); + + Map schemaMap = ModelOptionsUtils.jsonToMap(inputSchema); + assertThat(schemaMap).isNotNull(); + assertThat(schemaMap).containsKey("type"); + assertThat(schemaMap.get("type")).isEqualTo("object"); + + assertThat(schemaMap).containsKey("properties"); + assertThat(schemaMap.get("properties")).isInstanceOf(Map.class); + + // Verify the properties map is empty for a parameterless tool + Map properties = (Map) schemaMap.get("properties"); + assertThat(properties).isEmpty(); + + // Verify that additionalProperties is preserved after + // normalization + assertThat(schemaMap).containsKey("additionalProperties"); + assertThat(schemaMap.get("additionalProperties")).isEqualTo(false); + + // Test that the callback can be called successfully + String result = toolCallback.call("{}"); + assertThat(result).isNotNull().contains("Current time:"); + }); + + stopHttpServer(httpServer); + }); + } + + // Helper methods to start and stop the HTTP server + private static DisposableServer startHttpServer(ApplicationContext serverContext, int port) { + WebFluxStreamableServerTransportProvider mcpStreamableServerTransport = serverContext + .getBean(WebFluxStreamableServerTransportProvider.class); + HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpStreamableServerTransport.getRouterFunction()); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + return HttpServer.create().port(port).handle(adapter).bindNow(); + } + + private static void stopHttpServer(DisposableServer server) { + if (server != null) { + server.disposeNow(); + } + } + +} diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java index 9e1b46b4c41..4e02edeb379 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java @@ -30,7 +30,6 @@ import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder; import org.springframework.ai.tool.ToolCallback; -import org.springframework.ai.tool.definition.DefaultToolDefinition; import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.ai.tool.execution.ToolExecutionException; import org.springframework.lang.Nullable; @@ -45,6 +44,7 @@ * * @author Christian Tzolov * @author YunKui Lu + * @author Ilayaperumal Gopinathan */ public class AsyncMcpToolCallback implements ToolCallback { @@ -92,11 +92,7 @@ private AsyncMcpToolCallback(McpAsyncClient mcpClient, Tool tool, String prefixe @Override public ToolDefinition getToolDefinition() { - return DefaultToolDefinition.builder() - .name(this.prefixedToolName) - .description(this.tool.description()) - .inputSchema(ModelOptionsUtils.toJsonString(this.tool.inputSchema())) - .build(); + return McpToolUtils.createToolDefinition(this.prefixedToolName, this.tool); } public String getOriginalToolName() { diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolUtils.java b/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolUtils.java index 2aeff595d00..01c78d35516 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolUtils.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolUtils.java @@ -39,6 +39,9 @@ import org.springframework.ai.chat.model.ToolContext; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.definition.DefaultToolDefinition; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.util.json.schema.JsonSchemaUtils; import org.springframework.lang.Nullable; import org.springframework.util.CollectionUtils; import org.springframework.util.MimeType; @@ -62,6 +65,7 @@ * * * @author Christian Tzolov + * @author Ilayaperumal Gopinathan */ public final class McpToolUtils { @@ -227,6 +231,20 @@ public static McpStatelessServerFeatures.SyncToolSpecification toStatelessSyncTo .build(); } + /** + * Creates a Spring AI ToolDefinition from an MCP Tool. + * @param prefixedToolName the prefixed name for the tool + * @param tool the MCP tool + * @return a ToolDefinition with normalized input schema + */ + public static ToolDefinition createToolDefinition(String prefixedToolName, McpSchema.Tool tool) { + return DefaultToolDefinition.builder() + .name(prefixedToolName) + .description(tool.description()) + .inputSchema(JsonSchemaUtils.ensureValidInputSchema(ModelOptionsUtils.toJsonString(tool.inputSchema()))) + .build(); + } + private static SharedSyncToolSpecification toSharedSyncToolSpecification(ToolCallback toolCallback, MimeType mimeType) { diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java index 277401594c2..e678efffc95 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java @@ -29,7 +29,6 @@ import org.springframework.ai.chat.model.ToolContext; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.tool.ToolCallback; -import org.springframework.ai.tool.definition.DefaultToolDefinition; import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.ai.tool.execution.ToolExecutionException; import org.springframework.lang.Nullable; @@ -41,6 +40,7 @@ * * @author Christian Tzolov * @author YunKui Lu + * @author Ilayaperumal Gopinathan * @since 1.0.0 */ public class SyncMcpToolCallback implements ToolCallback { @@ -89,11 +89,7 @@ private SyncMcpToolCallback(McpSyncClient mcpClient, Tool tool, String prefixedT @Override public ToolDefinition getToolDefinition() { - return DefaultToolDefinition.builder() - .name(this.prefixedToolName) - .description(this.tool.description()) - .inputSchema(ModelOptionsUtils.toJsonString(this.tool.inputSchema())) - .build(); + return McpToolUtils.createToolDefinition(this.prefixedToolName, this.tool); } /** diff --git a/models/spring-ai-openai/pom.xml b/models/spring-ai-openai/pom.xml index 3e534615f2b..da1dd83e9e3 100644 --- a/models/spring-ai-openai/pom.xml +++ b/models/spring-ai-openai/pom.xml @@ -87,6 +87,19 @@ test + + org.springframework.ai + spring-ai-mcp + ${project.version} + test + + + + org.mockito + mockito-core + test + + io.micrometer micrometer-observation-test diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java index afa0d89ec14..b7401ac5027 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java @@ -28,6 +28,7 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; +import io.modelcontextprotocol.spec.McpSchema; import org.assertj.core.data.Percentage; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; @@ -82,6 +83,11 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; @SpringBootTest(classes = OpenAiTestConfiguration.class) @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") @@ -465,6 +471,128 @@ void streamFunctionCallUsageTest() { assertThat(usage.getTotalTokens()).isGreaterThan(680).isLessThan(960); } + @Test + void functionCallWithMcpParameterlessToolTest() { + UserMessage userMessage = new UserMessage("What is the current server time?"); + + List messages = new ArrayList<>(List.of(userMessage)); + + // Mock a parameter-less JsonSchema (no properties field) + // This simulates what an external MCP server might provide + McpSchema.JsonSchema mockJsonSchema = mock(McpSchema.JsonSchema.class); + when(mockJsonSchema.type()).thenReturn("object"); + when(mockJsonSchema.additionalProperties()).thenReturn(false); + when(mockJsonSchema.properties()).thenReturn(null); // No properties field + + // Create a mock MCP tool + McpSchema.Tool mockMcpTool = mock(McpSchema.Tool.class); + when(mockMcpTool.name()).thenReturn("getCurrentTime"); + when(mockMcpTool.description()).thenReturn("Get the current server time"); + when(mockMcpTool.inputSchema()).thenReturn(mockJsonSchema); + + // Create a mock MCP client + io.modelcontextprotocol.client.McpSyncClient mockMcpClient = mock( + io.modelcontextprotocol.client.McpSyncClient.class); + + McpSchema.Implementation clientInfo = new McpSchema.Implementation("test-mcp-client", "1.0.0"); + when(mockMcpClient.getClientInfo()).thenReturn(clientInfo); + + // Mock the tool call response + McpSchema.TextContent mockTextContent = mock(McpSchema.TextContent.class); + when(mockTextContent.type()).thenReturn("text"); + when(mockTextContent.text()).thenReturn("2025-11-11T12:00:00Z"); + + McpSchema.CallToolResult toolResult = mock(McpSchema.CallToolResult.class); + when(toolResult.content()).thenReturn(List.of(mockTextContent)); + when(toolResult.isError()).thenReturn(false); + + when(mockMcpClient.callTool(any())).thenReturn(toolResult); + + // Create the SyncMcpToolCallback + org.springframework.ai.mcp.SyncMcpToolCallback mcpToolCallback = org.springframework.ai.mcp.SyncMcpToolCallback + .builder() + .mcpClient(mockMcpClient) + .tool(mockMcpTool) + .prefixedToolName("test-mcp-client_getCurrentTime") + .build(); + + var promptOptions = OpenAiChatOptions.builder() + .model(OpenAiApi.ChatModel.GPT_4_O.getValue()) + .toolCallbacks(List.of(mcpToolCallback)) + .build(); + + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); + + logger.info("Response: {}", response); + + // Verify the response contains time-related content + assertThat(response.getResult().getOutput().getText()).isNotBlank(); + + // Verify the mock MCP client was called + verify(mockMcpClient, atLeastOnce()).callTool(any()); + } + + @Test + void functionCallWithMcpParameterlessAsyncToolTest() { + UserMessage userMessage = new UserMessage("What is the current server time?"); + + List messages = new ArrayList<>(List.of(userMessage)); + + // Mock a parameter-less JsonSchema (no properties field) + // This simulates what an external MCP server might provide + McpSchema.JsonSchema mockJsonSchema = mock(McpSchema.JsonSchema.class); + when(mockJsonSchema.type()).thenReturn("object"); + when(mockJsonSchema.additionalProperties()).thenReturn(false); + when(mockJsonSchema.properties()).thenReturn(null); // No properties field + + // Create a mock MCP tool + McpSchema.Tool mockMcpTool = mock(McpSchema.Tool.class); + when(mockMcpTool.name()).thenReturn("getCurrentTime"); + when(mockMcpTool.description()).thenReturn("Get the current server time"); + when(mockMcpTool.inputSchema()).thenReturn(mockJsonSchema); + + // Create a mock async MCP client + io.modelcontextprotocol.client.McpAsyncClient mockMcpClient = mock( + io.modelcontextprotocol.client.McpAsyncClient.class); + + McpSchema.Implementation clientInfo = new McpSchema.Implementation("test-mcp-async-client", "1.0.0"); + when(mockMcpClient.getClientInfo()).thenReturn(clientInfo); + + // Mock the tool call response + McpSchema.TextContent mockTextContent = mock(McpSchema.TextContent.class); + when(mockTextContent.type()).thenReturn("text"); + when(mockTextContent.text()).thenReturn("2025-11-11T12:00:00Z"); + + McpSchema.CallToolResult toolResult = mock(McpSchema.CallToolResult.class); + when(toolResult.content()).thenReturn(List.of(mockTextContent)); + when(toolResult.isError()).thenReturn(false); + + when(mockMcpClient.callTool(any())).thenReturn(reactor.core.publisher.Mono.just(toolResult)); + + // Create the AsyncMcpToolCallback + org.springframework.ai.mcp.AsyncMcpToolCallback mcpToolCallback = org.springframework.ai.mcp.AsyncMcpToolCallback + .builder() + .mcpClient(mockMcpClient) + .tool(mockMcpTool) + .prefixedToolName("test-mcp-async-client_getCurrentTime") + .build(); + + var promptOptions = OpenAiChatOptions.builder() + .model(OpenAiApi.ChatModel.GPT_4_O.getValue()) + .toolCallbacks(List.of(mcpToolCallback)) + .build(); + + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); + + logger.info("Response: {}", response); + + // Verify the response contains time-related content + assertThat(response.getResult().getOutput().getText()).isNotBlank(); + + // Verify the mock MCP client was called + verify(mockMcpClient, atLeastOnce()).callTool(any()); + } + @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "gpt-4o" }) void multiModalityEmbeddedImage(String modelName) throws IOException { diff --git a/spring-ai-model/src/main/java/org/springframework/ai/util/json/schema/JsonSchemaUtils.java b/spring-ai-model/src/main/java/org/springframework/ai/util/json/schema/JsonSchemaUtils.java new file mode 100644 index 00000000000..2cbd72e5b05 --- /dev/null +++ b/spring-ai-model/src/main/java/org/springframework/ai/util/json/schema/JsonSchemaUtils.java @@ -0,0 +1,72 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.util.json.schema; + +import java.util.Map; + +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.util.StringUtils; + +/** + * Utility methods for working with JSON schemas. + * + * @author Guangdong Liu + * @author Ilayaperumal Gopinathan + * @since 1.0.0 + */ +public final class JsonSchemaUtils { + + private JsonSchemaUtils() { + } + + /** + * Ensures that the input schema is valid for AI model APIs. Many AI models require + * that the parameters object must have a "properties" field, even if it's empty. This + * method normalizes schemas from external sources (like MCP tools) that may not + * include this field. + * @param inputSchema the input schema as a JSON string + * @return a valid input schema as a JSON string with required fields + */ + public static String ensureValidInputSchema(String inputSchema) { + if (!StringUtils.hasText(inputSchema)) { + return inputSchema; + } + + Map schemaMap = ModelOptionsUtils.jsonToMap(inputSchema); + + if (schemaMap == null || schemaMap.isEmpty()) { + // Create a minimal valid schema + schemaMap = new java.util.HashMap<>(); + schemaMap.put("type", "object"); + schemaMap.put("properties", new java.util.HashMap<>()); + return ModelOptionsUtils.toJsonString(schemaMap); + } + + // Ensure "type" field exists + if (!schemaMap.containsKey("type")) { + schemaMap.put("type", "object"); + } + + // Ensure "properties" field exists for object types + if ("object".equals(schemaMap.get("type")) && !schemaMap.containsKey("properties")) { + schemaMap.put("properties", new java.util.HashMap<>()); + } + + return ModelOptionsUtils.toJsonString(schemaMap); + } + +} diff --git a/spring-ai-model/src/test/java/org/springframework/ai/util/json/schema/JsonSchemaUtilsTests.java b/spring-ai-model/src/test/java/org/springframework/ai/util/json/schema/JsonSchemaUtilsTests.java new file mode 100644 index 00000000000..aadbf057bd0 --- /dev/null +++ b/spring-ai-model/src/test/java/org/springframework/ai/util/json/schema/JsonSchemaUtilsTests.java @@ -0,0 +1,160 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.util.json.schema; + +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.model.ModelOptionsUtils; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit tests for {@link JsonSchemaUtils}. + * + * @author Ilayaperumal Gopinathan + */ +class JsonSchemaUtilsTests { + + /** + * Test that a schema with only "type": "object" and no "properties" field is + * normalized to include an empty "properties" field. + *

+ * This scenario occurs when external MCP servers (like Claude Desktop) provide tool + * schemas for parameterless tools that don't include the "properties" field. + */ + @Test + void testEnsureValidInputSchemaAddsPropertiesField() { + // Simulate a schema from an external MCP server without "properties" + String inputSchema = "{\"type\":\"object\",\"additionalProperties\":false}"; + + String normalizedSchema = JsonSchemaUtils.ensureValidInputSchema(inputSchema); + + Map schemaMap = ModelOptionsUtils.jsonToMap(normalizedSchema); + assertThat(schemaMap).isNotNull(); + assertThat(schemaMap).containsKey("type"); + assertThat(schemaMap.get("type")).isEqualTo("object"); + + // The key assertion: verify that "properties" field was added + assertThat(schemaMap).containsKey("properties"); + assertThat(schemaMap.get("properties")).isInstanceOf(Map.class); + + // For a parameterless tool, properties should be empty + Map properties = (Map) schemaMap.get("properties"); + assertThat(properties).isEmpty(); + } + + /** + * Test that a schema without a "type" field is normalized to include both "type" and + * "properties" fields. + */ + @Test + void testEnsureValidInputSchemaAddsTypeAndPropertiesFields() { + // Simulate a minimal schema without "type" + String inputSchema = "{\"additionalProperties\":false}"; + + String normalizedSchema = JsonSchemaUtils.ensureValidInputSchema(inputSchema); + + Map schemaMap = ModelOptionsUtils.jsonToMap(normalizedSchema); + assertThat(schemaMap).isNotNull(); + + // Verify both "type" and "properties" were added + assertThat(schemaMap).containsKey("type"); + assertThat(schemaMap.get("type")).isEqualTo("object"); + assertThat(schemaMap).containsKey("properties"); + assertThat(schemaMap.get("properties")).isInstanceOf(Map.class); + } + + /** + * Test that an empty or null schema is normalized to a minimal valid schema. + */ + @Test + void testEnsureValidInputSchemaWithEmptySchema() { + String inputSchema = "{}"; + + String normalizedSchema = JsonSchemaUtils.ensureValidInputSchema(inputSchema); + + Map schemaMap = ModelOptionsUtils.jsonToMap(normalizedSchema); + assertThat(schemaMap).isNotNull(); + assertThat(schemaMap).containsKey("type"); + assertThat(schemaMap.get("type")).isEqualTo("object"); + assertThat(schemaMap).containsKey("properties"); + assertThat(schemaMap.get("properties")).isInstanceOf(Map.class); + } + + /** + * Test that a schema with existing "properties" field is not modified. + */ + @Test + void testEnsureValidInputSchemaPreservesExistingProperties() { + // A properly formed schema with properties + String inputSchema = "{\"type\":\"object\",\"properties\":{\"cityName\":{\"type\":\"string\"}}}"; + + String normalizedSchema = JsonSchemaUtils.ensureValidInputSchema(inputSchema); + + Map schemaMap = ModelOptionsUtils.jsonToMap(normalizedSchema); + assertThat(schemaMap).isNotNull(); + assertThat(schemaMap).containsKey("type"); + assertThat(schemaMap).containsKey("properties"); + + // Verify existing properties are preserved + Map properties = (Map) schemaMap.get("properties"); + assertThat(properties).isNotEmpty(); + assertThat(properties).containsKey("cityName"); + } + + /** + * Test that a schema with "type": "string" (not "object") is not modified. + */ + @Test + void testEnsureValidInputSchemaWithNonObjectType() { + String inputSchema = "{\"type\":\"string\"}"; + + String normalizedSchema = JsonSchemaUtils.ensureValidInputSchema(inputSchema); + + Map schemaMap = ModelOptionsUtils.jsonToMap(normalizedSchema); + assertThat(schemaMap).isNotNull(); + assertThat(schemaMap).containsKey("type"); + assertThat(schemaMap.get("type")).isEqualTo("string"); + + // Properties field should not be added for non-object types + assertThat(schemaMap).doesNotContainKey("properties"); + } + + /** + * Test that null or empty input returns a valid minimal schema. + */ + @Test + void testEnsureValidInputSchemaWithNullInput() { + String normalizedSchema = JsonSchemaUtils.ensureValidInputSchema(null); + + // Null input should be handled gracefully + assertThat(normalizedSchema).isNull(); + } + + /** + * Test that blank input returns the input as-is. + */ + @Test + void testEnsureValidInputSchemaWithBlankInput() { + String normalizedSchema = JsonSchemaUtils.ensureValidInputSchema(""); + + assertThat(normalizedSchema).isEmpty(); + } + +}