|
28 | 28 | import org.junit.jupiter.api.BeforeEach; |
29 | 29 | import org.junit.jupiter.api.Test; |
30 | 30 | import org.junit.jupiter.api.extension.RegisterExtension; |
| 31 | +import org.junit.jupiter.params.ParameterizedTest; |
| 32 | +import org.junit.jupiter.params.provider.NullAndEmptySource; |
31 | 33 | import org.mockito.ArgumentCaptor; |
32 | 34 | import org.mockito.Mockito; |
33 | 35 | import org.springframework.ai.chat.client.ChatClient; |
34 | 36 | import org.springframework.ai.chat.messages.AssistantMessage; |
35 | 37 | import org.springframework.ai.chat.messages.SystemMessage; |
36 | 38 | import org.springframework.ai.chat.messages.UserMessage; |
| 39 | +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; |
37 | 40 | import org.springframework.ai.chat.model.ChatModel; |
38 | 41 | import org.springframework.ai.chat.model.ChatResponse; |
39 | 42 | import org.springframework.ai.chat.model.Generation; |
@@ -803,15 +806,192 @@ void setHistory_withEmptyAttachmentMap_behavesLikeTextOnly() { |
803 | 806 | Assertions.assertTrue(userMsg.getMedia().isEmpty()); |
804 | 807 | } |
805 | 808 |
|
| 809 | + // --- Streaming finish_reason / abnormal termination tests --- |
| 810 | + |
| 811 | + @ParameterizedTest |
| 812 | + @NullAndEmptySource |
| 813 | + void stream_streamingWithMissingFinishReason_throwsIllegalStateException( |
| 814 | + String reason) { |
| 815 | + // OpenAI-compatible backends emit "" for an unset finish_reason; |
| 816 | + // both "" and null must be treated as missing. |
| 817 | + var request = createSimpleRequest("Hello"); |
| 818 | + Mockito.when(mockChatModel.stream(Mockito.any(Prompt.class))) |
| 819 | + .thenReturn(Flux.just(mockChatResponse("", reason))); |
| 820 | + |
| 821 | + Assertions.assertThrows(IllegalStateException.class, |
| 822 | + () -> provider.stream(request).collectList().block()); |
| 823 | + } |
| 824 | + |
| 825 | + @Test |
| 826 | + void stream_streamingCompletesEmptyWithNoChunks_throwsIllegalStateException() { |
| 827 | + // Zero-chunk stream: doOnNext never fires; the concatWith tail raises. |
| 828 | + var request = createSimpleRequest("Hello"); |
| 829 | + Mockito.when(mockChatModel.stream(Mockito.any(Prompt.class))) |
| 830 | + .thenReturn(Flux.empty()); |
| 831 | + |
| 832 | + Assertions.assertThrows(IllegalStateException.class, |
| 833 | + () -> provider.stream(request).collectList().block()); |
| 834 | + } |
| 835 | + |
| 836 | + @Test |
| 837 | + void stream_streamingWithValidFinishReasonButEmptyContent_completesWithoutError() { |
| 838 | + // Tool-only turns and content-filter stops produce empty text but |
| 839 | + // always carry a finish_reason; not errors. |
| 840 | + var request = createSimpleRequest("Hello"); |
| 841 | + Mockito.when(mockChatModel.stream(Mockito.any(Prompt.class))) |
| 842 | + .thenReturn(Flux.just(mockChatResponse("", "STOP"))); |
| 843 | + |
| 844 | + var results = provider.stream(request).collectList().block(); |
| 845 | + |
| 846 | + Assertions.assertNotNull(results); |
| 847 | + Assertions.assertTrue(results.isEmpty()); |
| 848 | + } |
| 849 | + |
| 850 | + @Test |
| 851 | + void stream_streamingWithLengthFinishReason_emitsPartialContent() { |
| 852 | + var request = createSimpleRequest("Hello"); |
| 853 | + Mockito.when(mockChatModel.stream(Mockito.any(Prompt.class))) |
| 854 | + .thenReturn(Flux.just(mockChatResponse("partial", "LENGTH"))); |
| 855 | + |
| 856 | + var results = provider.stream(request).collectList().block(); |
| 857 | + |
| 858 | + Assertions.assertEquals(List.of("partial"), results); |
| 859 | + } |
| 860 | + |
| 861 | + @Test |
| 862 | + void stream_streamingWithFinishReasonOnlyOnLastChunk_completesNormally() { |
| 863 | + // Real OpenAI streams set finish_reason only on the terminal chunk. |
| 864 | + var request = createSimpleRequest("Hello"); |
| 865 | + var chunk1 = mockChatResponse("Hel", null); |
| 866 | + var chunk2 = mockChatResponse("lo", null); |
| 867 | + var terminal = mockChatResponse(" World", "STOP"); |
| 868 | + Mockito.when(mockChatModel.stream(Mockito.any(Prompt.class))) |
| 869 | + .thenReturn(Flux.just(chunk1, chunk2, terminal)); |
| 870 | + |
| 871 | + var results = provider.stream(request).collectList().block(); |
| 872 | + |
| 873 | + Assertions.assertEquals(List.of("Hel", "lo", " World"), results); |
| 874 | + } |
| 875 | + |
| 876 | + @Test |
| 877 | + void stream_streamingWithNullGeneration_throwsIllegalStateException() { |
| 878 | + // ChatResponse(emptyList()) yields getResult() == null and no |
| 879 | + // finish_reason: indistinguishable from an abort. |
| 880 | + var request = createSimpleRequest("Hello"); |
| 881 | + var responseWithNoResult = new ChatResponse(Collections.emptyList()); |
| 882 | + Mockito.when(mockChatModel.stream(Mockito.any(Prompt.class))) |
| 883 | + .thenReturn(Flux.just(responseWithNoResult)); |
| 884 | + |
| 885 | + Assertions.assertThrows(IllegalStateException.class, |
| 886 | + () -> provider.stream(request).collectList().block()); |
| 887 | + } |
| 888 | + |
| 889 | + @Test |
| 890 | + void stream_streamingWithNullGenerationButFollowedByFinish_completesNormally() { |
| 891 | + // A null-result chunk is tolerated as long as another chunk signs |
| 892 | + // the stream off with a finish_reason. |
| 893 | + var request = createSimpleRequest("Hello"); |
| 894 | + var empty = new ChatResponse(Collections.emptyList()); |
| 895 | + var terminal = mockChatResponse("ok", "STOP"); |
| 896 | + Mockito.when(mockChatModel.stream(Mockito.any(Prompt.class))) |
| 897 | + .thenReturn(Flux.just(empty, terminal)); |
| 898 | + |
| 899 | + var results = provider.stream(request).collectList().block(); |
| 900 | + |
| 901 | + Assertions.assertEquals(List.of("ok"), results); |
| 902 | + } |
| 903 | + |
| 904 | + @Test |
| 905 | + void stream_streamingWithNullTextInMessage_filtersOut() { |
| 906 | + // AssistantMessage.getText() is @Nullable; null text is filtered |
| 907 | + // rather than propagated as the empty string. |
| 908 | + var request = createSimpleRequest("Hello"); |
| 909 | + var nullTextMessage = new AssistantMessage((String) null); |
| 910 | + var response = new ChatResponse( |
| 911 | + List.of(new Generation(nullTextMessage, ChatGenerationMetadata |
| 912 | + .builder().finishReason("STOP").build()))); |
| 913 | + Mockito.when(mockChatModel.stream(Mockito.any(Prompt.class))) |
| 914 | + .thenReturn(Flux.just(response)); |
| 915 | + |
| 916 | + var results = provider.stream(request).collectList().block(); |
| 917 | + |
| 918 | + Assertions.assertNotNull(results); |
| 919 | + Assertions.assertTrue(results.isEmpty()); |
| 920 | + } |
| 921 | + |
| 922 | + @Test |
| 923 | + void stream_streamingWithMultipleChunksAndMixedEmptyContent_emitsOnlyNonEmpty() { |
| 924 | + var request = createSimpleRequest("Hello"); |
| 925 | + Mockito.when(mockChatModel.stream(Mockito.any(Prompt.class))) |
| 926 | + .thenReturn(Flux.just(mockChatResponse("", null), |
| 927 | + mockChatResponse("Hello", null), |
| 928 | + mockChatResponse("", null), |
| 929 | + mockChatResponse(" World", "STOP"))); |
| 930 | + |
| 931 | + var results = provider.stream(request).collectList().block(); |
| 932 | + |
| 933 | + Assertions.assertEquals(List.of("Hello", " World"), results); |
| 934 | + } |
| 935 | + |
| 936 | + @Test |
| 937 | + void stream_streamingUpstreamErrorsDuringStream_propagatesOriginalError() { |
| 938 | + var request = createSimpleRequest("Hello"); |
| 939 | + var originalError = new RuntimeException("network broken"); |
| 940 | + Mockito.when(mockChatModel.stream(Mockito.any(Prompt.class))) |
| 941 | + .thenReturn(Flux.error(originalError)); |
| 942 | + |
| 943 | + var thrown = Assertions.assertThrows(RuntimeException.class, |
| 944 | + () -> provider.stream(request).collectList().block()); |
| 945 | + Assertions.assertEquals(originalError, thrown); |
| 946 | + } |
| 947 | + |
| 948 | + @Test |
| 949 | + void stream_streamingUpstreamErrorsAfterFinishReason_propagatesOriginalError() { |
| 950 | + // finish_reason was already seen, yet an upstream error must still |
| 951 | + // win over our abort detector. |
| 952 | + var request = createSimpleRequest("Hello"); |
| 953 | + var chunk = mockChatResponse("data", "STOP"); |
| 954 | + var originalError = new RuntimeException("broken after chunk"); |
| 955 | + Mockito.when(mockChatModel.stream(Mockito.any(Prompt.class))) |
| 956 | + .thenReturn( |
| 957 | + Flux.just(chunk).concatWith(Flux.error(originalError))); |
| 958 | + |
| 959 | + var thrown = Assertions.assertThrows(RuntimeException.class, |
| 960 | + () -> provider.stream(request).collectList().block()); |
| 961 | + Assertions.assertEquals(originalError, thrown); |
| 962 | + } |
| 963 | + |
| 964 | + @Test |
| 965 | + void stream_streamingChatModelThrowsSynchronously_propagatesError() { |
| 966 | + var request = createSimpleRequest("Hello"); |
| 967 | + var originalError = new RuntimeException("stream API down"); |
| 968 | + Mockito.when(mockChatModel.stream(Mockito.any(Prompt.class))) |
| 969 | + .thenThrow(originalError); |
| 970 | + |
| 971 | + var thrown = Assertions.assertThrows(RuntimeException.class, |
| 972 | + () -> provider.stream(request).collectList().block()); |
| 973 | + Assertions.assertEquals(originalError, thrown); |
| 974 | + } |
| 975 | + |
806 | 976 | private void mockSimpleChat(String responseText) { |
807 | 977 | var response = mockSimpleChatResponse(responseText); |
808 | 978 | Mockito.when(mockChatModel.call(Mockito.any(Prompt.class))) |
809 | 979 | .thenReturn(response); |
810 | 980 | } |
811 | 981 |
|
812 | 982 | private ChatResponse mockSimpleChatResponse(String text) { |
| 983 | + // Single-chunk responses are always terminal; tag them with STOP so |
| 984 | + // the finish_reason gate is satisfied. |
| 985 | + return mockChatResponse(text, "STOP"); |
| 986 | + } |
| 987 | + |
| 988 | + private static ChatResponse mockChatResponse(String text, |
| 989 | + String finishReason) { |
813 | 990 | var assistantMessage = new AssistantMessage(text); |
814 | | - var generation = new Generation(assistantMessage); |
| 991 | + var metadata = finishReason == null ? ChatGenerationMetadata.NULL |
| 992 | + : ChatGenerationMetadata.builder().finishReason(finishReason) |
| 993 | + .build(); |
| 994 | + var generation = new Generation(assistantMessage, metadata); |
815 | 995 | return new ChatResponse(List.of(generation)); |
816 | 996 | } |
817 | 997 |
|
|
0 commit comments