Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -783,8 +783,7 @@ private Optional<Integer> gatherMemoryIdParamName(MethodInfo method) {
}

private AiServiceMethodCreateInfo.UserMessageInfo gatherUserMessageInfo(MethodInfo method,
List<TemplateParameterInfo> templateParams,
Class<?> returnType) {
List<TemplateParameterInfo> templateParams, Class<?> returnType) {
String outputFormatInstructions = outputFormatInstructions(returnType);

Optional<Integer> userNameParamName = method.annotations(Langchain4jDotNames.USER_NAME).stream().filter(
Expand All @@ -794,8 +793,7 @@ private AiServiceMethodCreateInfo.UserMessageInfo gatherUserMessageInfo(MethodIn
if (userMessageInstance != null) {
AnnotationValue delimiterValue = userMessageInstance.value("delimiter");
String delimiter = delimiterValue != null ? delimiterValue.asString() : DEFAULT_DELIMITER;
String userMessageTemplate = String.join(delimiter, userMessageInstance.value().asStringArray())
+ outputFormatInstructions;
String userMessageTemplate = String.join(delimiter, userMessageInstance.value().asStringArray());

if (userMessageTemplate.contains("{{it}}")) {
if (method.parametersCount() != 1) {
Expand All @@ -810,25 +808,22 @@ private AiServiceMethodCreateInfo.UserMessageInfo gatherUserMessageInfo(MethodIn
return AiServiceMethodCreateInfo.UserMessageInfo.fromTemplate(
new AiServiceMethodCreateInfo.TemplateInfo(userMessageTemplate,
TemplateParameterInfo.toNameToArgsPositionMap(templateParams)),
userNameParamName);
userNameParamName, outputFormatInstructions);
} else {
Optional<AnnotationInstance> userMessageOnMethodParam = method.annotations(Langchain4jDotNames.USER_MESSAGE)
.stream()
.filter(IS_METHOD_PARAMETER_ANNOTATION).findFirst();
if (userMessageOnMethodParam.isPresent()) {
return AiServiceMethodCreateInfo.UserMessageInfo.fromMethodParam(
userMessageOnMethodParam.get().target().asMethodParameter().position(),
outputFormatInstructions,
userNameParamName);
userNameParamName, outputFormatInstructions);
} else {
if (method.parametersCount() == 0) {
throw illegalConfigurationForMethod("Method should have at least one argument", method);
}
if (method.parametersCount() == 1) {
return AiServiceMethodCreateInfo.UserMessageInfo.fromMethodParam(
0,
outputFormatInstructions,
userNameParamName);
return AiServiceMethodCreateInfo.UserMessageInfo.fromMethodParam(0, userNameParamName,
outputFormatInstructions);
}

throw illegalConfigurationForMethod(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ChatMessageSerializer;
import dev.langchain4j.data.message.ImageContent;
import dev.langchain4j.data.message.UserMessage;
import io.quarkus.test.QuarkusUnitTest;

Expand Down Expand Up @@ -54,6 +55,16 @@ void should_serialize_and_deserialize_user_message_without_name() {
assertThat(deserializedMessage).isEqualTo(message);
}

@Test
void should_serialize_and_deserialize_user_message_with_image_content() {
UserMessage message = UserMessage.from(ImageContent.from("http://image.url"));

String json = messageToJson(message);
ChatMessage deserializedMessage = messageFromJson(json);

assertThat(deserializedMessage).isEqualTo(message);
}

@Test
void should_serialize_and_deserialize_empty_list() {

Expand All @@ -76,7 +87,7 @@ void should_serialize_and_deserialize_list_with_one_message() {
List<ChatMessage> messages = singletonList(userMessage("hello"));

String json = messagesToJson(messages);
assertThat(json).isEqualTo("[{\"text\":\"hello\",\"type\":\"USER\"}]");
assertThat(json).isEqualTo("[{\"contents\":[{\"text\":\"hello\",\"type\":\"TEXT\"}],\"type\":\"USER\"}]");

List<ChatMessage> deserializedMessages = messagesFromJson(json);
assertThat(deserializedMessages).isEqualTo(messages);
Expand All @@ -99,8 +110,8 @@ void should_serialize_and_deserialize_list_with_all_types_of_messages() {
String json = ChatMessageSerializer.messagesToJson(messages);
assertThat(json).isEqualTo("[" +
"{\"text\":\"Hello from system\",\"type\":\"SYSTEM\"}," +
"{\"text\":\"Hello from user\",\"type\":\"USER\"}," +
"{\"name\":\"Klaus\",\"text\":\"Hello from Klaus\",\"type\":\"USER\"}," +
"{\"contents\":[{\"text\":\"Hello from user\",\"type\":\"TEXT\"}],\"type\":\"USER\"}," +
"{\"name\":\"Klaus\",\"contents\":[{\"text\":\"Hello from Klaus\",\"type\":\"TEXT\"}],\"type\":\"USER\"}," +
"{\"text\":\"Hello from AI\",\"type\":\"AI\"}," +
"{\"toolExecutionRequests\":[{\"name\":\"calculator\",\"arguments\":\"{}\"}],\"type\":\"AI\"}," +
"{\"text\":\"4\",\"id\":\"12345\",\"toolName\":\"calculator\",\"type\":\"TOOL_EXECUTION_RESULT\"}" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,26 +85,28 @@ public Optional<SpanInfo> getSpanInfo() {
public static class UserMessageInfo {
private final Optional<TemplateInfo> template;
private final Optional<Integer> paramPosition;
private final Optional<String> instructions;
private final Optional<Integer> userNameParamPosition;
private final String outputFormatInstructions;

@RecordableConstructor
public UserMessageInfo(Optional<TemplateInfo> template, Optional<Integer> paramPosition, Optional<String> instructions,
Optional<Integer> userNameParamPosition) {
public UserMessageInfo(Optional<TemplateInfo> template, Optional<Integer> paramPosition,
Optional<Integer> userNameParamPosition, String outputFormatInstructions) {
this.template = template;
this.paramPosition = paramPosition;
this.instructions = instructions;
this.userNameParamPosition = userNameParamPosition;
this.outputFormatInstructions = outputFormatInstructions == null ? "" : outputFormatInstructions;
}

public static UserMessageInfo fromMethodParam(int paramPosition, String instructions,
Optional<Integer> userNameParamPosition) {
return new UserMessageInfo(Optional.empty(), Optional.of(paramPosition), Optional.of(instructions),
userNameParamPosition);
public static UserMessageInfo fromMethodParam(int paramPosition, Optional<Integer> userNameParamPosition,
String outputFormatInstructions) {
return new UserMessageInfo(Optional.empty(), Optional.of(paramPosition),
userNameParamPosition, outputFormatInstructions);
}

public static UserMessageInfo fromTemplate(TemplateInfo templateInfo, Optional<Integer> userNameParamPosition) {
return new UserMessageInfo(Optional.of(templateInfo), Optional.empty(), Optional.empty(), userNameParamPosition);
public static UserMessageInfo fromTemplate(TemplateInfo templateInfo, Optional<Integer> userNameParamPosition,
String outputFormatInstructions) {
return new UserMessageInfo(Optional.of(templateInfo), Optional.empty(), userNameParamPosition,
outputFormatInstructions);
}

public Optional<TemplateInfo> getTemplate() {
Expand All @@ -115,13 +117,13 @@ public Optional<Integer> getParamPosition() {
return paramPosition;
}

public Optional<String> getInstructions() {
return instructions;
}

public Optional<Integer> getUserNameParamPosition() {
return userNameParamPosition;
}

public String getOutputFormatInstructions() {
return outputFormatInstructions;
}
}

public static class TemplateInfo {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
package io.quarkiverse.langchain4j.runtime.aiservice;

import static dev.langchain4j.data.message.ToolExecutionResultMessage.toolExecutionResultMessage;
import static dev.langchain4j.data.message.UserMessage.userMessage;
import static dev.langchain4j.internal.Exceptions.runtime;
import static dev.langchain4j.service.AiServices.removeToolMessages;
import static dev.langchain4j.service.AiServices.verifyModerationIfNeeded;
import static dev.langchain4j.service.ServiceOutputParser.parse;
import static java.util.stream.Collectors.joining;

import java.lang.reflect.Array;
import java.util.ArrayList;
Expand All @@ -29,7 +27,6 @@
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
Expand All @@ -38,6 +35,7 @@
import dev.langchain4j.model.moderation.Moderation;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.rag.query.Metadata;
import dev.langchain4j.service.AiServiceContext;
import dev.langchain4j.service.AiServiceTokenStream;
import dev.langchain4j.service.TokenStream;
Expand Down Expand Up @@ -96,29 +94,19 @@ private static Object doImplement(AiServiceMethodCreateInfo createInfo, Object[]
audit.initialMessages(systemMessage, userMessage);
}

if (context.retriever != null) { // TODO extract method/class
List<TextSegment> relevant = context.retriever.findRelevant(userMessage.text());

if (relevant == null || relevant.isEmpty()) {
log.debug("No relevant information was found");
} else {
String relevantConcatenated = relevant.stream()
.map(TextSegment::text)
.collect(joining("\n\n"));

log.debugv("Retrieved relevant information:\n{0}\n", relevantConcatenated);

userMessage = userMessage(userMessage.text()
+ "\n\nHere is some information that might be useful for answering:\n\n"
+ relevantConcatenated);
Object memoryId = memoryId(createInfo, methodArgs).orElse("default");

if (audit != null) {
audit.addRelevantDocument(relevant, userMessage);
}
}
if (context.retrievalAugmentor != null) { // TODO extract method/class
List<ChatMessage> chatMemory = context.hasChatMemory()
? context.chatMemory(memoryId).messages()
: null;
Metadata metadata = Metadata.from(userMessage, memoryId, chatMemory);
userMessage = context.retrievalAugmentor.augment(userMessage, metadata);
}

Object memoryId = memoryId(createInfo, methodArgs).orElse("default");
// TODO give user ability to provide custom OutputParser
String outputFormatInstructions = createInfo.getUserMessageInfo().getOutputFormatInstructions();
userMessage = UserMessage.from(userMessage.text() + outputFormatInstructions);

if (context.hasChatMemory()) {
ChatMemory chatMemory = context.chatMemory(memoryId);
Expand Down Expand Up @@ -267,7 +255,7 @@ private static UserMessage prepareUserMessage(AiServiceContext context, AiServic
// we do not need to apply the instructions as they have already been added to the template text at build time
Prompt prompt = PromptTemplate.from(templateInfo.getText()).apply(templateParams);

return userMessage(userName, prompt.text());
return createUserMessage(userName, prompt.text());
} else if (userMessageInfo.getParamPosition().isPresent()) {
Integer paramIndex = userMessageInfo.getParamPosition().get();
Object argValue = methodArgs[paramIndex];
Expand All @@ -277,13 +265,21 @@ private static UserMessage prepareUserMessage(AiServiceContext context, AiServic
+ "' because parameter with index "
+ paramIndex + " is null");
}
return userMessage(userName, toString(argValue) + userMessageInfo.getInstructions().orElse(""));
return createUserMessage(userName, toString(argValue));
} else {
throw new IllegalStateException("Unable to construct UserMessage for class '" + context.aiServiceClass.getName()
+ "'. Please contact the maintainers");
}
}

private static UserMessage createUserMessage(String name, String text) {
if (name == null) {
return userMessage(text);
} else {
return userMessage(name, text);
}
}

private static Object transformTemplateParamValue(Object value) {
if (value.getClass().isArray()) {
// Qute does not transform these values but Langchain4j expects to be converted to a [item1, item2, item3] like systax
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package io.quarkiverse.langchain4j.runtime.jackson;

import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonSubTypes;
import com.fasterxml.jackson.annotation.JsonTypeInfo;

import dev.langchain4j.data.message.Content;
import dev.langchain4j.data.message.ContentType;
import dev.langchain4j.data.message.ImageContent;
import dev.langchain4j.data.message.TextContent;
import io.quarkus.jackson.JacksonMixin;

@JacksonMixin(Content.class)
@JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, getterVisibility = JsonAutoDetect.Visibility.NONE, setterVisibility = JsonAutoDetect.Visibility.NONE)
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.EXISTING_PROPERTY, property = "type")
@JsonSubTypes({
@JsonSubTypes.Type(value = TextContent.class, name = "TEXT"),
@JsonSubTypes.Type(value = ImageContent.class, name = "IMAGE"),
})
public abstract class ContextMixin {

@JsonProperty
public abstract ContentType type();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package io.quarkiverse.langchain4j.runtime.jackson;

import com.fasterxml.jackson.databind.annotation.JsonPOJOBuilder;

import dev.langchain4j.data.image.Image;
import io.quarkus.jackson.JacksonMixin;

@JacksonMixin(Image.Builder.class)
@JsonPOJOBuilder(withPrefix = "")
public abstract class ImageBuilderMixin {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package io.quarkiverse.langchain4j.runtime.jackson;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;

import dev.langchain4j.data.image.Image;
import dev.langchain4j.data.message.ImageContent;
import io.quarkus.jackson.JacksonMixin;

@JacksonMixin(ImageContent.class)
public abstract class ImageContentMixin {

@JsonCreator
public ImageContentMixin(@JsonProperty("image") Image image,
@JsonProperty("detailLevel") ImageContent.DetailLevel detailLevel) {
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package io.quarkiverse.langchain4j.runtime.jackson;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;

import dev.langchain4j.data.image.Image;
import io.quarkus.jackson.JacksonMixin;

@JacksonMixin(Image.class)
@JsonInclude(JsonInclude.Include.NON_NULL)
@JsonDeserialize(builder = Image.Builder.class)
public abstract class ImageMixin {

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package io.quarkiverse.langchain4j.runtime.jackson;

import java.io.IOException;

import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.deser.std.StdDeserializer;

import dev.langchain4j.data.message.TextContent;

public class TextContentDeserializer extends StdDeserializer<TextContent> {

public TextContentDeserializer() {
super(TextContent.class);
}

@Override
public TextContent deserialize(JsonParser p, DeserializationContext deserializationContext)
throws IOException {
JsonNode node = p.getCodec().readTree(p);
return new TextContent(node.get("text").asText());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package io.quarkiverse.langchain4j.runtime.jackson;

import com.fasterxml.jackson.databind.annotation.JsonDeserialize;

import dev.langchain4j.data.message.TextContent;
import io.quarkus.jackson.JacksonMixin;

@JacksonMixin(TextContent.class)
@JsonDeserialize(using = TextContentDeserializer.class)
public abstract class TextContentMixin {

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public class ToolExecutionResultMessageMixin {

@JsonCreator
public ToolExecutionResultMessageMixin(@JsonProperty("id") String id, @JsonProperty("toolName") String toolName,
@JsonProperty("toolExecutionResult") String toolExecutionResult) {
@JsonProperty("text") String text) {

}
}
Loading