Skip to content

Commit ef3d43c

Browse files
sobychackomarkpollack
authored andcommitted
Use OpenAIAsyncClient for streaming in AzureOpenAiChatModel
- Switch to OpenAIAsyncClient for streaming operations - Modify AzureOpenAiChatModel constructor to accept OpenAIClientBuilder - Update getChatCompletionsStream to use non-blocking async client - Refactor related classes and tests to support OpenAIClientBuilder - Revise AzureOpenAiAutoConfiguration to provide OpenAIClientBuilder - Add AzureOpenAiChatClientTest to verify streaming functionality - Adjust existing tests for compatibility with OpenAIClientBuilder Resolves #981 This change improves support for asynchronous streaming operations in the AzureOpenAiChatModel, addressing potential issues in reactive environments.
1 parent e1d9bfc commit ef3d43c

File tree

9 files changed

+179
-88
lines changed

9 files changed

+179
-88
lines changed

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java

Lines changed: 34 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,11 @@
1616

1717
package org.springframework.ai.azure.openai;
1818

19-
import java.util.ArrayList;
20-
import java.util.Base64;
21-
import java.util.Collections;
22-
import java.util.HashSet;
23-
import java.util.List;
24-
import java.util.Map;
25-
import java.util.Optional;
26-
import java.util.Set;
27-
import java.util.concurrent.atomic.AtomicBoolean;
28-
19+
import com.azure.ai.openai.OpenAIAsyncClient;
20+
import com.azure.ai.openai.OpenAIClient;
21+
import com.azure.ai.openai.OpenAIClientBuilder;
22+
import com.azure.ai.openai.models.*;
23+
import com.azure.core.util.BinaryData;
2924
import org.springframework.ai.azure.openai.metadata.AzureOpenAiUsage;
3025
import org.springframework.ai.chat.messages.AssistantMessage;
3126
import org.springframework.ai.chat.messages.Message;
@@ -49,37 +44,19 @@
4944
import org.springframework.ai.model.function.FunctionCallbackContext;
5045
import org.springframework.util.Assert;
5146
import org.springframework.util.CollectionUtils;
52-
53-
import com.azure.ai.openai.OpenAIClient;
54-
import com.azure.ai.openai.models.ChatChoice;
55-
import com.azure.ai.openai.models.ChatCompletions;
56-
import com.azure.ai.openai.models.ChatCompletionsFunctionToolCall;
57-
import com.azure.ai.openai.models.ChatCompletionsFunctionToolDefinition;
58-
import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat;
59-
import com.azure.ai.openai.models.ChatCompletionsOptions;
60-
import com.azure.ai.openai.models.ChatCompletionsResponseFormat;
61-
import com.azure.ai.openai.models.ChatCompletionsTextResponseFormat;
62-
import com.azure.ai.openai.models.ChatCompletionsToolCall;
63-
import com.azure.ai.openai.models.ChatCompletionsToolDefinition;
64-
import com.azure.ai.openai.models.ChatMessageContentItem;
65-
import com.azure.ai.openai.models.ChatMessageImageContentItem;
66-
import com.azure.ai.openai.models.ChatMessageImageUrl;
67-
import com.azure.ai.openai.models.ChatMessageTextContentItem;
68-
import com.azure.ai.openai.models.ChatRequestAssistantMessage;
69-
import com.azure.ai.openai.models.ChatRequestMessage;
70-
import com.azure.ai.openai.models.ChatRequestSystemMessage;
71-
import com.azure.ai.openai.models.ChatRequestToolMessage;
72-
import com.azure.ai.openai.models.ChatRequestUserMessage;
73-
import com.azure.ai.openai.models.CompletionsFinishReason;
74-
import com.azure.ai.openai.models.ContentFilterResultsForPrompt;
75-
import com.azure.ai.openai.models.FunctionCall;
76-
import com.azure.ai.openai.models.FunctionDefinition;
77-
import com.azure.core.util.BinaryData;
78-
import com.azure.core.util.IterableStream;
79-
8047
import reactor.core.publisher.Flux;
8148
import reactor.core.publisher.Mono;
8249

50+
import java.util.ArrayList;
51+
import java.util.Base64;
52+
import java.util.Collections;
53+
import java.util.HashSet;
54+
import java.util.List;
55+
import java.util.Map;
56+
import java.util.Optional;
57+
import java.util.Set;
58+
import java.util.concurrent.atomic.AtomicBoolean;
59+
8360
/**
8461
* {@link ChatModel} implementation for {@literal Microsoft Azure AI} backed by
8562
* {@link OpenAIClient}.
@@ -96,6 +73,7 @@
9673
* @author Soby Chacko
9774
* @see ChatModel
9875
* @see com.azure.ai.openai.OpenAIClient
76+
* @since 1.0.0
9977
*/
10078
public class AzureOpenAiChatModel extends AbstractToolCallSupport implements ChatModel {
10179

@@ -108,34 +86,40 @@ public class AzureOpenAiChatModel extends AbstractToolCallSupport implements Cha
10886
*/
10987
private final OpenAIClient openAIClient;
11088

89+
/**
90+
* The {@link OpenAIAsyncClient} used for streaming async operations.
91+
*/
92+
private final OpenAIAsyncClient openAIAsyncClient;
93+
11194
/**
11295
* The configuration information for a chat completions request.
11396
*/
114-
private AzureOpenAiChatOptions defaultOptions;
97+
private final AzureOpenAiChatOptions defaultOptions;
11598

116-
public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient) {
99+
public AzureOpenAiChatModel(OpenAIClientBuilder microsoftOpenAiClient) {
117100
this(microsoftOpenAiClient,
118101
AzureOpenAiChatOptions.builder()
119102
.withDeploymentName(DEFAULT_DEPLOYMENT_NAME)
120103
.withTemperature(DEFAULT_TEMPERATURE)
121104
.build());
122105
}
123106

124-
public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient, AzureOpenAiChatOptions options) {
125-
this(microsoftOpenAiClient, options, null);
107+
public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions options) {
108+
this(openAIClientBuilder, options, null);
126109
}
127110

128-
public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient, AzureOpenAiChatOptions options,
111+
public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions options,
129112
FunctionCallbackContext functionCallbackContext) {
130-
this(microsoftOpenAiClient, options, functionCallbackContext, List.of());
113+
this(openAIClientBuilder, options, functionCallbackContext, List.of());
131114
}
132115

133-
public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient, AzureOpenAiChatOptions options,
116+
public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions options,
134117
FunctionCallbackContext functionCallbackContext, List<FunctionCallback> toolFunctionCallbacks) {
135118
super(functionCallbackContext, options, toolFunctionCallbacks);
136-
Assert.notNull(microsoftOpenAiClient, "com.azure.ai.openai.OpenAIClient must not be null");
119+
Assert.notNull(openAIClientBuilder, "com.azure.ai.openai.OpenAIClient must not be null");
137120
Assert.notNull(options, "AzureOpenAiChatOptions must not be null");
138-
this.openAIClient = microsoftOpenAiClient;
121+
this.openAIClient = openAIClientBuilder.buildClient();
122+
this.openAIAsyncClient = openAIClientBuilder.buildAsyncClient();
139123
this.defaultOptions = options;
140124
}
141125

@@ -170,11 +154,11 @@ public Flux<ChatResponse> stream(Prompt prompt) {
170154
ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt);
171155
options.setStream(true);
172156

173-
IterableStream<ChatCompletions> chatCompletionsStream = this.openAIClient
157+
Flux<ChatCompletions> chatCompletionsStream = this.openAIAsyncClient
174158
.getChatCompletionsStream(options.getModel(), options);
175159

176160
final var isFunctionCall = new AtomicBoolean(false);
177-
final Flux<ChatCompletions> accessibleChatCompletionsFlux = Flux.fromIterable(chatCompletionsStream)
161+
final Flux<ChatCompletions> accessibleChatCompletionsFlux = chatCompletionsStream
178162
// Note: the first chat completions can be ignored when using Azure OpenAI
179163
// service which is a known service bug.
180164
.filter(chatCompletions -> !CollectionUtils.isEmpty(chatCompletions.getChoices()))
@@ -254,15 +238,13 @@ public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptM
254238
Assert.notNull(chatCompletions, "Azure OpenAI ChatCompletions must not be null");
255239
String id = chatCompletions.getId();
256240
Usage usage = (chatCompletions.getUsage() != null) ? AzureOpenAiUsage.from(chatCompletions) : new EmptyUsage();
257-
ChatResponseMetadata chatResponseMetadata = ChatResponseMetadata.builder()
241+
return ChatResponseMetadata.builder()
258242
.withId(id)
259243
.withUsage(usage)
260244
.withModel(chatCompletions.getModel())
261245
.withPromptMetadata(promptFilterMetadata)
262246
.withKeyValue("system-fingerprint", chatCompletions.getSystemFingerprint())
263247
.build();
264-
265-
return chatResponseMetadata;
266248
}
267249

268250
/**

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageModel.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import org.springframework.ai.image.ImageResponse;
2323
import org.springframework.ai.image.ImageResponseMetadata;
2424
import org.springframework.ai.model.ModelOptionsUtils;
25-
import org.springframework.beans.factory.annotation.Autowired;
2625
import org.springframework.util.Assert;
2726

2827
import java.util.List;
@@ -36,15 +35,14 @@
3635
* @author Benoit Moussaud
3736
* @see ImageModel
3837
* @see com.azure.ai.openai.OpenAIClient
39-
* @since 1.0.0 M1
38+
* @since 1.0.0
4039
*/
4140
public class AzureOpenAiImageModel implements ImageModel {
4241

4342
private static final String DEFAULT_DEPLOYMENT_NAME = AzureOpenAiImageOptions.DEFAULT_IMAGE_MODEL;
4443

4544
private final Logger logger = LoggerFactory.getLogger(getClass());
4645

47-
@Autowired
4846
private final OpenAIClient openAIClient;
4947

5048
private final AzureOpenAiImageOptions defaultOptions;

models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.springframework.ai.azure.openai;
1818

1919
import com.azure.ai.openai.OpenAIClient;
20+
import com.azure.ai.openai.OpenAIClientBuilder;
2021
import com.azure.ai.openai.models.AzureChatEnhancementConfiguration;
2122
import com.azure.ai.openai.models.AzureChatOCREnhancementConfiguration;
2223
import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat;
@@ -44,7 +45,7 @@ public class AzureChatCompletionsOptionsTests {
4445
@Test
4546
public void createRequestWithChatOptions() {
4647

47-
OpenAIClient mockClient = Mockito.mock(OpenAIClient.class);
48+
OpenAIClientBuilder mockClient = Mockito.mock(OpenAIClientBuilder.class);
4849

4950
AzureChatEnhancementConfiguration mockAzureChatEnhancementConfiguration = Mockito
5051
.mock(AzureChatEnhancementConfiguration.class);
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
/*
2+
* Copyright 2023 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.azure.openai;
18+
19+
import static com.azure.core.http.policy.HttpLogDetailLevel.BODY_AND_HEADERS;
20+
import static org.assertj.core.api.Assertions.assertThat;
21+
22+
import java.util.Arrays;
23+
import java.util.stream.Collectors;
24+
25+
import org.junit.jupiter.api.Test;
26+
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
27+
28+
import org.springframework.ai.chat.client.ChatClient;
29+
import org.springframework.beans.factory.annotation.Autowired;
30+
import org.springframework.boot.SpringBootConfiguration;
31+
import org.springframework.boot.test.context.SpringBootTest;
32+
import org.springframework.context.annotation.Bean;
33+
34+
import com.azure.ai.openai.OpenAIClientBuilder;
35+
import com.azure.ai.openai.OpenAIServiceVersion;
36+
import com.azure.core.credential.AzureKeyCredential;
37+
import com.azure.core.http.policy.HttpLogOptions;
38+
39+
/**
40+
* @author Soby Chacko
41+
*/
42+
@SpringBootTest(classes = AzureOpenAiChatClientTest.TestConfiguration.class)
43+
@EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_API_KEY", matches = ".+")
44+
@EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_ENDPOINT", matches = ".+")
45+
public class AzureOpenAiChatClientTest {
46+
47+
@Autowired
48+
private ChatClient chatClient;
49+
50+
@Test
51+
void streamingAndImperativeResponsesContainIdenticalRelevantResults() {
52+
String prompt = "Name all states in the USA and their capitals, add a space followed by a hyphen, then another space between the two. "
53+
+ "List them with a numerical index. Do not use any abbreviations in state or capitals.";
54+
55+
// Imperative call
56+
String rawDataFromImperativeCall = chatClient.prompt(prompt).call().content();
57+
String imperativeStatesData = extractStatesData(rawDataFromImperativeCall);
58+
String formattedImperativeResponse = formatResponse(imperativeStatesData);
59+
60+
// Streaming call
61+
String stitchedResponseFromStream = chatClient.prompt(prompt)
62+
.stream()
63+
.content()
64+
.collectList()
65+
.block()
66+
.stream()
67+
.collect(Collectors.joining());
68+
String streamingStatesData = extractStatesData(stitchedResponseFromStream);
69+
String formattedStreamingResponse = formatResponse(streamingStatesData);
70+
71+
// Assertions
72+
assertThat(formattedStreamingResponse).isEqualTo(formattedImperativeResponse);
73+
assertThat(formattedStreamingResponse).contains("1. Alabama - Montgomery");
74+
assertThat(formattedStreamingResponse).contains("50. Wyoming - Cheyenne");
75+
assertThat(formattedStreamingResponse.lines().count()).isEqualTo(50);
76+
}
77+
78+
private String extractStatesData(String rawData) {
79+
int firstStateIndex = rawData.indexOf("1. Alabama - Montgomery");
80+
String lastAlphabeticalState = "50. Wyoming - Cheyenne";
81+
int lastStateIndex = rawData.indexOf(lastAlphabeticalState) + lastAlphabeticalState.length();
82+
return rawData.substring(firstStateIndex, lastStateIndex);
83+
}
84+
85+
private String formatResponse(String response) {
86+
return String.join("\n", Arrays.stream(response.split("\n")).map(String::strip).toArray(String[]::new));
87+
}
88+
89+
@SpringBootConfiguration
90+
public static class TestConfiguration {
91+
92+
@Bean
93+
public OpenAIClientBuilder openAIClient() {
94+
return new OpenAIClientBuilder().credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY")))
95+
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
96+
.serviceVersion(OpenAIServiceVersion.V2024_02_15_PREVIEW)
97+
.httpLogOptions(new HttpLogOptions().setLogLevel(BODY_AND_HEADERS));
98+
}
99+
100+
@Bean
101+
public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder) {
102+
return new AzureOpenAiChatModel(openAIClientBuilder,
103+
AzureOpenAiChatOptions.builder().withDeploymentName("gpt-4o").withMaxTokens(1000).build());
104+
105+
}
106+
107+
@Bean
108+
public ChatClient chatClient(AzureOpenAiChatModel azureOpenAiChatModel) {
109+
return ChatClient.builder(azureOpenAiChatModel).build();
110+
}
111+
112+
}
113+
114+
}

models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelIT.java

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
*/
1616
package org.springframework.ai.azure.openai;
1717

18-
import com.azure.ai.openai.OpenAIClient;
1918
import com.azure.ai.openai.OpenAIClientBuilder;
2019
import com.azure.ai.openai.OpenAIServiceVersion;
2120
import com.azure.core.credential.AzureKeyCredential;
@@ -262,17 +261,16 @@ record ActorsFilmsRecord(String actor, List<String> movies) {
262261
public static class TestConfiguration {
263262

264263
@Bean
265-
public OpenAIClient openAIClient() {
264+
public OpenAIClientBuilder openAIClientBuilder() {
266265
return new OpenAIClientBuilder().credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY")))
267266
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
268267
.serviceVersion(OpenAIServiceVersion.V2024_02_15_PREVIEW)
269-
.httpLogOptions(new HttpLogOptions().setLogLevel(BODY_AND_HEADERS))
270-
.buildClient();
268+
.httpLogOptions(new HttpLogOptions().setLogLevel(BODY_AND_HEADERS));
271269
}
272270

273271
@Bean
274-
public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClient openAIClient) {
275-
return new AzureOpenAiChatModel(openAIClient,
272+
public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder) {
273+
return new AzureOpenAiChatModel(openAIClientBuilder,
276274
AzureOpenAiChatOptions.builder().withDeploymentName("gpt-4o").withMaxTokens(1000).build());
277275

278276
}

models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/MockAzureOpenAiTestConfiguration.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
*
4242
* @author John Blum
4343
* @see org.springframework.boot.SpringBootConfiguration
44-
* @see org.springframework.ai.test.config.MockAiTestConfiguration
44+
* @see org.springframework.ai.azure.openai.MockAiTestConfiguration
4545
* @since 0.7.0
4646
*/
4747
@SpringBootConfiguration
@@ -51,15 +51,15 @@
5151
public class MockAzureOpenAiTestConfiguration {
5252

5353
@Bean
54-
OpenAIClient microsoftAzureOpenAiClient(MockWebServer webServer) {
54+
OpenAIClientBuilder microsoftAzureOpenAiClient(MockWebServer webServer) {
5555

5656
HttpUrl baseUrl = webServer.url(MockAiTestConfiguration.SPRING_AI_API_PATH);
5757

58-
return new OpenAIClientBuilder().endpoint(baseUrl.toString()).buildClient();
58+
return new OpenAIClientBuilder().endpoint(baseUrl.toString());
5959
}
6060

6161
@Bean
62-
AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClient microsoftAzureOpenAiClient) {
62+
AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClientBuilder microsoftAzureOpenAiClient) {
6363
return new AzureOpenAiChatModel(microsoftAzureOpenAiClient);
6464
}
6565

models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -183,14 +183,13 @@ void functionCallSequentialAndStreamTest() {
183183
public static class TestConfiguration {
184184

185185
@Bean
186-
public OpenAIClient openAIClient() {
186+
public OpenAIClientBuilder openAIClient() {
187187
return new OpenAIClientBuilder().credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY")))
188-
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
189-
.buildClient();
188+
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"));
190189
}
191190

192191
@Bean
193-
public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClient openAIClient, String selectedModel) {
192+
public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClientBuilder openAIClient, String selectedModel) {
194193
return new AzureOpenAiChatModel(openAIClient,
195194
AzureOpenAiChatOptions.builder().withDeploymentName(selectedModel).withMaxTokens(500).build());
196195
}

0 commit comments

Comments
 (0)