Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add integration module with quarkus-websockets-next #487

Merged
merged 2 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions core/deployment/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@
<version>${project.version}</version>
</dependency>

<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-websockets-next-deployment</artifactId>
<version>${project.version}</version>
<optional>true</optional> <!-- conditional dependency -->
</dependency>

<dependency>
<groupId>org.bouncycastle</groupId>
<artifactId>bcutil-jdk18on</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem;
import io.quarkiverse.langchain4j.runtime.AiServicesRecorder;
import io.quarkiverse.langchain4j.runtime.NamedModelUtil;
import io.quarkiverse.langchain4j.runtime.RequestScopeStateDefaultMemoryIdProvider;
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceClassCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodImplementationSupport;
Expand All @@ -65,16 +66,17 @@
import io.quarkiverse.langchain4j.runtime.aiservice.MetricsTimedWrapper;
import io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceContext;
import io.quarkiverse.langchain4j.runtime.aiservice.SpanWrapper;
import io.quarkiverse.langchain4j.spi.DefaultMemoryIdProvider;
import io.quarkus.arc.Arc;
import io.quarkus.arc.ArcContainer;
import io.quarkus.arc.InstanceHandle;
import io.quarkus.arc.deployment.AdditionalBeanBuildItem;
import io.quarkus.arc.deployment.CustomScopeAnnotationsBuildItem;
import io.quarkus.arc.deployment.GeneratedBeanBuildItem;
import io.quarkus.arc.deployment.GeneratedBeanGizmoAdaptor;
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
import io.quarkus.arc.deployment.UnremovableBeanBuildItem;
import io.quarkus.arc.processor.BuiltinScope;
import io.quarkus.arc.processor.ScopeInfo;
import io.quarkus.builder.item.MultiBuildItem;
import io.quarkus.deployment.Capabilities;
import io.quarkus.deployment.Capability;
Expand All @@ -86,6 +88,7 @@
import io.quarkus.deployment.builditem.CombinedIndexBuildItem;
import io.quarkus.deployment.builditem.GeneratedClassBuildItem;
import io.quarkus.deployment.builditem.nativeimage.ReflectiveClassBuildItem;
import io.quarkus.deployment.builditem.nativeimage.ServiceProviderBuildItem;
import io.quarkus.deployment.metrics.MetricsCapabilityBuildItem;
import io.quarkus.gizmo.ClassCreator;
import io.quarkus.gizmo.ClassOutput;
Expand Down Expand Up @@ -130,7 +133,8 @@ public class AiServicesProcessor {
@BuildStep
public void nativeSupport(CombinedIndexBuildItem indexBuildItem,
List<AiServicesMethodBuildItem> aiServicesMethodBuildItems,
BuildProducer<ReflectiveClassBuildItem> reflectiveClassProducer) {
BuildProducer<ReflectiveClassBuildItem> reflectiveClassProducer,
BuildProducer<ServiceProviderBuildItem> serviceProviderProducer) {
IndexView index = indexBuildItem.getIndex();
Collection<AnnotationInstance> instances = index.getAnnotations(LangChain4jDotNames.DESCRIPTION);
Set<ClassInfo> classesUsingDescription = new HashSet<>();
Expand Down Expand Up @@ -163,10 +167,14 @@ public void nativeSupport(CombinedIndexBuildItem indexBuildItem,
.constructors(false)
.build());
}

serviceProviderProducer.produce(new ServiceProviderBuildItem(DefaultMemoryIdProvider.class.getName(),
RequestScopeStateDefaultMemoryIdProvider.class.getName()));
}

@BuildStep
public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
CustomScopeAnnotationsBuildItem customScopes,
BuildProducer<RequestChatModelBeanBuildItem> requestChatModelBeanProducer,
BuildProducer<RequestModerationModelBeanBuildItem> requestModerationModelBeanProducer,
BuildProducer<DeclarativeAiServiceBuildItem> declarativeAiServiceProducer,
Expand Down Expand Up @@ -299,8 +307,11 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
}
}

BuiltinScope declaredScope = BuiltinScope.from(declarativeAiServiceClassInfo);
ScopeInfo cdiScope = declaredScope != null ? declaredScope.getInfo() : BuiltinScope.REQUEST.getInfo();
DotName cdiScope = BuiltinScope.REQUEST.getInfo().getDotName();
Optional<AnnotationInstance> scopeAnnotation = customScopes.getScope(declarativeAiServiceClassInfo.annotations());
if (scopeAnnotation.isPresent()) {
cdiScope = scopeAnnotation.get().name();
}

declarativeAiServiceProducer.produce(
new DeclarativeAiServiceBuildItem(
Expand Down Expand Up @@ -670,12 +681,12 @@ public void handleAiServices(AiServicesRecorder recorder,
try (ClassCreator classCreator = classCreatorBuilder.build()) {
if (isRegisteredService) {
// we need to make this a bean, so we need to add the proper scope annotation
ScopeInfo scopeInfo = declarativeAiServiceItems.stream()
DotName scopeInfo = declarativeAiServiceItems.stream()
.filter(bi -> bi.getServiceClassInfo().equals(iface))
.findFirst().orElseThrow(() -> new IllegalStateException(
"Unable to determine the CDI scope of " + iface))
.getCdiScope();
classCreator.addAnnotation(scopeInfo.getDotName().toString());
classCreator.addAnnotation(scopeInfo.toString());
}

FieldDescriptor contextField = classCreator.getFieldCreator("context", QuarkusAiServiceContext.class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import org.jboss.jandex.ClassInfo;
import org.jboss.jandex.DotName;

import io.quarkus.arc.processor.ScopeInfo;
import io.quarkus.builder.item.MultiBuildItem;

/**
Expand All @@ -23,7 +22,7 @@ public final class DeclarativeAiServiceBuildItem extends MultiBuildItem {
private final boolean customRetrievalAugmentorSupplierClassIsABean;
private final DotName auditServiceClassSupplierDotName;
private final DotName moderationModelSupplierDotName;
private final ScopeInfo cdiScope;
private final DotName cdiScope;
private final String chatModelName;
private final String moderationModelName;

Expand All @@ -35,7 +34,7 @@ public DeclarativeAiServiceBuildItem(ClassInfo serviceClassInfo, DotName languag
boolean customRetrievalAugmentorSupplierClassIsABean,
DotName auditServiceClassSupplierDotName,
DotName moderationModelSupplierDotName,
ScopeInfo cdiScope,
DotName cdiScope,
String chatModelName,
String moderationModelName) {
this.serviceClassInfo = serviceClassInfo;
Expand Down Expand Up @@ -88,7 +87,7 @@ public DotName getModerationModelSupplierDotName() {
return moderationModelSupplierDotName;
}

public ScopeInfo getCdiScope() {
public DotName getCdiScope() {
return cdiScope;
}

Expand Down
1 change: 1 addition & 0 deletions core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
<modules>
<module>deployment</module>
<module>runtime</module>
<module>runtime-spi</module>
</modules>

</project>
16 changes: 16 additions & 0 deletions core/runtime-spi/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-core-parent</artifactId>
<version>999-SNAPSHOT</version>
</parent>


<artifactId>quarkus-langchain4j-core-runtime-spi</artifactId>
<name>Quarkus LangChain4j - Core - Runtime - SPI</name>

</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package io.quarkiverse.langchain4j.spi;

/**
* Quarkus extension can decide whether they can provide a default value for the memory ID object if none was explicitly
* provided.
* <p>
* The idea behind this is that depending on the type of request that is being served, Quarkus can determine a unique
* (per request) object to be used.
*/
public interface DefaultMemoryIdProvider {
int DEFAULT_PRIORITY = 0;

/**
* Defines the priority of the providers.
* A lower integer value means that the customizer will be considered before one with a higher priority
*/
default int priority() {
return DEFAULT_PRIORITY;
}

/**
* Determines the object to be used as the default memory ID.
* A value of {@code null} means that the provider is not going to give a value and therefore he next
* provider should be tried.
*/
Object getMemoryId();
}
14 changes: 14 additions & 0 deletions core/runtime/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,20 @@
<groupId>io.quarkus</groupId>
<artifactId>quarkus-vertx</artifactId>
</dependency>

<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-core-runtime-spi</artifactId>
<version>${project.version}</version>
</dependency>

<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-websockets-next</artifactId>
<version>${project.version}</version>
<optional>true</optional> <!-- conditional dependency -->
</dependency>

<dependency>
<groupId>io.micrometer</groupId>
<artifactId>micrometer-core</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package io.quarkiverse.langchain4j.runtime;

import io.quarkiverse.langchain4j.spi.DefaultMemoryIdProvider;
import io.quarkus.arc.Arc;
import io.quarkus.arc.ArcContainer;
import io.quarkus.arc.ManagedContext;

/**
* This implementation uses the state of the request scope as the default value
*/
public class RequestScopeStateDefaultMemoryIdProvider implements DefaultMemoryIdProvider {

@Override
public int priority() {
return DefaultMemoryIdProvider.DEFAULT_PRIORITY + 100;
}

@Override
public Object getMemoryId() {
ArcContainer container = Arc.container();
if (container == null) {
return null;
}
ManagedContext requestContext = container.requestContext();
if (requestContext.isActive()) {
return requestContext.getState();
}
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -40,11 +42,10 @@
import dev.langchain4j.service.AiServiceContext;
import dev.langchain4j.service.AiServiceTokenStream;
import dev.langchain4j.service.TokenStream;
import dev.langchain4j.spi.ServiceHelper;
import io.quarkiverse.langchain4j.audit.Audit;
import io.quarkiverse.langchain4j.audit.AuditService;
import io.quarkus.arc.Arc;
import io.quarkus.arc.ArcContainer;
import io.quarkus.arc.ManagedContext;
import io.quarkiverse.langchain4j.spi.DefaultMemoryIdProvider;
import io.smallrye.mutiny.Multi;
import io.smallrye.mutiny.infrastructure.Infrastructure;
import io.smallrye.mutiny.subscription.MultiEmitter;
Expand All @@ -58,6 +59,24 @@ public class AiServiceMethodImplementationSupport {

private static final int MAX_SEQUENTIAL_TOOL_EXECUTIONS = 10;

private static final List<DefaultMemoryIdProvider> DEFAULT_MEMORY_ID_PROVIDERS;

static {
var defaultMemoryIdProviders = ServiceHelper.loadFactories(
DefaultMemoryIdProvider.class);
if (defaultMemoryIdProviders.isEmpty()) {
DEFAULT_MEMORY_ID_PROVIDERS = Collections.emptyList();
} else {
DEFAULT_MEMORY_ID_PROVIDERS = new ArrayList<>(defaultMemoryIdProviders);
DEFAULT_MEMORY_ID_PROVIDERS.sort(new Comparator<>() {
@Override
public int compare(DefaultMemoryIdProvider o1, DefaultMemoryIdProvider o2) {
return Integer.compare(o1.priority(), o2.priority());
}
});
}
}

/**
* This method is called by the implementations of each ai service method.
*/
Expand Down Expand Up @@ -322,16 +341,16 @@ private static Object memoryId(AiServiceMethodCreateInfo createInfo, Object[] me
if (createInfo.getMemoryIdParamPosition().isPresent()) {
return methodArgs[createInfo.getMemoryIdParamPosition().get()];
}

if (hasChatMemoryProvider) {
// first we try to use the current context in order to make sure that we don't interleave chat messages of concurrent requests
ArcContainer container = Arc.container();
if (container != null) {
ManagedContext requestContext = container.requestContext();
if (requestContext.isActive()) {
return requestContext.getState();
for (DefaultMemoryIdProvider provider : DEFAULT_MEMORY_ID_PROVIDERS) {
Object memoryId = provider.getMemoryId();
if (memoryId != null) {
return memoryId;
}
}
}

// fallback to the default since there is nothing else we can really use here
return "default";
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
io.quarkiverse.langchain4j.runtime.RequestScopeStateDefaultMemoryIdProvider
1 change: 1 addition & 0 deletions docs/modules/ROOT/nav.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,5 @@

* Advanced topics
** xref:fault-tolerance.adoc[Fault Tolerance]
** xref:websockets.adoc[WebSockets]
** xref:enable-disable-integrations.adoc[Enabling and Disabling Integrations]
70 changes: 70 additions & 0 deletions docs/modules/ROOT/pages/websockets.adoc
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
= Using AI services in WebSockets

include::./includes/attributes.adoc[]

Using a chatbot in a WebSockets environment is quite common, which is why the extension provides a few facilities to make such usages as easy as possible.

* 1. Start by adding the link:https://quarkus.io/guides/websockets-next-tutorial[quarkus-websockets-next] dependency to your `pom.xml` file:

[source,xml]
----
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-websockets-next</artifactId>
</dependency>
----

IMPORTANT: `quarkus-websockets-next` is available as of Quarkus 3.9.

* 2. Annotated your AiService with `@SessionScoped`

[source,java]
----
import dev.langchain4j.service.SystemMessage;
import io.quarkiverse.langchain4j.RegisterAiService;
import jakarta.enterprise.context.SessionScoped;

@RegisterAiService
@SessionScoped
public interface SessionScopedChatBot {

@SystemMessage("You are chatbot that helps users with their queries")
String chat(String message);
}
----

* 3. Create a WebSocket endpoint

[source,java]
----
import io.quarkus.websockets.next.OnOpen;
import io.quarkus.websockets.next.OnTextMessage;
import io.quarkus.websockets.next.WebSocket;

@WebSocket(path = "/websocket")
public static class WebSocketChatBot {

private final SessionScopedChatBot bot;

public WebSocketChatBot(SessionScopedChatBot bot) {
this.bot = bot;
}

@OnOpen
public String onOpen() {
return bot.chat("Hello, how can I help you?");
}

@OnTextMessage
public String onMessage(String message) {
return bot.chat(message);
}
}
----

Two things are important to note in the snippets above:

* There is no `@MemoryId` field being used in the AI service. Quarkus will automatically all the WebSocket connection ID as the memory ID.
This ensures that each WebSocket session has its own chat memory.
* The use of `@SessionScoped` is important as the scope of the AI service is tied to the scope of the WebSocket endpoint.
This allows Quarkus to automatically clear chat memory when the WebSocket connection is closed for whatever reason.
jmartisk marked this conversation as resolved.
Show resolved Hide resolved
Loading
Loading