Skip to content

Commit

Permalink
Add vector database sink config validation and documentation (LangStr…
Browse files Browse the repository at this point in the history
  • Loading branch information
nicoloboschi committed Oct 16, 2023
1 parent 3507c37 commit c6686f4
Show file tree
Hide file tree
Showing 11 changed files with 948 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
import ai.langstream.api.database.VectorDatabaseWriterProvider;
import ai.langstream.api.runner.code.Record;
import ai.langstream.api.util.ConfigurationUtils;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.util.ArrayList;
Expand All @@ -38,9 +36,6 @@
@Slf4j
public class JdbcWriter implements VectorDatabaseWriterProvider {

private static final ObjectMapper MAPPER =
new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);

@Override
public boolean supports(Map<String, Object> dataSourceConfig) {
return "jdbc".equals(dataSourceConfig.get("service"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
@AllArgsConstructor
public class AgentConfigurationModel {

private String type;
private String name;
private String description;
private Map<String, ConfigPropertyModel> properties;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,8 @@ public static void validateAgentModelFromClass(
Class modelClazz,
Map<String, Object> asMap,
boolean allowUnknownProperties) {
final EntityRef ref =
() ->
"agent configuration (agent: '%s', type: '%s')"
.formatted(
agentConfiguration.getName() == null
? agentConfiguration.getId()
: agentConfiguration.getName(),
agentConfiguration.getType());
validateModelFromClass(ref, modelClazz, asMap, allowUnknownProperties);
validateModelFromClass(
new AgentEntityRef(agentConfiguration), modelClazz, asMap, allowUnknownProperties);
}

@AllArgsConstructor
Expand Down Expand Up @@ -199,6 +192,22 @@ public String ref() {
}
}

@AllArgsConstructor
public static class AgentEntityRef implements EntityRef {

private final AgentConfiguration agentConfiguration;

@Override
public String ref() {
return "agent configuration (agent: '%s', type: '%s')"
.formatted(
agentConfiguration.getName() == null
? agentConfiguration.getId()
: agentConfiguration.getName(),
agentConfiguration.getType());
}
}

@SneakyThrows
public static void validateAssetModelFromClass(
AssetDefinition asset,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package ai.langstream.runtime.impl.k8s.agents;

import ai.langstream.api.doc.AgentConfig;
import ai.langstream.api.doc.AgentConfigurationModel;
import ai.langstream.api.doc.ConfigProperty;
import ai.langstream.api.model.AgentConfiguration;
import ai.langstream.api.model.Application;
Expand All @@ -28,18 +29,58 @@
import ai.langstream.api.runtime.PluginsRegistry;
import ai.langstream.impl.agents.AbstractComposableAgentProvider;
import ai.langstream.impl.agents.ai.steps.QueryConfiguration;
import ai.langstream.impl.uti.ClassConfigValidator;
import ai.langstream.runtime.impl.k8s.KubernetesClusterRuntime;
import ai.langstream.runtime.impl.k8s.agents.vectors.CassandraVectorDatabaseWriterConfig;
import ai.langstream.runtime.impl.k8s.agents.vectors.JDBCVectorDatabaseWriterConfig;
import ai.langstream.runtime.impl.k8s.agents.vectors.MilvusVectorDatabaseWriterConfig;
import ai.langstream.runtime.impl.k8s.agents.vectors.OpenSearchVectorDatabaseWriterConfig;
import ai.langstream.runtime.impl.k8s.agents.vectors.PineconeVectorDatabaseWriterConfig;
import ai.langstream.runtime.impl.k8s.agents.vectors.SolrVectorDatabaseWriterConfig;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import lombok.Data;
import lombok.Getter;
import lombok.Setter;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class QueryVectorDBAgentProvider extends AbstractComposableAgentProvider {

protected static final ObjectMapper MAPPER = new ObjectMapper();

@Getter
@Setter
public abstract static class VectorDatabaseWriterConfig {
@ConfigProperty(
description =
"""
The defined datasource ID to use to store the vectors.
""",
required = true)
String datasource;

public abstract Class getAgentConfigModelClass();

public abstract boolean isAgentConfigModelAllowUnknownProperties();
}

protected static final String QUERY_VECTOR_DB = "query-vector-db";
protected static final String VECTOR_DB_SINK = "vector-db-sink";
protected static final Map<String, VectorDatabaseWriterConfig>
SUPPORTED_VECTOR_DB_SINK_DATASOURCES =
Map.of(
"cassandra", CassandraVectorDatabaseWriterConfig.CASSANDRA,
"astra", CassandraVectorDatabaseWriterConfig.ASTRA,
"jdbc", JDBCVectorDatabaseWriterConfig.INSTANCE,
"pinecone", PineconeVectorDatabaseWriterConfig.INSTANCE,
"opensearch", OpenSearchVectorDatabaseWriterConfig.INSTANCE,
"solr", SolrVectorDatabaseWriterConfig.INSTANCE,
"milvus", MilvusVectorDatabaseWriterConfig.INSTANCE);

public QueryVectorDBAgentProvider() {
super(
Expand Down Expand Up @@ -76,30 +117,79 @@ protected Map<String, Object> computeAgentConfiguration(
// get the datasource configuration and inject it into the agent configuration
String resourceId = (String) originalConfiguration.remove("datasource");
if (resourceId == null) {
throw new IllegalStateException(
"datasource is required but this exception should have been raised before ?");
throw new IllegalArgumentException(
ClassConfigValidator.formatErrString(
new ClassConfigValidator.AgentEntityRef(agentConfiguration),
"datasource",
"is required"));
}
generateDataSourceConfiguration(
resourceId,
executionPlan.getApplication(),
originalConfiguration,
clusterRuntime,
pluginsRegistry);
pluginsRegistry,
agentConfiguration);

return originalConfiguration;
}

private boolean isAgentConfigModelAllowUnknownProperties(String type, String service) {
switch (type) {
case QUERY_VECTOR_DB:
return false;
case VECTOR_DB_SINK:
{
final VectorDatabaseWriterConfig vectorDatabaseSinkConfig =
SUPPORTED_VECTOR_DB_SINK_DATASOURCES.get(service);
if (vectorDatabaseSinkConfig == null) {
throw new IllegalArgumentException(
"Unsupported vector database service: "
+ service
+ ". Supported services are: "
+ SUPPORTED_VECTOR_DB_SINK_DATASOURCES.keySet());
}
return vectorDatabaseSinkConfig.isAgentConfigModelAllowUnknownProperties();
}
default:
throw new IllegalStateException();
}
}

private Class getAgentConfigModelClass(String type, String service) {
switch (type) {
case QUERY_VECTOR_DB:
return QueryVectorDBConfig.class;
case VECTOR_DB_SINK:
{
final VectorDatabaseWriterConfig vectorDatabaseSinkConfig =
SUPPORTED_VECTOR_DB_SINK_DATASOURCES.get(service);
if (vectorDatabaseSinkConfig == null) {
throw new IllegalArgumentException(
"Unsupported vector database service: "
+ service
+ ". Supported services are: "
+ SUPPORTED_VECTOR_DB_SINK_DATASOURCES.keySet());
}
return vectorDatabaseSinkConfig.getAgentConfigModelClass();
}
default:
throw new IllegalStateException();
}
}

private void generateDataSourceConfiguration(
String resourceId,
Application applicationInstance,
Map<String, Object> configuration,
ComputeClusterRuntime computeClusterRuntime,
PluginsRegistry pluginsRegistry) {
PluginsRegistry pluginsRegistry,
AgentConfiguration agentConfiguration) {

Resource resource = applicationInstance.getResources().get(resourceId);
log.info("Generating datasource configuration for {}", resourceId);
if (resource != null) {
Map<String, Object> resourceImplementation =
Map<String, Object> resourceConfiguration =
computeClusterRuntime.getResourceImplementation(resource, pluginsRegistry);
if (!resource.type().equals("datasource")
&& !resource.type().equals("vector-database")) {
Expand All @@ -108,57 +198,60 @@ private void generateDataSourceConfiguration(
+ resourceId
+ "' is not type=datasource or type=vector-database");
}
if (configuration.containsKey("datasource")) {
throw new IllegalArgumentException("Only one datasource is supported");
configuration.put("datasource", resourceConfiguration);
final String type = agentConfiguration.getType();
final String service = (String) resourceConfiguration.get("service");
final Class modelClass = getAgentConfigModelClass(type, service);
if (modelClass != null) {
ClassConfigValidator.validateAgentModelFromClass(
agentConfiguration,
modelClass,
agentConfiguration.getConfiguration(),
isAgentConfigModelAllowUnknownProperties(type, service));
}
configuration.put("datasource", resourceImplementation);
} else {
throw new IllegalArgumentException("Resource '" + resourceId + "' not found");
}
}

@Override
protected Class getAgentConfigModelClass(String type) {
return switch (type) {
case QUERY_VECTOR_DB -> QueryVectorDBConfig.class;
case VECTOR_DB_SINK -> VectorDBSinkConfig.class;
default -> throw new IllegalStateException(type);
};
}

@Override
protected boolean isAgentConfigModelAllowUnknownProperties(String type) {
return switch (type) {
case QUERY_VECTOR_DB -> false;
case VECTOR_DB_SINK -> true;
default -> throw new IllegalStateException(type);
};
}

@AgentConfig(
name = "Query a vector database",
description =
"""
Query a vector database using Vector Search capabilities.
""")
Query a vector database using Vector Search capabilities.
""")
@Data
public static class QueryVectorDBConfig extends QueryConfiguration {}

@AgentConfig(
name = "Vector database sink",
description =
"""
Store vectors in a vector database.
Configuration properties depends on the vector database implementation, specified by the "datasource" property.
""")
@Data
public static class VectorDBSinkConfig {
@ConfigProperty(
description =
"""
The defined datasource ID to use to store the vectors.
""",
required = true)
private String datasource;
@Override
public Map<String, AgentConfigurationModel> generateSupportedTypesDocumentation() {
Map<String, AgentConfigurationModel> result = new LinkedHashMap<>();
result.put(
QUERY_VECTOR_DB,
ClassConfigValidator.generateAgentModelFromClass(QueryVectorDBConfig.class));

for (Map.Entry<String, VectorDatabaseWriterConfig> datasource :
SUPPORTED_VECTOR_DB_SINK_DATASOURCES.entrySet()) {
final String service = datasource.getKey();
AgentConfigurationModel value =
ClassConfigValidator.generateAgentModelFromClass(
datasource.getValue().getAgentConfigModelClass());
value = deepCopy(value);
value.getProperties()
.get("datasource")
.setDescription(
"Resource id. The target resource must be type: 'datasource' or 'vector-database' and "
+ "service: '"
+ service
+ "'.");
value.setType(VECTOR_DB_SINK);
result.put(VECTOR_DB_SINK + "_" + service, value);
}
return result;
}

@SneakyThrows
private static AgentConfigurationModel deepCopy(AgentConfigurationModel instance) {
return MAPPER.readValue(MAPPER.writeValueAsBytes(instance), AgentConfigurationModel.class);
}
}
Loading

0 comments on commit c6686f4

Please sign in to comment.