From d2057f3d5e33b8c5b5ff15c2f103459b5ecfb927 Mon Sep 17 00:00:00 2001 From: ultramancode Date: Tue, 4 Nov 2025 15:17:02 +0900 Subject: [PATCH 1/2] Fix tool call merging for streaming APIs without IDs - Update MessageAggregator to handle tool calls without IDs - When tool call has no ID, merge with last tool call - Add comprehensive tests for streaming patterns Signed-off-by: ultramancode --- .../ai/chat/model/MessageAggregator.java | 73 +++++++- .../ai/chat/model/MessageAggregatorTests.java | 166 ++++++++++++++++++ 2 files changed, 236 insertions(+), 3 deletions(-) create mode 100644 spring-ai-model/src/test/java/org/springframework/ai/chat/model/MessageAggregatorTests.java diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java index 59f7db03d57..9bd267bd529 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java @@ -38,8 +38,6 @@ import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; -import static org.springframework.ai.chat.messages.AssistantMessage.ToolCall; - /** * Helper that for streaming chat responses, aggregate the chat response messages into a * single AssistantMessage. Job is performed in parallel to the chat response processing. @@ -48,6 +46,7 @@ * @author Alexandros Pappas * @author Thomas Vitale * @author Heonwoo Kim + * @author Taewoong Kim * @since 1.0.0 */ public class MessageAggregator { @@ -104,7 +103,7 @@ public Flux aggregate(Flux fluxChatResponse, } AssistantMessage outputMessage = chatResponse.getResult().getOutput(); if (!CollectionUtils.isEmpty(outputMessage.getToolCalls())) { - toolCallsRef.get().addAll(outputMessage.getToolCalls()); + mergeToolCalls(toolCallsRef.get(), outputMessage.getToolCalls()); } } @@ -188,6 +187,74 @@ public Flux aggregate(Flux fluxChatResponse, }).doOnError(e -> logger.error("Aggregation Error", e)); } + /** + * Merge tool calls by id to handle streaming responses where tool call data is split + * across multiple chunks. This is common in OpenAI-compatible APIs like Qwen, where + * the first chunk contains the function name and subsequent chunks contain only arguments. + * if a tool call has an ID, it's matched by ID. + * if it has no ID (empty or null), it's merged with the last tool call in the list. + * @param existingToolCalls the list of existing tool calls to merge into + * @param newToolCalls the new tool calls to merge + */ + private void mergeToolCalls(List existingToolCalls, List newToolCalls) { + for (ToolCall newCall : newToolCalls) { + if (StringUtils.hasText(newCall.id())) { + // ID present: match by ID or add as new + ToolCall existingMatch = existingToolCalls.stream() + .filter(existing -> newCall.id().equals(existing.id())) + .findFirst() + .orElse(null); + + if (existingMatch != null) { + // Merge with existing tool call with same ID + int index = existingToolCalls.indexOf(existingMatch); + ToolCall merged = mergeToolCall(existingMatch, newCall); + existingToolCalls.set(index, merged); + } else { + // New tool call with ID + existingToolCalls.add(newCall); + } + } else { + // No ID: merge with last tool call + ToolCall lastToolCall = existingToolCalls.isEmpty() ? null : existingToolCalls.get(existingToolCalls.size() - 1); + ToolCall merged = mergeToolCall(lastToolCall, newCall); + + if (lastToolCall != null) { + existingToolCalls.set(existingToolCalls.size() - 1, merged); + } else { + existingToolCalls.add(merged); + } + } + } + } + + /** + * Merge two tool calls into one, combining their properties. + * @param existing the existing tool call + * @param current the current tool call to merge + * @return the merged tool call + */ + private ToolCall mergeToolCall(ToolCall existing, ToolCall current) { + if (existing == null) { + return current; + } + + // Use non-empty ID, prefer existing if both present (for consistency) + String mergedId = StringUtils.hasText(existing.id()) ? existing.id() : current.id(); + + // Use non-empty name, prefer new if both present + String mergedName = StringUtils.hasText(current.name()) ? current.name() : existing.name(); + + // Use non-empty type, prefer new if both present + String mergedType = StringUtils.hasText(current.type()) ? current.type() : existing.type(); + + // Concatenate arguments + String mergedArgs = (existing.arguments() != null ? existing.arguments() : "") + + (current.arguments() != null ? current.arguments() : ""); + + return new ToolCall(mergedId, mergedType, mergedName, mergedArgs); + } + public record DefaultUsage(Integer promptTokens, Integer completionTokens, Integer totalTokens) implements Usage { @Override diff --git a/spring-ai-model/src/test/java/org/springframework/ai/chat/model/MessageAggregatorTests.java b/spring-ai-model/src/test/java/org/springframework/ai/chat/model/MessageAggregatorTests.java new file mode 100644 index 00000000000..158a524e322 --- /dev/null +++ b/spring-ai-model/src/test/java/org/springframework/ai/chat/model/MessageAggregatorTests.java @@ -0,0 +1,166 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.model; + +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; + +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; + +import org.springframework.ai.chat.messages.AssistantMessage; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link MessageAggregator} with streaming tool calls that lack IDs in subsequent chunks. + * This pattern is common in OpenAI-compatible APIs. + * @author Taewoong Kim + */ +class MessageAggregatorTests { + + private final MessageAggregator messageAggregator = new MessageAggregator(); + + /** + * Test merging of tool calls when subsequent chunks have no ID. + * First chunk contains the tool name and ID, subsequent chunks contain only arguments. + */ + @Test + void shouldMergeToolCallsWithoutIds() { + // Chunk 1: ID and name present + ChatResponse chunk1 = new ChatResponse(List.of(new Generation(AssistantMessage.builder() + .toolCalls(List.of(new AssistantMessage.ToolCall("chatcmpl-tool-123", "function", "getCurrentWeather", ""))) + .build()))); + + // Chunk 2-5: No ID, only arguments (common streaming pattern) + ChatResponse chunk2 = new ChatResponse(List.of(new Generation(AssistantMessage.builder() + .toolCalls(List.of(new AssistantMessage.ToolCall("", "function", "", "{\"location\": \""))) + .build()))); + + ChatResponse chunk3 = new ChatResponse(List.of(new Generation(AssistantMessage.builder() + .toolCalls(List.of(new AssistantMessage.ToolCall("", "function", "", "Se"))) + .build()))); + + ChatResponse chunk4 = new ChatResponse(List.of(new Generation(AssistantMessage.builder() + .toolCalls(List.of(new AssistantMessage.ToolCall("", "function", "", "oul"))) + .build()))); + + ChatResponse chunk5 = new ChatResponse(List.of(new Generation(AssistantMessage.builder() + .toolCalls(List.of(new AssistantMessage.ToolCall("", "function", "", "\"}"))) + .build()))); + + Flux flux = Flux.just(chunk1, chunk2, chunk3, chunk4, chunk5); + + // When: Aggregate the streaming responses + AtomicReference finalResponse = new AtomicReference<>(); + this.messageAggregator.aggregate(flux, finalResponse::set).blockLast(); + + // Then: Verify the tool call was properly merged + assertThat(finalResponse.get()).isNotNull(); + List toolCalls = finalResponse.get().getResult().getOutput().getToolCalls(); + + assertThat(toolCalls).hasSize(1); + AssistantMessage.ToolCall mergedToolCall = toolCalls.get(0); + + assertThat(mergedToolCall.id()).isEqualTo("chatcmpl-tool-123"); + assertThat(mergedToolCall.name()).isEqualTo("getCurrentWeather"); + assertThat(mergedToolCall.arguments()).isEqualTo("{\"location\": \"Seoul\"}"); + } + + /** + * Test multiple tool calls being streamed simultaneously. Each tool call has its own ID in the first chunk, + * and subsequent chunks have no ID but are merged with the last tool call. + */ + @Test + void shouldMergeMultipleToolCallsWithMixedIds() { + // Given: Multiple tool calls being streamed + // Chunk 1: First tool call starts with ID + ChatResponse chunk1 = new ChatResponse(List.of(new Generation(AssistantMessage.builder() + .toolCalls(List.of(new AssistantMessage.ToolCall("tool-1", "function", "getWeather", ""))) + .build()))); + + // Chunk 2: Argument for first tool call (no ID) + ChatResponse chunk2 = new ChatResponse(List.of(new Generation(AssistantMessage.builder() + .toolCalls(List.of(new AssistantMessage.ToolCall("", "function", "", "{\"city\":\"Tokyo\"}"))) + .build()))); + + // Chunk 3: Second tool call starts with ID + ChatResponse chunk3 = new ChatResponse(List.of(new Generation(AssistantMessage.builder() + .toolCalls(List.of(new AssistantMessage.ToolCall("tool-2", "function", "getTime", ""))) + .build()))); + + // Chunk 4: Argument for second tool call (no ID) + ChatResponse chunk4 = new ChatResponse(List.of(new Generation(AssistantMessage.builder() + .toolCalls(List.of(new AssistantMessage.ToolCall("", "function", "", "{\"timezone\":\"JST\"}"))) + .build()))); + + Flux flux = Flux.just(chunk1, chunk2, chunk3, chunk4); + + // When: Aggregate the streaming responses + AtomicReference finalResponse = new AtomicReference<>(); + this.messageAggregator.aggregate(flux, finalResponse::set).blockLast(); + + // Then: Verify both tool calls were properly merged + assertThat(finalResponse.get()).isNotNull(); + List toolCalls = finalResponse.get().getResult().getOutput().getToolCalls(); + + assertThat(toolCalls).hasSize(2); + + AssistantMessage.ToolCall firstToolCall = toolCalls.get(0); + assertThat(firstToolCall.id()).isEqualTo("tool-1"); + assertThat(firstToolCall.name()).isEqualTo("getWeather"); + assertThat(firstToolCall.arguments()).isEqualTo("{\"city\":\"Tokyo\"}"); + + AssistantMessage.ToolCall secondToolCall = toolCalls.get(1); + assertThat(secondToolCall.id()).isEqualTo("tool-2"); + assertThat(secondToolCall.name()).isEqualTo("getTime"); + assertThat(secondToolCall.arguments()).isEqualTo("{\"timezone\":\"JST\"}"); + } + + /** + * Test that tool calls with IDs are still matched correctly by ID, even when they arrive in different chunks. + */ + @Test + void shouldMergeToolCallsById() { + // Given: Chunks with same ID arriving separately + ChatResponse chunk1 = new ChatResponse(List.of(new Generation(AssistantMessage.builder() + .toolCalls(List.of(new AssistantMessage.ToolCall("tool-1", "function", "getWeather", "{\"ci"))) + .build()))); + + ChatResponse chunk2 = new ChatResponse(List.of(new Generation(AssistantMessage.builder() + .toolCalls(List.of(new AssistantMessage.ToolCall("tool-1", "function", "", "ty\":\"Paris\"}"))) + .build()))); + + Flux flux = Flux.just(chunk1, chunk2); + + // When: Aggregate the streaming responses + AtomicReference finalResponse = new AtomicReference<>(); + this.messageAggregator.aggregate(flux, finalResponse::set).blockLast(); + + // Then: Verify the tool call was merged by ID + assertThat(finalResponse.get()).isNotNull(); + List toolCalls = finalResponse.get().getResult().getOutput().getToolCalls(); + + assertThat(toolCalls).hasSize(1); + AssistantMessage.ToolCall mergedToolCall = toolCalls.get(0); + assertThat(mergedToolCall.id()).isEqualTo("tool-1"); + assertThat(mergedToolCall.name()).isEqualTo("getWeather"); + assertThat(mergedToolCall.arguments()).isEqualTo("{\"city\":\"Paris\"}"); + } + +} + From 77801edf14c6155be9f85d98b99f0cc1182d960b Mon Sep 17 00:00:00 2001 From: ultramancode Date: Tue, 4 Nov 2025 19:54:39 +0900 Subject: [PATCH 2/2] Format: apply Spring Java Format Signed-off-by: ultramancode --- .../ai/chat/model/MessageAggregator.java | 30 +++++---- .../ai/chat/model/MessageAggregatorTests.java | 62 ++++++++++--------- 2 files changed, 49 insertions(+), 43 deletions(-) diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java index 9bd267bd529..bae682239de 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java @@ -190,9 +190,9 @@ public Flux aggregate(Flux fluxChatResponse, /** * Merge tool calls by id to handle streaming responses where tool call data is split * across multiple chunks. This is common in OpenAI-compatible APIs like Qwen, where - * the first chunk contains the function name and subsequent chunks contain only arguments. - * if a tool call has an ID, it's matched by ID. - * if it has no ID (empty or null), it's merged with the last tool call in the list. + * the first chunk contains the function name and subsequent chunks contain only + * arguments. if a tool call has an ID, it's matched by ID. if it has no ID (empty or + * null), it's merged with the last tool call in the list. * @param existingToolCalls the list of existing tool calls to merge into * @param newToolCalls the new tool calls to merge */ @@ -201,27 +201,31 @@ private void mergeToolCalls(List existingToolCalls, List new if (StringUtils.hasText(newCall.id())) { // ID present: match by ID or add as new ToolCall existingMatch = existingToolCalls.stream() - .filter(existing -> newCall.id().equals(existing.id())) - .findFirst() - .orElse(null); + .filter(existing -> newCall.id().equals(existing.id())) + .findFirst() + .orElse(null); if (existingMatch != null) { // Merge with existing tool call with same ID int index = existingToolCalls.indexOf(existingMatch); ToolCall merged = mergeToolCall(existingMatch, newCall); existingToolCalls.set(index, merged); - } else { + } + else { // New tool call with ID existingToolCalls.add(newCall); } - } else { + } + else { // No ID: merge with last tool call - ToolCall lastToolCall = existingToolCalls.isEmpty() ? null : existingToolCalls.get(existingToolCalls.size() - 1); + ToolCall lastToolCall = existingToolCalls.isEmpty() ? null + : existingToolCalls.get(existingToolCalls.size() - 1); ToolCall merged = mergeToolCall(lastToolCall, newCall); - + if (lastToolCall != null) { existingToolCalls.set(existingToolCalls.size() - 1, merged); - } else { + } + else { existingToolCalls.add(merged); } } @@ -231,14 +235,14 @@ private void mergeToolCalls(List existingToolCalls, List new /** * Merge two tool calls into one, combining their properties. * @param existing the existing tool call - * @param current the current tool call to merge + * @param current the current tool call to merge * @return the merged tool call */ private ToolCall mergeToolCall(ToolCall existing, ToolCall current) { if (existing == null) { return current; } - + // Use non-empty ID, prefer existing if both present (for consistency) String mergedId = StringUtils.hasText(existing.id()) ? existing.id() : current.id(); diff --git a/spring-ai-model/src/test/java/org/springframework/ai/chat/model/MessageAggregatorTests.java b/spring-ai-model/src/test/java/org/springframework/ai/chat/model/MessageAggregatorTests.java index 158a524e322..448160881cd 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/chat/model/MessageAggregatorTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/chat/model/MessageAggregatorTests.java @@ -27,8 +27,9 @@ import static org.assertj.core.api.Assertions.assertThat; /** - * Tests for {@link MessageAggregator} with streaming tool calls that lack IDs in subsequent chunks. - * This pattern is common in OpenAI-compatible APIs. + * Tests for {@link MessageAggregator} with streaming tool calls that lack IDs in + * subsequent chunks. This pattern is common in OpenAI-compatible APIs. + * * @author Taewoong Kim */ class MessageAggregatorTests { @@ -36,32 +37,32 @@ class MessageAggregatorTests { private final MessageAggregator messageAggregator = new MessageAggregator(); /** - * Test merging of tool calls when subsequent chunks have no ID. - * First chunk contains the tool name and ID, subsequent chunks contain only arguments. + * Test merging of tool calls when subsequent chunks have no ID. First chunk contains + * the tool name and ID, subsequent chunks contain only arguments. */ @Test void shouldMergeToolCallsWithoutIds() { // Chunk 1: ID and name present ChatResponse chunk1 = new ChatResponse(List.of(new Generation(AssistantMessage.builder() - .toolCalls(List.of(new AssistantMessage.ToolCall("chatcmpl-tool-123", "function", "getCurrentWeather", ""))) - .build()))); + .toolCalls(List.of(new AssistantMessage.ToolCall("chatcmpl-tool-123", "function", "getCurrentWeather", ""))) + .build()))); // Chunk 2-5: No ID, only arguments (common streaming pattern) ChatResponse chunk2 = new ChatResponse(List.of(new Generation(AssistantMessage.builder() - .toolCalls(List.of(new AssistantMessage.ToolCall("", "function", "", "{\"location\": \""))) - .build()))); + .toolCalls(List.of(new AssistantMessage.ToolCall("", "function", "", "{\"location\": \""))) + .build()))); ChatResponse chunk3 = new ChatResponse(List.of(new Generation(AssistantMessage.builder() - .toolCalls(List.of(new AssistantMessage.ToolCall("", "function", "", "Se"))) - .build()))); + .toolCalls(List.of(new AssistantMessage.ToolCall("", "function", "", "Se"))) + .build()))); ChatResponse chunk4 = new ChatResponse(List.of(new Generation(AssistantMessage.builder() - .toolCalls(List.of(new AssistantMessage.ToolCall("", "function", "", "oul"))) - .build()))); + .toolCalls(List.of(new AssistantMessage.ToolCall("", "function", "", "oul"))) + .build()))); ChatResponse chunk5 = new ChatResponse(List.of(new Generation(AssistantMessage.builder() - .toolCalls(List.of(new AssistantMessage.ToolCall("", "function", "", "\"}"))) - .build()))); + .toolCalls(List.of(new AssistantMessage.ToolCall("", "function", "", "\"}"))) + .build()))); Flux flux = Flux.just(chunk1, chunk2, chunk3, chunk4, chunk5); @@ -82,31 +83,32 @@ void shouldMergeToolCallsWithoutIds() { } /** - * Test multiple tool calls being streamed simultaneously. Each tool call has its own ID in the first chunk, - * and subsequent chunks have no ID but are merged with the last tool call. + * Test multiple tool calls being streamed simultaneously. Each tool call has its own + * ID in the first chunk, and subsequent chunks have no ID but are merged with the + * last tool call. */ @Test void shouldMergeMultipleToolCallsWithMixedIds() { // Given: Multiple tool calls being streamed // Chunk 1: First tool call starts with ID ChatResponse chunk1 = new ChatResponse(List.of(new Generation(AssistantMessage.builder() - .toolCalls(List.of(new AssistantMessage.ToolCall("tool-1", "function", "getWeather", ""))) - .build()))); + .toolCalls(List.of(new AssistantMessage.ToolCall("tool-1", "function", "getWeather", ""))) + .build()))); // Chunk 2: Argument for first tool call (no ID) ChatResponse chunk2 = new ChatResponse(List.of(new Generation(AssistantMessage.builder() - .toolCalls(List.of(new AssistantMessage.ToolCall("", "function", "", "{\"city\":\"Tokyo\"}"))) - .build()))); + .toolCalls(List.of(new AssistantMessage.ToolCall("", "function", "", "{\"city\":\"Tokyo\"}"))) + .build()))); // Chunk 3: Second tool call starts with ID ChatResponse chunk3 = new ChatResponse(List.of(new Generation(AssistantMessage.builder() - .toolCalls(List.of(new AssistantMessage.ToolCall("tool-2", "function", "getTime", ""))) - .build()))); + .toolCalls(List.of(new AssistantMessage.ToolCall("tool-2", "function", "getTime", ""))) + .build()))); // Chunk 4: Argument for second tool call (no ID) ChatResponse chunk4 = new ChatResponse(List.of(new Generation(AssistantMessage.builder() - .toolCalls(List.of(new AssistantMessage.ToolCall("", "function", "", "{\"timezone\":\"JST\"}"))) - .build()))); + .toolCalls(List.of(new AssistantMessage.ToolCall("", "function", "", "{\"timezone\":\"JST\"}"))) + .build()))); Flux flux = Flux.just(chunk1, chunk2, chunk3, chunk4); @@ -132,18 +134,19 @@ void shouldMergeMultipleToolCallsWithMixedIds() { } /** - * Test that tool calls with IDs are still matched correctly by ID, even when they arrive in different chunks. + * Test that tool calls with IDs are still matched correctly by ID, even when they + * arrive in different chunks. */ @Test void shouldMergeToolCallsById() { // Given: Chunks with same ID arriving separately ChatResponse chunk1 = new ChatResponse(List.of(new Generation(AssistantMessage.builder() - .toolCalls(List.of(new AssistantMessage.ToolCall("tool-1", "function", "getWeather", "{\"ci"))) - .build()))); + .toolCalls(List.of(new AssistantMessage.ToolCall("tool-1", "function", "getWeather", "{\"ci"))) + .build()))); ChatResponse chunk2 = new ChatResponse(List.of(new Generation(AssistantMessage.builder() - .toolCalls(List.of(new AssistantMessage.ToolCall("tool-1", "function", "", "ty\":\"Paris\"}"))) - .build()))); + .toolCalls(List.of(new AssistantMessage.ToolCall("tool-1", "function", "", "ty\":\"Paris\"}"))) + .build()))); Flux flux = Flux.just(chunk1, chunk2); @@ -163,4 +166,3 @@ void shouldMergeToolCallsById() { } } -