diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java index c5a99f99eb3..f07d93a159c 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java @@ -284,7 +284,7 @@ public Flux chatCompletionStream(ChatCompletionRequest chat }) .concatMapIterable(window -> { Mono monoChunk = window - .reduce(new ChatCompletionChunk(null, null, null, null, null, null), this.chunkMerger::merge); + .reduce(new ChatCompletionChunk(null, null, null, null, null, null, null), this.chunkMerger::merge); return List.of(monoChunk); }) .flatMap(mono -> mono); @@ -1110,7 +1110,8 @@ public record ChatCompletionChunk(// @formatter:off @JsonProperty("created") Long created, @JsonProperty("model") String model, @JsonProperty("system_fingerprint") String systemFingerprint, - @JsonProperty("object") String object) { // @formatter:on + @JsonProperty("object") String object, + @JsonProperty("usage") Usage usage) { // @formatter:on /** * Chat completion choice. diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiStreamFunctionCallingHelper.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiStreamFunctionCallingHelper.java index e3afcf6407e..71714241e92 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiStreamFunctionCallingHelper.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiStreamFunctionCallingHelper.java @@ -58,13 +58,14 @@ public ChatCompletionChunk merge(ChatCompletionChunk previous, ChatCompletionChu String systemFingerprint = (current.systemFingerprint() != null ? current.systemFingerprint() : previous.systemFingerprint()); String object = (current.object() != null ? current.object() : previous.object()); + ZhiPuAiApi.Usage usage = (current.usage() != null ? current.usage() : previous.usage()); ChunkChoice previousChoice0 = (CollectionUtils.isEmpty(previous.choices()) ? null : previous.choices().get(0)); ChunkChoice currentChoice0 = (CollectionUtils.isEmpty(current.choices()) ? null : current.choices().get(0)); ChunkChoice choice = merge(previousChoice0, currentChoice0); List chunkChoices = choice == null ? List.of() : List.of(choice); - return new ChatCompletionChunk(id, chunkChoices, created, model, systemFingerprint, object); + return new ChatCompletionChunk(id, chunkChoices, created, model, systemFingerprint, object, usage); } private ChunkChoice merge(ChunkChoice previous, ChunkChoice current) { diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiIT.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiIT.java index 44958b9d157..9ee29d46cae 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiIT.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiIT.java @@ -57,21 +57,24 @@ void chatCompletionEntity() { void chatCompletionEntityWithMoreParams() { ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); ResponseEntity response = this.zhiPuAiApi - .chatCompletionEntity(new ChatCompletionRequest(List.of(chatCompletionMessage), "glm-3-turbo", 1024, null, + .chatCompletionEntity(new ChatCompletionRequest(List.of(chatCompletionMessage), "glm-4-flash", 1024, null, false, 0.95, 0.7, null, null, null, "test_request_id", false, null, null)); assertThat(response).isNotNull(); assertThat(response.getBody()).isNotNull(); + assertThat(response.getBody().usage()).isNotNull(); } @Test void chatCompletionStream() { ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); Flux response = this.zhiPuAiApi - .chatCompletionStream(new ChatCompletionRequest(List.of(chatCompletionMessage), "glm-3-turbo", 0.7, true)); + .chatCompletionStream(new ChatCompletionRequest(List.of(chatCompletionMessage), "glm-4-flash", 0.7, true)); assertThat(response).isNotNull(); - assertThat(response.collectList().block()).isNotNull(); + List chunks = response.collectList().block(); + assertThat(chunks).isNotNull(); + assertThat(chunks.get(chunks.size() - 1).usage()).isNotNull(); } @Test diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiRetryTests.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiRetryTests.java index b1b37bd37ba..327a7f45329 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiRetryTests.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiRetryTests.java @@ -133,7 +133,7 @@ public void zhiPuAiChatStreamTransientError() { var choice = new ChatCompletionChunk.ChunkChoice(ChatCompletionFinishReason.STOP, 0, new ChatCompletionMessage("Response", Role.ASSISTANT), null); ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", List.of(choice), 666L, "model", null, - null); + null, null); given(this.zhiPuAiApi.chatCompletionStream(isA(ChatCompletionRequest.class))) .willThrow(new TransientAiException("Transient Error 1"))