Skip to content

Commit

Permalink
[agents] New agent 'dispatch' for message routing (LangStream#536)
Browse files Browse the repository at this point in the history
  • Loading branch information
eolivelli committed Oct 5, 2023
1 parent 9dff827 commit f863f2a
Show file tree
Hide file tree
Showing 27 changed files with 976 additions and 221 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,33 @@
*/
package ai.langstream.ai.agents.commons;

import ai.langstream.api.runner.code.Header;
import ai.langstream.api.runner.code.Record;
import ai.langstream.api.runner.code.SimpleRecord;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.sql.Time;
import java.sql.Timestamp;
import java.time.Instant;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.LocalTime;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Date;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.Data;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
Expand Down Expand Up @@ -358,4 +371,172 @@ public static Object safeClone(Object object) {
}
throw new UnsupportedOperationException("Cannot copy a value of " + object.getClass());
}

public static TransformContext recordToTransformContext(
Record record, boolean attemptJsonConversion) {
TransformContext context = new TransformContext();
context.setKeyObject(record.key());
context.setKeySchemaType(
record.key() == null ? null : getSchemaType(record.key().getClass()));
// TODO: temporary hack. We should be able to get the schema from the record
if (record.key() instanceof GenericRecord) {
context.setKeyNativeSchema(((GenericRecord) record.key()).getSchema());
}
context.setValueObject(record.value());
context.setValueSchemaType(
record.value() == null ? null : getSchemaType(record.value().getClass()));
// TODO: temporary hack. We should be able to get the schema from the record
if (record.value() instanceof GenericRecord) {
context.setKeyNativeSchema(((GenericRecord) record.value()).getSchema());
}
context.setInputTopic(record.origin());
context.setEventTime(record.timestamp());
if (attemptJsonConversion) {
context.setKeyObject(attemptJsonConversion(context.getKeyObject()));
context.setValueObject(attemptJsonConversion(context.getValueObject()));
}
// the headers must be Strings, this is a tentative conversion
// in the future we need a better way to handle headers
context.setProperties(
record.headers().stream()
.filter(h -> h.key() != null && h.value() != null)
.collect(
Collectors.toMap(
Header::key,
(h -> {
if (h.value() == null) {
return null;
}
if (h.value() instanceof byte[]) {
return new String(
(byte[]) h.value(), StandardCharsets.UTF_8);
} else {
return h.value().toString();
}
}))));
return context;
}

public static Optional<Record> transformContextToRecord(TransformContext context) {
if (context.isDropCurrentRecord()) {
return Optional.empty();
}
List<Header> headers = new ArrayList<>();
context.getProperties()
.forEach(
(key, value) -> {
SimpleRecord.SimpleHeader header =
new SimpleRecord.SimpleHeader(key, value);
headers.add(header);
});
return Optional.of(new TransformRecord(context, headers));
}

private record TransformRecord(TransformContext context, Collection<Header> headers)
implements Record {
private TransformRecord(TransformContext context, Collection<Header> headers) {
this.context = context;
this.headers = new ArrayList<>(headers);
}

@Override
public Object key() {
return context.getKeyObject();
}

@Override
public Object value() {
return context.getValueObject();
}

@Override
public String origin() {
return context.getInputTopic();
}

@Override
public Long timestamp() {
return context.getEventTime();
}
}

private static TransformSchemaType getSchemaType(Class<?> javaType) {
if (String.class.isAssignableFrom(javaType)) {
return TransformSchemaType.STRING;
}
if (Byte.class.isAssignableFrom(javaType)) {
return TransformSchemaType.INT8;
}
if (Short.class.isAssignableFrom(javaType)) {
return TransformSchemaType.INT16;
}
if (Integer.class.isAssignableFrom(javaType)) {
return TransformSchemaType.INT32;
}
if (Long.class.isAssignableFrom(javaType)) {
return TransformSchemaType.INT64;
}
if (Double.class.isAssignableFrom(javaType)) {
return TransformSchemaType.DOUBLE;
}
if (Float.class.isAssignableFrom(javaType)) {
return TransformSchemaType.FLOAT;
}
if (Boolean.class.isAssignableFrom(javaType)) {
return TransformSchemaType.BOOLEAN;
}
if (byte[].class.isAssignableFrom(javaType)) {
return TransformSchemaType.BYTES;
}
// Must be before DATE
if (Time.class.isAssignableFrom(javaType)) {
return TransformSchemaType.TIME;
}
// Must be before DATE
if (Timestamp.class.isAssignableFrom(javaType)) {
return TransformSchemaType.TIMESTAMP;
}
if (Date.class.isAssignableFrom(javaType)) {
return TransformSchemaType.DATE;
}
if (Instant.class.isAssignableFrom(javaType)) {
return TransformSchemaType.INSTANT;
}
if (LocalDate.class.isAssignableFrom(javaType)) {
return TransformSchemaType.LOCAL_DATE;
}
if (LocalTime.class.isAssignableFrom(javaType)) {
return TransformSchemaType.LOCAL_TIME;
}
if (LocalDateTime.class.isAssignableFrom(javaType)) {
return TransformSchemaType.LOCAL_DATE_TIME;
}
if (GenericRecord.class.isAssignableFrom(javaType)) {
return TransformSchemaType.AVRO;
}
if (JsonNode.class.isAssignableFrom(javaType)) {
return TransformSchemaType.JSON;
}
if (Map.class.isAssignableFrom(javaType)) {
return TransformSchemaType.MAP;
}
throw new IllegalArgumentException("Unsupported data type: " + javaType);
}

public static Object attemptJsonConversion(Object value) {
try {
if (value instanceof String) {
return OBJECT_MAPPER.readValue(
(String) value, new TypeReference<Map<String, Object>>() {});
} else if (value instanceof byte[]) {
return OBJECT_MAPPER.readValue(
(byte[]) value, new TypeReference<Map<String, Object>>() {});
}
} catch (IOException e) {
if (log.isDebugEnabled()) {
log.debug("Cannot convert value to json", e);
}
}
return value;
}
}
91 changes: 91 additions & 0 deletions langstream-agents/langstream-agents-flow-control/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
Copyright DataStax, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<groupId>ai.langstream</groupId>
<artifactId>langstream-agents</artifactId>
<version>0.1.1-SNAPSHOT</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>langstream-agents-flow-control</artifactId>
<packaging>jar</packaging>
<name>LangStream - Flow control agents</name>
<properties>
<maven.compiler.source>17</maven.compiler.source>
<maven.compiler.target>17</maven.compiler.target>
</properties>
<dependencies>
<dependency>
<groupId>${project.groupId}</groupId>
<artifactId>langstream-api</artifactId>
<version>${project.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>${project.groupId}</groupId>
<artifactId>langstream-agents-commons</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-core</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-nar-maven-plugin</artifactId>
<extensions>true</extensions>
<configuration>
<classifier>nar</classifier>
</configuration>
<executions>
<execution>
<id>default-nar</id>
<phase>package</phase>
<goals>
<goal>nar</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>
Loading

0 comments on commit f863f2a

Please sign in to comment.