Skip to content
This repository has been archived by the owner on Aug 25, 2024. It is now read-only.

Commit

Permalink
Fix VertexAI parameters (LangStream#491)
Browse files Browse the repository at this point in the history
  • Loading branch information
nicoloboschi authored Sep 27, 2023
1 parent c4a84a3 commit 00c4a9f
Show file tree
Hide file tree
Showing 22 changed files with 128 additions and 90 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,6 @@ pipeline:
completion-field: "value.answer"
# we are also logging the prompt we sent to the LLM
log-field: "value.prompt"
max-tokens: 20
prompt:
- "{{% value.question}}"
Original file line number Diff line number Diff line change
Expand Up @@ -231,10 +231,11 @@ public CompletableFuture<String> getTextCompletions(

// this is the default behavior, as it is async
// it works even if the streamingChunksConsumer is null
final String model = (String) options.get("model");
if (completionsOptions.isStream()) {
CompletableFuture<?> finished = new CompletableFuture<>();
Flux<com.azure.ai.openai.models.Completions> flux =
client.getCompletionsStream((String) options.get("model"), completionsOptions);
client.getCompletionsStream(model, completionsOptions);

TextCompletionsConsumer textCompletionsConsumer =
new TextCompletionsConsumer(
Expand All @@ -253,8 +254,7 @@ public CompletableFuture<String> getTextCompletions(
return finished.thenApply(___ -> textCompletionsConsumer.totalAnswer.toString());
} else {
com.azure.ai.openai.models.Completions completions =
client.getCompletions((String) options.get("model"), completionsOptions)
.block();
client.getCompletions(model, completionsOptions).block();
final String text = completions.getChoices().get(0).getText();
return CompletableFuture.completedFuture(text);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package ai.langstream.ai.agents.services.impl;

import ai.langstream.ai.agents.services.ServiceProviderProvider;
import ai.langstream.api.util.ConfigurationUtils;
import com.datastax.oss.streaming.ai.completions.ChatChoice;
import com.datastax.oss.streaming.ai.completions.ChatCompletions;
import com.datastax.oss.streaming.ai.completions.ChatMessage;
Expand Down Expand Up @@ -315,19 +316,33 @@ private void appendRequestParameters(
Map<String, Object> additionalConfiguration, CompletionRequest request) {
request.parameters = new HashMap<>();

if (additionalConfiguration.containsKey("temperature")) {
request.parameters.put(
"temperature", additionalConfiguration.get("temperature"));
}
if (additionalConfiguration.containsKey("max-tokens")) {
request.parameters.put(
"maxOutputTokens", additionalConfiguration.get("max-tokens"));
}
if (additionalConfiguration.containsKey("topP")) {
request.parameters.put("topP", additionalConfiguration.get("topP"));
appendDoubleValue("temperature", "temperature", additionalConfiguration, request);
appendIntValue("max-tokens", "maxOutputTokens", additionalConfiguration, request);
appendDoubleValue("topP", "topP", additionalConfiguration, request);
appendIntValue("topK", "topK", additionalConfiguration, request);
}

private void appendDoubleValue(
String key,
String toKey,
Map<String, Object> additionalConfiguration,
CompletionRequest request) {
final Double typedValue =
ConfigurationUtils.getDouble(key, null, additionalConfiguration);
if (typedValue != null) {
request.parameters.put(toKey, typedValue);
}
if (additionalConfiguration.containsKey("topK")) {
request.parameters.put("topK", additionalConfiguration.get("topK"));
}

private void appendIntValue(
String key,
String toKey,
Map<String, Object> additionalConfiguration,
CompletionRequest request) {
final Integer typedValue =
ConfigurationUtils.getInteger(key, null, additionalConfiguration);
if (typedValue != null) {
request.parameters.put(toKey, typedValue);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import static com.datastax.oss.streaming.ai.util.TransformFunctionUtil.convertToMap;

import com.azure.ai.openai.models.ChatCompletionsOptions;
import com.datastax.oss.streaming.ai.completions.ChatChoice;
import com.datastax.oss.streaming.ai.completions.ChatCompletions;
import com.datastax.oss.streaming.ai.completions.ChatMessage;
Expand Down Expand Up @@ -121,21 +120,8 @@ public CompletableFuture<?> processAsync(TransformContext transformContext) {
.execute(jsonRecord)))
.collect(Collectors.toList());

ChatCompletionsOptions chatCompletionsOptions =
new ChatCompletionsOptions(List.of())
.setMaxTokens(config.getMaxTokens())
.setTemperature(config.getTemperature())
.setTopP(config.getTopP())
.setLogitBias(config.getLogitBias())
.setStream(config.isStream())
.setUser(config.getUser())
.setStop(config.getStop())
.setPresencePenalty(config.getPresencePenalty())
.setFrequencyPenalty(config.getFrequencyPenalty());
Map<String, Object> options = convertToMap(chatCompletionsOptions);
options.put("model", config.getModel());
Map<String, Object> options = convertToMap(config);
options.put("min-chunks-per-message", config.getMinChunksPerMessage());
options.remove("messages");

CompletableFuture<ChatCompletions> chatCompletionsHandle =
completionsService.getChatCompletions(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import static com.datastax.oss.streaming.ai.util.TransformFunctionUtil.convertToMap;

import com.azure.ai.openai.models.CompletionsOptions;
import com.datastax.oss.streaming.ai.completions.Chunk;
import com.datastax.oss.streaming.ai.completions.CompletionsService;
import com.datastax.oss.streaming.ai.model.JsonRecord;
Expand Down Expand Up @@ -87,21 +86,8 @@ public CompletableFuture<?> processAsync(TransformContext transformContext) {
.map(p -> messageTemplates.get(p).execute(jsonRecord))
.collect(Collectors.toList());

CompletionsOptions completionsOptions =
new CompletionsOptions(List.of())
.setMaxTokens(config.getMaxTokens())
.setTemperature(config.getTemperature())
.setTopP(config.getTopP())
.setLogitBias(config.getLogitBias())
.setStream(config.isStream())
.setUser(config.getUser())
.setStop(config.getStop())
.setPresencePenalty(config.getPresencePenalty())
.setFrequencyPenalty(config.getFrequencyPenalty());
Map<String, Object> options = convertToMap(completionsOptions);
options.put("model", config.getModel());
final Map<String, Object> options = convertToMap(config);
options.put("min-chunks-per-message", config.getMinChunksPerMessage());
options.remove("messages");

CompletableFuture<String> chatCompletionsHandle =
completionsService.getTextCompletions(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ EmbeddingsService getEmbeddingsService(Map<String, Object> additionalConfigurati

void close();

public static class NoopServiceProvider implements ServiceProvider {
class NoopServiceProvider implements ServiceProvider {
@Override
public CompletionsService getCompletionsService(
Map<String, Object> additionalConfiguration) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,14 @@ void testChatCompletionsWithLogField() throws Exception {
Utils.getRecord(messageSchema.getValueSchema(), (byte[]) messageValue.getValue());
assertEquals("result", valueAvroRecord.get("completion").toString());
assertEquals(
valueAvroRecord.get("log").toString(),
"{\"options\":{\"max_tokens\":null,\"temperature\":null,\"top_p\":null,\"logit_bias\":null,\"user\":null,\"n\":null,\"stop\":null,\"presence_penalty\":null,\"frequency_penalty\":null,\"stream\":true,\"model\":\"test-model\",\"functions\":null,\"function_call\":null,\"dataSources\":null,\"min-chunks-per-message\":20},\"messages\":[{\"role\":\"user\",\"content\":\"value1 key2\"}],\"model\":\"test-model\"}");
"{\"options\":{\"type\":\"ai-chat-completions\",\"when\":null,\"model\":\"test-model\","
+ "\"messages\":[{\"role\":\"user\",\"content\":\"{{ value.valueField1 }} {{ key.keyField2 }}\"}],"
+ "\"stream-to-topic\":null,\"stream-response-completion-field\":null,\"min-chunks-per-message\":20,"
+ "\"completion-field\":\"value.completion\",\"stream\":true,\"log-field\":\"value.log\","
+ "\"max-tokens\":null,\"temperature\":null,\"top-p\":null,\"logit-bias\":null,\"user\":null,"
+ "\"stop\":null,\"presence-penalty\":null,\"frequency-penalty\":null},"
+ "\"messages\":[{\"role\":\"user\",\"content\":\"value1 key2\"}],\"model\":\"test-model\"}",
valueAvroRecord.get("log").toString());
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,8 +403,6 @@ private void generateAIProvidersConfiguration(
}
} else {
for (Resource resource : applicationInstance.getResources().values()) {
Map<String, Object> configurationCopy =
clusterRuntime.getResourceImplementation(resource, pluginsRegistry);
final String configKey =
switch (resource.type()) {
case SERVICE_VERTEX -> "vertex";
Expand All @@ -413,6 +411,8 @@ private void generateAIProvidersConfiguration(
default -> null;
};
if (configKey != null) {
Map<String, Object> configurationCopy =
clusterRuntime.getResourceImplementation(resource, pluginsRegistry);
configuration.put(configKey, configurationCopy);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/
package ai.langstream.tests;

import static ai.langstream.tests.TextCompletionsIT.getAppEnvForAIServiceProvider;

import ai.langstream.tests.util.BaseEndToEndTest;
import ai.langstream.tests.util.ConsumeGatewayMessage;
import java.util.List;
Expand All @@ -36,13 +38,10 @@ public class ChatCompletionsIT extends BaseEndToEndTest {

@BeforeAll
public static void checkCredentials() {
appEnv =
appEnv = getAppEnvForAIServiceProvider();
appEnv.putAll(
getAppEnvMapFromSystem(
List.of(
"OPEN_AI_ACCESS_KEY",
"OPEN_AI_URL",
"OPEN_AI_CHAT_COMPLETIONS_MODEL",
"OPEN_AI_PROVIDER"));
List.of("CHAT_COMPLETIONS_MODEL", "CHAT_COMPLETIONS_SERVICE")));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,25 @@ public class TextCompletionsIT extends BaseEndToEndTest {

@BeforeAll
public static void checkCredentials() {
appEnv = getAppEnvForAIServiceProvider();
appEnv.putAll(
getAppEnvMapFromSystem(
List.of("TEXT_COMPLETIONS_MODEL", "TEXT_COMPLETIONS_SERVICE")));
}

public static Map<String, String> getAppEnvForAIServiceProvider() {
try {
appEnv =
getAppEnvMapFromSystem(
List.of("OPEN_AI_ACCESS_KEY", "OPEN_AI_URL", "OPEN_AI_PROVIDER"));
return getAppEnvMapFromSystem(
List.of("OPEN_AI_ACCESS_KEY", "OPEN_AI_URL", "OPEN_AI_PROVIDER"));
} catch (Throwable t) {
// no openai - try vertex
appEnv =
getAppEnvMapFromSystem(
List.of(
"VERTEX_AI_URL",
"VERTEX_AI_TOKEN",
"VERTEX_AI_REGION",
"VERTEX_AI_PROJECT"));
return getAppEnvMapFromSystem(
List.of(
"VERTEX_AI_URL",
"VERTEX_AI_TOKEN",
"VERTEX_AI_REGION",
"VERTEX_AI_PROJECT"));
}

appEnv.putAll(
getAppEnvMapFromSystem(
List.of("TEXT_COMPLETIONS_MODEL", "TEXT_COMPLETIONS_SERVICE")));
}

@Test
Expand All @@ -80,6 +81,6 @@ public void test() throws Exception {
.formatted(sessionId)
.split(" "));
log.info("Output: {}", message);
Assertions.assertTrue(message.getAnswerFromChatCompletionsValue().contains("Bounjour"));
Assertions.assertTrue(message.getAnswerFromChatCompletionsValue().contains("Bonjour"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/
package ai.langstream.tests;

import static ai.langstream.tests.TextCompletionsIT.getAppEnvForAIServiceProvider;

import ai.langstream.tests.util.BaseEndToEndTest;
import ai.langstream.tests.util.ConsumeGatewayMessage;
import java.util.List;
Expand All @@ -36,20 +38,20 @@ public class WebCrawlerToVectorIT extends BaseEndToEndTest {

@BeforeAll
public static void checkCredentials() {
appEnv =
appEnv = getAppEnvForAIServiceProvider();
appEnv.putAll(
getAppEnvMapFromSystem(
List.of("CHAT_COMPLETIONS_MODEL", "CHAT_COMPLETIONS_SERVICE")));
appEnv.putAll(getAppEnvMapFromSystem(List.of("EMBEDDINGS_MODEL", "EMBEDDINGS_SERVICE")));

appEnv.putAll(
getAppEnvMapFromSystem(
List.of(
"OPEN_AI_ACCESS_KEY",
"OPEN_AI_URL",
"OPEN_AI_EMBEDDINGS_MODEL",
"OPEN_AI_CHAT_COMPLETIONS_MODEL",
"OPEN_AI_PROVIDER",
"ASTRA_TOKEN",
"ASTRA_CLIENT_ID",
"ASTRA_SECRET",
"ASTRA_SECURE_BUNDLE",
"ASTRA_ENVIRONMENT",
"ASTRA_DATABASE"));
"ASTRA_DATABASE")));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,9 @@ private static KubeCluster getKubeCluster() {
public void setupSingleTest() {
// cleanup previous runs
cleanupAllEndToEndTestsNamespaces();
codeStorageProvider.cleanup();
streamingClusterProvider.cleanup();

namespace = "ls-test-" + UUID.randomUUID().toString().substring(0, 8);

client.resource(
Expand Down Expand Up @@ -1132,7 +1135,6 @@ private static void deployLocalApplicationAndAwaitReady(
.pollInterval(5, TimeUnit.SECONDS)
.untilAsserted(
() -> {
log.info("waiting new executors to be ready");
final List<Pod> pods =
client.pods()
.inNamespace(tenantNamespace)
Expand All @@ -1144,6 +1146,10 @@ private static void deployLocalApplicationAndAwaitReady(
"langstream-runtime"))
.list()
.getItems();
log.info(
"waiting new executors to be ready, found {}, expected {}",
pods.size(),
expectedNumExecutors);
if (pods.size() != expectedNumExecutors) {
fail("too many pods: " + pods.size());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,17 @@
configuration:
resources:
- type: "open-ai-configuration"
id: "open-ai"
name: "OpenAI Azure configuration"
configuration:
url: "{{ secrets.open-ai.url }}"
access-key: "{{ secrets.open-ai.access-key }}"
provider: "{{ secrets.open-ai.provider }}"
- type: "vertex-configuration"
name: "Google Vertex AI configuration"
id: "vertex"
configuration:
url: "{{ secrets.vertex-ai.url }}"
token: "{{ secrets.vertex-ai.token }}"
region: "{{ secrets.vertex-ai.region }}"
project: "{{ secrets.vertex-ai.project }}"
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@ pipeline:
type: "ai-chat-completions"
output: "ls-test-history-topic"
configuration:
model: "{{{secrets.open-ai.chat-completions-model}}}"
ai-service: "{{{secrets.chat-completions.service}}}"
model: "{{{secrets.chat-completions.model}}}"
completion-field: "value.answer"
log-field: "value.prompt"
stream-to-topic: "ls-test-output-topic"
stream-response-completion-field: "value"
min-chunks-per-message: 20
max-tokens: 20
messages:
- role: user
content: "You are an helpful assistant. Below you can fine a question from the user. Please try to help them the best way you can.\n\n{{% value.question}}"
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
configuration:
resources:
- type: "open-ai-configuration"
id: "open-ai"
name: "OpenAI Azure configuration"
configuration:
url: "{{ secrets.open-ai.url }}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,6 @@ pipeline:
stream-to-topic: "ls-test-output-topic"
stream-response-completion-field: "value"
min-chunks-per-message: 20
max-tokens: 20
prompt:
- "{{% value.question}}"
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ pipeline:
type: "ai-chat-completions"

configuration:
model: "{{{secrets.open-ai.chat-completions-model}}}" # This needs to be set to the model deployment name, not the base name
ai-service: "{{{secrets.chat-completions.service}}}"
model: "{{{secrets.chat-completions.model}}}" # This needs to be set to the model deployment name, not the base name
# on the ls-test-log-topic we add a field with the answer
completion-field: "value.answer"
# we are also logging the prompt we sent to the LLM
Expand Down
Loading

0 comments on commit 00c4a9f

Please sign in to comment.