diff --git a/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-neo4j/src/test/java/org/springframework/ai/model/chat/memory/repository/neo4j/autoconfigure/Neo4jChatMemoryRepositoryAutoConfigurationIT.java b/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-neo4j/src/test/java/org/springframework/ai/model/chat/memory/repository/neo4j/autoconfigure/Neo4jChatMemoryRepositoryAutoConfigurationIT.java index 82489c2d2de..f60f6f61320 100644 --- a/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-neo4j/src/test/java/org/springframework/ai/model/chat/memory/repository/neo4j/autoconfigure/Neo4jChatMemoryRepositoryAutoConfigurationIT.java +++ b/auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-neo4j/src/test/java/org/springframework/ai/model/chat/memory/repository/neo4j/autoconfigure/Neo4jChatMemoryRepositoryAutoConfigurationIT.java @@ -115,10 +115,11 @@ void addAndGet() { assertThat(((UserMessage) messages.get(0)).getMedia()).usingRecursiveFieldByFieldElementComparator() .isEqualTo(media); memory.deleteByConversationId(sessionId); - ToolResponseMessage toolResponseMessage = new ToolResponseMessage( - List.of(new ToolResponse("id", "name", "responseData"), - new ToolResponse("id2", "name2", "responseData2")), - Map.of("id", "id", "metadataKey", "metadata")); + ToolResponseMessage toolResponseMessage = ToolResponseMessage.builder() + .responses(List.of(new ToolResponse("id", "name", "responseData"), + new ToolResponse("id2", "name2", "responseData2"))) + .metadata(Map.of("id", "id", "metadataKey", "metadata")) + .build(); memory.saveAll(sessionId, List.of(toolResponseMessage)); messages = memory.findByConversationId(sessionId); assertThat(messages.size()).isEqualTo(1); diff --git a/memory/repository/spring-ai-model-chat-memory-repository-cassandra/src/main/java/org/springframework/ai/chat/memory/repository/cassandra/CassandraChatMemoryRepository.java b/memory/repository/spring-ai-model-chat-memory-repository-cassandra/src/main/java/org/springframework/ai/chat/memory/repository/cassandra/CassandraChatMemoryRepository.java index fd337c50278..aae301fe30f 100644 --- a/memory/repository/spring-ai-model-chat-memory-repository-cassandra/src/main/java/org/springframework/ai/chat/memory/repository/cassandra/CassandraChatMemoryRepository.java +++ b/memory/repository/spring-ai-model-chat-memory-repository-cassandra/src/main/java/org/springframework/ai/chat/memory/repository/cassandra/CassandraChatMemoryRepository.java @@ -216,7 +216,7 @@ private Message getMessage(UdtValue udt) { return SystemMessage.builder().text(content).metadata(props).build(); case TOOL: // todo – persist ToolResponse somehow - return new ToolResponseMessage(List.of(), props); + return ToolResponseMessage.builder().responses(List.of()).metadata(props).build(); default: throw new IllegalStateException( String.format("unknown message type %s", udt.getString(this.conf.messageUdtTypeColumn))); diff --git a/memory/repository/spring-ai-model-chat-memory-repository-cosmos-db/src/main/java/org/springframework/ai/chat/memory/repository/cosmosdb/CosmosDBChatMemoryRepository.java b/memory/repository/spring-ai-model-chat-memory-repository-cosmos-db/src/main/java/org/springframework/ai/chat/memory/repository/cosmosdb/CosmosDBChatMemoryRepository.java index a399608861b..4bdb1534ca8 100644 --- a/memory/repository/spring-ai-model-chat-memory-repository-cosmos-db/src/main/java/org/springframework/ai/chat/memory/repository/cosmosdb/CosmosDBChatMemoryRepository.java +++ b/memory/repository/spring-ai-model-chat-memory-repository-cosmos-db/src/main/java/org/springframework/ai/chat/memory/repository/cosmosdb/CosmosDBChatMemoryRepository.java @@ -235,7 +235,7 @@ private Message mapToMessage(Map doc) { case ASSISTANT -> new AssistantMessage(content, metadata); case USER -> UserMessage.builder().text(content).metadata(metadata).build(); case SYSTEM -> SystemMessage.builder().text(content).metadata(metadata).build(); - case TOOL -> new ToolResponseMessage(List.of(), metadata); + case TOOL -> ToolResponseMessage.builder().responses(List.of()).metadata(metadata).build(); default -> throw new IllegalStateException(String.format("Unknown message type: %s", messageTypeStr)); }; } diff --git a/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepository.java b/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepository.java index 3ae6b0d81a8..7d1785747f0 100644 --- a/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepository.java +++ b/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepository.java @@ -148,7 +148,7 @@ public Message mapRow(ResultSet rs, int i) throws SQLException { // The content is always stored empty for ToolResponseMessages. // If we want to capture the actual content, we need to extend // AddBatchPreparedStatement to support it. - case TOOL -> new ToolResponseMessage(List.of()); + case TOOL -> ToolResponseMessage.builder().responses(List.of()).build(); }; } diff --git a/memory/repository/spring-ai-model-chat-memory-repository-neo4j/src/main/java/org/springframework/ai/chat/memory/repository/neo4j/Neo4jChatMemoryRepository.java b/memory/repository/spring-ai-model-chat-memory-repository-neo4j/src/main/java/org/springframework/ai/chat/memory/repository/neo4j/Neo4jChatMemoryRepository.java index 2d9a2099906..2309ad4dd81 100644 --- a/memory/repository/spring-ai-model-chat-memory-repository-neo4j/src/main/java/org/springframework/ai/chat/memory/repository/neo4j/Neo4jChatMemoryRepository.java +++ b/memory/repository/spring-ai-model-chat-memory-repository-neo4j/src/main/java/org/springframework/ai/chat/memory/repository/neo4j/Neo4jChatMemoryRepository.java @@ -35,6 +35,7 @@ import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.ToolResponseMessage; +import org.springframework.ai.chat.messages.ToolResponseMessage.ToolResponse; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.content.Media; import org.springframework.ai.content.MediaContent; @@ -172,12 +173,12 @@ public Neo4jChatMemoryRepositoryConfig getConfig() { private Message buildToolMessage(org.neo4j.driver.Record record) { Message message; - message = new ToolResponseMessage(record.get("toolResponses").asList(v -> { + message = ToolResponseMessage.builder().responses(record.get("toolResponses").asList(v -> { Map trMap = v.asMap(); - return new ToolResponseMessage.ToolResponse((String) trMap.get(ToolResponseAttributes.ID.getValue()), + return new ToolResponse((String) trMap.get(ToolResponseAttributes.ID.getValue()), (String) trMap.get(ToolResponseAttributes.NAME.getValue()), (String) trMap.get(ToolResponseAttributes.RESPONSE_DATA.getValue())); - }), record.get("metadata").asMap()); + })).metadata(record.get("metadata").asMap()).build(); return message; } diff --git a/memory/repository/spring-ai-model-chat-memory-repository-neo4j/src/test/java/org/springframework/ai/chat/memory/repository/neo4j/Neo4jChatMemoryRepositoryIT.java b/memory/repository/spring-ai-model-chat-memory-repository-neo4j/src/test/java/org/springframework/ai/chat/memory/repository/neo4j/Neo4jChatMemoryRepositoryIT.java index acb06ede872..bee7a5d5888 100644 --- a/memory/repository/spring-ai-model-chat-memory-repository-neo4j/src/test/java/org/springframework/ai/chat/memory/repository/neo4j/Neo4jChatMemoryRepositoryIT.java +++ b/memory/repository/spring-ai-model-chat-memory-repository-neo4j/src/test/java/org/springframework/ai/chat/memory/repository/neo4j/Neo4jChatMemoryRepositoryIT.java @@ -130,7 +130,9 @@ void saveAndFindMultipleMessages() { List messages = List.of(new AssistantMessage("Message from assistant - " + conversationId), new UserMessage("Message from user - " + conversationId), new SystemMessage("Message from system - " + conversationId), - new ToolResponseMessage(List.of(new ToolResponse("id", "name", "responseData")))); + ToolResponseMessage.builder() + .responses(List.of(new ToolResponse("id", "name", "responseData"))) + .build()); this.chatMemoryRepository.saveAll(conversationId, messages); List retrievedMessages = this.chatMemoryRepository.findByConversationId(conversationId); @@ -285,9 +287,11 @@ void handleAssistantMessageWithToolCalls() { void handleToolResponseMessage() { var conversationId = UUID.randomUUID().toString(); - ToolResponseMessage toolResponseMessage = new ToolResponseMessage(List - .of(new ToolResponse("id1", "name1", "responseData1"), new ToolResponse("id2", "name2", "responseData2")), - Map.of("metadataKey", "metadataValue")); + ToolResponseMessage toolResponseMessage = ToolResponseMessage.builder() + .responses(List.of(new ToolResponse("id1", "name1", "responseData1"), + new ToolResponse("id2", "name2", "responseData2"))) + .metadata(Map.of("metadataKey", "metadataValue")) + .build(); this.chatMemoryRepository.saveAll(conversationId, List.of(toolResponseMessage)); @@ -408,7 +412,9 @@ private Message createMessageByType(String content, MessageType messageType) { case ASSISTANT -> new AssistantMessage(content); case USER -> new UserMessage(content); case SYSTEM -> new SystemMessage(content); - case TOOL -> new ToolResponseMessage(List.of(new ToolResponse("id", "name", "responseData"))); + case TOOL -> ToolResponseMessage.builder() + .responses(List.of(new ToolResponse("id", "name", "responseData"))) + .build(); }; } diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTests.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTests.java index 163cf2a7119..2751b1388cb 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTests.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTests.java @@ -199,7 +199,9 @@ void createChatCompletionMessagesWithToolResponseMessage() { var toolResponse1 = createToolResponse(1); var toolResponse2 = createToolResponse(2); var toolResponse3 = createToolResponse(3); - var toolResponseMessage = new ToolResponseMessage(List.of(toolResponse1, toolResponse2, toolResponse3)); + var toolResponseMessage = ToolResponseMessage.builder() + .responses(List.of(toolResponse1, toolResponse2, toolResponse3)) + .build(); var prompt = createPrompt(toolResponseMessage); var chatCompletionRequest = this.chatModel.createRequest(prompt, false); var chatCompletionMessages = chatCompletionRequest.messages(); @@ -212,7 +214,7 @@ void createChatCompletionMessagesWithToolResponseMessage() { @Test void createChatCompletionMessagesWithInvalidToolResponseMessage() { var toolResponse = new ToolResponseMessage.ToolResponse(null, null, null); - var toolResponseMessage = new ToolResponseMessage(List.of(toolResponse)); + var toolResponseMessage = ToolResponseMessage.builder().responses(List.of(toolResponse)).build(); var prompt = createPrompt(toolResponseMessage); assertThatThrownBy(() -> this.chatModel.createRequest(prompt, false)) .isInstanceOf(IllegalArgumentException.class) diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java index 972fe635aae..a0dc2b514fc 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java @@ -25,6 +25,7 @@ import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.ToolResponseMessage; +import org.springframework.ai.chat.messages.ToolResponseMessage.ToolResponse; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; @@ -256,11 +257,10 @@ private static List createMessagesWithAllMessageTypes() { var systemMessage = new SystemMessage("Test system message"); var userMessage = new UserMessage("Test user message"); // @formatter:off - var toolResponseMessage = new ToolResponseMessage(List.of( - new ToolResponseMessage.ToolResponse("tool1", "Tool 1", "Test tool response 1"), - new ToolResponseMessage.ToolResponse("tool2", "Tool 2", "Test tool response 2"), - new ToolResponseMessage.ToolResponse("tool3", "Tool 3", "Test tool response 3")) - ); + var toolResponseMessage = ToolResponseMessage.builder().responses(List.of( + new ToolResponse("tool1", "Tool 1", "Test tool response 1"), + new ToolResponse("tool2", "Tool 2", "Test tool response 2"), + new ToolResponse("tool3", "Tool 3", "Test tool response 3"))).build(); // @formatter:on var assistantMessage = new AssistantMessage("Test assistant message"); diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/ToolResponseMessage.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/ToolResponseMessage.java index 47da252180f..2f44eb4ec75 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/ToolResponseMessage.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/ToolResponseMessage.java @@ -31,15 +31,27 @@ public class ToolResponseMessage extends AbstractMessage { protected final List responses; + /** + * @deprecated in favor of using {@link ToolResponseMessage.Builder} + */ + @Deprecated public ToolResponseMessage(List responses) { this(responses, Map.of()); } + /** + * @deprecated in favor of using {@link ToolResponseMessage.Builder} + */ + @Deprecated public ToolResponseMessage(List responses, Map metadata) { super(MessageType.TOOL, "", metadata); this.responses = responses; } + public static Builder builder() { + return new Builder(); + } + public List getResponses() { return this.responses; } @@ -73,4 +85,29 @@ public record ToolResponse(String id, String name, String responseData) { } + public static final class Builder { + + private List responses; + + private Map metadata = Map.of(); + + private Builder() { + } + + public Builder responses(List responses) { + this.responses = responses; + return this; + } + + public Builder metadata(Map metadata) { + this.metadata = metadata; + return this; + } + + public ToolResponseMessage build() { + return new ToolResponseMessage(this.responses, this.metadata); + } + + } + } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/Prompt.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/Prompt.java index b77421de9aa..8c6c66197a8 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/Prompt.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/Prompt.java @@ -184,8 +184,10 @@ else if (message instanceof AssistantMessage assistantMessage) { .build()); } else if (message instanceof ToolResponseMessage toolResponseMessage) { - messagesCopy.add(new ToolResponseMessage(new ArrayList<>(toolResponseMessage.getResponses()), - new HashMap<>(toolResponseMessage.getMetadata()))); + messagesCopy.add(ToolResponseMessage.builder() + .responses(new ArrayList<>(toolResponseMessage.getResponses())) + .metadata(new HashMap<>(toolResponseMessage.getMetadata())) + .build()); } else { throw new IllegalArgumentException("Unsupported message type: " + message.getClass().getName()); diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java index 9811b8b38c7..02c35462857 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java @@ -251,7 +251,8 @@ private InternalToolExecutionResult executeToolCall(Prompt prompt, AssistantMess toolCallResult != null ? toolCallResult : "")); } - return new InternalToolExecutionResult(new ToolResponseMessage(toolResponses, Map.of()), returnDirect); + return new InternalToolExecutionResult(ToolResponseMessage.builder().responses(toolResponses).build(), + returnDirect); } private List buildConversationHistoryAfterToolExecution(List previousMessages, diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTests.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTests.java index 41c27409bde..54abeef0ffc 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTests.java @@ -25,6 +25,7 @@ import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.ToolResponseMessage; +import org.springframework.ai.chat.messages.ToolResponseMessage.ToolResponse; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; @@ -169,8 +170,9 @@ void whenSingleToolCallInChatResponseThenExecute() { .build()))) .build(); - ToolResponseMessage expectedToolResponse = new ToolResponseMessage( - List.of(new ToolResponseMessage.ToolResponse("toolA", "toolA", "Mission accomplished!"))); + ToolResponseMessage expectedToolResponse = ToolResponseMessage.builder() + .responses(List.of(new ToolResponse("toolA", "toolA", "Mission accomplished!"))) + .build(); ToolExecutionResult toolExecutionResult = toolCallingManager.executeToolCalls(prompt, chatResponse); @@ -194,8 +196,9 @@ void whenSingleToolCallWithReturnDirectInChatResponseThenExecute() { .build()))) .build(); - ToolResponseMessage expectedToolResponse = new ToolResponseMessage( - List.of(new ToolResponseMessage.ToolResponse("toolA", "toolA", "Mission accomplished!"))); + ToolResponseMessage expectedToolResponse = ToolResponseMessage.builder() + .responses(List.of(new ToolResponse("toolA", "toolA", "Mission accomplished!"))) + .build(); ToolExecutionResult toolExecutionResult = toolCallingManager.executeToolCalls(prompt, chatResponse); @@ -223,9 +226,10 @@ void whenMultipleToolCallsInChatResponseThenExecute() { .build()))) .build(); - ToolResponseMessage expectedToolResponse = new ToolResponseMessage( - List.of(new ToolResponseMessage.ToolResponse("toolA", "toolA", "Mission accomplished!"), - new ToolResponseMessage.ToolResponse("toolB", "toolB", "Mission accomplished!"))); + ToolResponseMessage expectedToolResponse = ToolResponseMessage.builder() + .responses(List.of(new ToolResponse("toolA", "toolA", "Mission accomplished!"), + new ToolResponse("toolB", "toolB", "Mission accomplished!"))) + .build(); ToolExecutionResult toolExecutionResult = toolCallingManager.executeToolCalls(prompt, chatResponse); @@ -249,8 +253,9 @@ void whenDuplicateMixedToolCallsInChatResponseThenExecute() { .build()))) .build(); - ToolResponseMessage expectedToolResponse = new ToolResponseMessage( - List.of(new ToolResponseMessage.ToolResponse("toolA", "toolA", "Mission accomplished!"))); + ToolResponseMessage expectedToolResponse = ToolResponseMessage.builder() + .responses(List.of(new ToolResponse("toolA", "toolA", "Mission accomplished!"))) + .build(); ToolExecutionResult toolExecutionResult = toolCallingManager.executeToolCalls(prompt, chatResponse); @@ -277,9 +282,10 @@ void whenMultipleToolCallsWithReturnDirectInChatResponseThenExecute() { .build()))) .build(); - ToolResponseMessage expectedToolResponse = new ToolResponseMessage( - List.of(new ToolResponseMessage.ToolResponse("toolA", "toolA", "Mission accomplished!"), - new ToolResponseMessage.ToolResponse("toolB", "toolB", "Mission accomplished!"))); + ToolResponseMessage expectedToolResponse = ToolResponseMessage.builder() + .responses(List.of(new ToolResponse("toolA", "toolA", "Mission accomplished!"), + new ToolResponse("toolB", "toolB", "Mission accomplished!"))) + .build(); ToolExecutionResult toolExecutionResult = toolCallingManager.executeToolCalls(prompt, chatResponse); @@ -307,9 +313,10 @@ void whenMultipleToolCallsWithMixedReturnDirectInChatResponseThenExecute() { .build()))) .build(); - ToolResponseMessage expectedToolResponse = new ToolResponseMessage( - List.of(new ToolResponseMessage.ToolResponse("toolA", "toolA", "Mission accomplished!"), - new ToolResponseMessage.ToolResponse("toolB", "toolB", "Mission accomplished!"))); + ToolResponseMessage expectedToolResponse = ToolResponseMessage.builder() + .responses(List.of(new ToolResponse("toolA", "toolA", "Mission accomplished!"), + new ToolResponse("toolB", "toolB", "Mission accomplished!"))) + .build(); ToolExecutionResult toolExecutionResult = toolCallingManager.executeToolCalls(prompt, chatResponse); @@ -334,8 +341,9 @@ void whenToolCallWithExceptionThenReturnError() { .build()))) .build(); - ToolResponseMessage expectedToolResponse = new ToolResponseMessage( - List.of(new ToolResponseMessage.ToolResponse("toolC", "toolC", "You failed this city!"))); + ToolResponseMessage expectedToolResponse = ToolResponseMessage.builder() + .responses(List.of(new ToolResponse("toolC", "toolC", "You failed this city!"))) + .build(); ToolExecutionResult toolExecutionResult = toolCallingManager.executeToolCalls(prompt, chatResponse); @@ -378,10 +386,10 @@ void whenMixedMethodToolCallsInChatResponseThenExecute() throws NoSuchMethodExce .build()))) .build(); - ToolResponseMessage expectedToolResponse = new ToolResponseMessage( - List.of(new ToolResponseMessage.ToolResponse("toolA", "toolA", TestGenericClass.CALL_RESULT_JSON), - new ToolResponseMessage.ToolResponse("toolB", "toolB", - TestGenericClass.CALL_WITH_TOOL_CONTEXT_RESULT_JSON))); + ToolResponseMessage expectedToolResponse = ToolResponseMessage.builder() + .responses(List.of(new ToolResponse("toolA", "toolA", TestGenericClass.CALL_RESULT_JSON), + new ToolResponse("toolB", "toolB", TestGenericClass.CALL_WITH_TOOL_CONTEXT_RESULT_JSON))) + .build(); ToolExecutionResult toolExecutionResult = toolCallingManager.executeToolCalls(prompt, chatResponse); diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolExecutionResultTests.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolExecutionResultTests.java index ca1862f985f..e1b3983907f 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolExecutionResultTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolExecutionResultTests.java @@ -22,6 +22,7 @@ import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.ToolResponseMessage; +import org.springframework.ai.chat.messages.ToolResponseMessage.ToolResponse; import org.springframework.ai.chat.messages.UserMessage; import static org.assertj.core.api.Assertions.assertThat; @@ -40,8 +41,10 @@ void whenSingleToolCallThenSingleGeneration() { .conversationHistory(List.of(new AssistantMessage("Hello, how can I help you?"), new UserMessage("I would like to know the weather in London"), new AssistantMessage("Call the weather tool"), - new ToolResponseMessage(List.of(new ToolResponseMessage.ToolResponse("42", "weather", - "The weather in London is 20 degrees Celsius"))))) + ToolResponseMessage.builder() + .responses(List + .of(new ToolResponse("42", "weather", "The weather in London is 20 degrees Celsius"))) + .build())) .build(); var generations = ToolExecutionResult.buildGenerations(toolExecutionResult); @@ -59,11 +62,11 @@ void whenMultipleToolCallsThenMultipleGenerations() { .conversationHistory(List.of(new AssistantMessage("Hello, how can I help you?"), new UserMessage("I would like to know the weather in London"), new AssistantMessage("Call the weather tool and the news tool"), - new ToolResponseMessage(List.of( - new ToolResponseMessage.ToolResponse("42", "weather", - "The weather in London is 20 degrees Celsius"), - new ToolResponseMessage.ToolResponse("21", "news", - "There is heavy traffic in the centre of London"))))) + ToolResponseMessage.builder() + .responses(List.of( + new ToolResponse("42", "weather", "The weather in London is 20 degrees Celsius"), + new ToolResponse("21", "news", "There is heavy traffic in the centre of London"))) + .build())) .build(); var generations = ToolExecutionResult.buildGenerations(toolExecutionResult); @@ -92,8 +95,8 @@ void whenEmptyConversationHistoryThenThrowsException() { @Test void whenToolResponseWithEmptyResponseListThenEmptyGenerations() { var toolExecutionResult = ToolExecutionResult.builder() - .conversationHistory( - List.of(new AssistantMessage("Processing request"), new ToolResponseMessage(List.of()))) + .conversationHistory(List.of(new AssistantMessage("Processing request"), + ToolResponseMessage.builder().responses(List.of()).build())) .build(); var generations = ToolExecutionResult.buildGenerations(toolExecutionResult); @@ -104,8 +107,8 @@ void whenToolResponseWithEmptyResponseListThenEmptyGenerations() { @Test void whenToolResponseWithNullContentThenGenerationWithNullText() { var toolExecutionResult = ToolExecutionResult.builder() - .conversationHistory( - List.of(new ToolResponseMessage(List.of(new ToolResponseMessage.ToolResponse("1", "tool", null))))) + .conversationHistory(List + .of(ToolResponseMessage.builder().responses(List.of(new ToolResponse("1", "tool", null))).build())) .build(); var generations = ToolExecutionResult.buildGenerations(toolExecutionResult); @@ -117,8 +120,8 @@ void whenToolResponseWithNullContentThenGenerationWithNullText() { @Test void whenToolResponseWithEmptyStringContentThenGenerationWithEmptyText() { var toolExecutionResult = ToolExecutionResult.builder() - .conversationHistory( - List.of(new ToolResponseMessage(List.of(new ToolResponseMessage.ToolResponse("1", "tool", ""))))) + .conversationHistory(List + .of(ToolResponseMessage.builder().responses(List.of(new ToolResponse("1", "tool", ""))).build())) .build(); var generations = ToolExecutionResult.buildGenerations(toolExecutionResult); @@ -144,11 +147,13 @@ void whenBuilderCalledWithoutConversationHistoryThenThrowsException() { void whenMultipleToolResponseMessagesOnlyLastOneIsProcessed() { var toolExecutionResult = ToolExecutionResult.builder() .conversationHistory(List.of(new AssistantMessage("First response"), - new ToolResponseMessage( - List.of(new ToolResponseMessage.ToolResponse("1", "old_tool", "Old response"))), + ToolResponseMessage.builder() + .responses(List.of(new ToolResponse("1", "old_tool", "Old response"))) + .build(), new AssistantMessage("Second response"), - new ToolResponseMessage( - List.of(new ToolResponseMessage.ToolResponse("2", "new_tool", "New response"))))) + ToolResponseMessage.builder() + .responses(List.of(new ToolResponse("2", "new_tool", "New response"))) + .build())) .build(); var generations = ToolExecutionResult.buildGenerations(toolExecutionResult); @@ -162,8 +167,9 @@ void whenMultipleToolResponseMessagesOnlyLastOneIsProcessed() { @Test void whenToolResponseWithEmptyToolNameThenMetadataContainsEmptyString() { var toolExecutionResult = ToolExecutionResult.builder() - .conversationHistory(List.of(new ToolResponseMessage( - List.of(new ToolResponseMessage.ToolResponse("1", "", "Response content"))))) + .conversationHistory(List.of(ToolResponseMessage.builder() + .responses(List.of(new ToolResponse("1", "", "Response content"))) + .build())) .build(); var generations = ToolExecutionResult.buildGenerations(toolExecutionResult); @@ -175,8 +181,9 @@ void whenToolResponseWithEmptyToolNameThenMetadataContainsEmptyString() { @Test void whenToolResponseWithNullToolIdThenGenerationStillCreated() { var toolExecutionResult = ToolExecutionResult.builder() - .conversationHistory(List.of(new ToolResponseMessage( - List.of(new ToolResponseMessage.ToolResponse(null, "tool", "Response content"))))) + .conversationHistory(List.of(ToolResponseMessage.builder() + .responses(List.of(new ToolResponse(null, "tool", "Response content"))) + .build())) .build(); var generations = ToolExecutionResult.buildGenerations(toolExecutionResult);