Skip to content

Commit

Permalink
Improve missing AI service validation (LangStream#597)
Browse files Browse the repository at this point in the history
  • Loading branch information
nicoloboschi committed Oct 16, 2023
1 parent 86c569e commit 8e40d2d
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;

@Slf4j
Expand Down Expand Up @@ -155,7 +156,8 @@ protected void generateSteps(
application,
configuration,
computeClusterRuntime,
pluginsRegistry);
pluginsRegistry,
agentConfiguration);

STEP_TYPES
.get(agentConfiguration.getType())
Expand Down Expand Up @@ -185,7 +187,8 @@ private void generateAIServiceConfiguration(
Application applicationInstance,
Map<String, Object> configuration,
ComputeClusterRuntime clusterRuntime,
PluginsRegistry pluginsRegistry) {
PluginsRegistry pluginsRegistry,
AgentConfiguration agentConfiguration) {
if (resourceId != null) {
Resource resource = applicationInstance.getResources().get(resourceId);
log.info("Generating ai service configuration for {}", resourceId);
Expand All @@ -208,6 +211,7 @@ private void generateAIServiceConfiguration(
throw new IllegalArgumentException("Resource " + resourceId + " not found");
}
} else {
boolean found = false;
for (Resource resource : applicationInstance.getResources().values()) {
final String configKey =
switch (resource.type()) {
Expand All @@ -220,8 +224,18 @@ private void generateAIServiceConfiguration(
Map<String, Object> configurationCopy =
clusterRuntime.getResourceImplementation(resource, pluginsRegistry);
configuration.put(configKey, configurationCopy);
found = true;
}
}
if (!found) {
final String errString =
ClassConfigValidator.formatErrString(
new ClassConfigValidator.AgentEntityRef(agentConfiguration),
"No ai service resource found in application configuration. One of "
+ AI_SERVICES.stream().collect(Collectors.joining(", "))
+ " must be defined.");
throw new IllegalArgumentException(errString);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,10 @@ public static String formatErrString(EntityRef entityRef, String property, Strin
return "Found error on %s. Property '%s' %s".formatted(entityRef.ref(), property, message);
}

public static String formatErrString(EntityRef entityRef, String message) {
return "Found error on %s. %s".formatted(entityRef.ref(), message);
}

private static void validateProperties(
EntityRef entityRef,
String parentProp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,26 +29,35 @@
public class AgentValidationTestUtil {

public static void validate(String pipeline, String expectErrMessage) throws Exception {
validate(pipeline, null, expectErrMessage);
}

public static void validate(String pipeline, String configuration, String expectErrMessage)
throws Exception {
if (expectErrMessage != null && expectErrMessage.isBlank()) {
throw new IllegalArgumentException("expectErrMessage cannot be blank");
}
if (configuration == null) {
configuration =
"""
configuration:
resources:
- type: "datasource"
name: "cassandra"
configuration:
service: "cassandra"
contact-points: "xx"
loadBalancing-localDc: "xx"
port: 999
""";
}
Application applicationInstance =
ModelBuilder.buildApplicationInstance(
Map.of(
"module.yaml",
pipeline,
"configuration.yaml",
"""
configuration:
resources:
- type: "datasource"
name: "cassandra"
configuration:
service: "cassandra"
contact-points: "xx"
loadBalancing-localDc: "xx"
port: 999
"""),
configuration),
"""
instance:
streamingCluster:
Expand All @@ -70,7 +79,7 @@ public static void validate(String pipeline, String expectErrMessage) throws Exc
ExecutionPlan implementation =
deployer.createImplementation("app", applicationInstance);
if (expectErrMessage != null) {
fail("Expected error message: " + expectErrMessage);
fail("Expected error message instead no errors thrown: " + expectErrMessage);
}
} catch (IllegalArgumentException e) {
if (expectErrMessage != null && e.getMessage().contains(expectErrMessage)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,46 @@ public void testValidationDropFields() {
null);
}

@Test
@SneakyThrows
public void testValidationAiChatCompletions() {
validate(
"""
topics: []
pipeline:
- name: "chat"
type: "ai-chat-completions"
configuration:
model: my-model
messages:
- role: system
content: "Hello"
""",
"Found error on agent configuration (agent: 'chat', type: 'ai-chat-completions'). No ai service resource found in application configuration. One of vertex-configuration, hugging-face-configuration, open-ai-configuration must be defined.");

AgentValidationTestUtil.validate(
"""
topics: []
pipeline:
- name: "chat"
type: "ai-chat-completions"
configuration:
model: my-model
messages:
- role: system
content: "Hello"
""",
"""
configuration:
resources:
- type: "open-ai-configuration"
name: "OpenAI Azure configuration"
configuration:
access-key: "yy"
""",
null);
}

private void validate(String pipeline, String expectErrMessage) throws Exception {
AgentValidationTestUtil.validate(pipeline, expectErrMessage);
}
Expand Down

0 comments on commit 8e40d2d

Please sign in to comment.