Skip to content

Commit

Permalink
Merge pull request #487 from quarkiverse/websockets
Browse files Browse the repository at this point in the history
Add integration module with quarkus-websockets-next
  • Loading branch information
geoand committed Apr 24, 2024
2 parents 6e7957d + 8efb4ab commit 8c4429e
Show file tree
Hide file tree
Showing 24 changed files with 790 additions and 20 deletions.
7 changes: 7 additions & 0 deletions core/deployment/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@
<version>${project.version}</version>
</dependency>

<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-websockets-next-deployment</artifactId>
<version>${project.version}</version>
<optional>true</optional> <!-- conditional dependency -->
</dependency>

<dependency>
<groupId>org.bouncycastle</groupId>
<artifactId>bcutil-jdk18on</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem;
import io.quarkiverse.langchain4j.runtime.AiServicesRecorder;
import io.quarkiverse.langchain4j.runtime.NamedModelUtil;
import io.quarkiverse.langchain4j.runtime.RequestScopeStateDefaultMemoryIdProvider;
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceClassCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodImplementationSupport;
Expand All @@ -65,16 +66,17 @@
import io.quarkiverse.langchain4j.runtime.aiservice.MetricsTimedWrapper;
import io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceContext;
import io.quarkiverse.langchain4j.runtime.aiservice.SpanWrapper;
import io.quarkiverse.langchain4j.spi.DefaultMemoryIdProvider;
import io.quarkus.arc.Arc;
import io.quarkus.arc.ArcContainer;
import io.quarkus.arc.InstanceHandle;
import io.quarkus.arc.deployment.AdditionalBeanBuildItem;
import io.quarkus.arc.deployment.CustomScopeAnnotationsBuildItem;
import io.quarkus.arc.deployment.GeneratedBeanBuildItem;
import io.quarkus.arc.deployment.GeneratedBeanGizmoAdaptor;
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
import io.quarkus.arc.deployment.UnremovableBeanBuildItem;
import io.quarkus.arc.processor.BuiltinScope;
import io.quarkus.arc.processor.ScopeInfo;
import io.quarkus.builder.item.MultiBuildItem;
import io.quarkus.deployment.Capabilities;
import io.quarkus.deployment.Capability;
Expand All @@ -86,6 +88,7 @@
import io.quarkus.deployment.builditem.CombinedIndexBuildItem;
import io.quarkus.deployment.builditem.GeneratedClassBuildItem;
import io.quarkus.deployment.builditem.nativeimage.ReflectiveClassBuildItem;
import io.quarkus.deployment.builditem.nativeimage.ServiceProviderBuildItem;
import io.quarkus.deployment.metrics.MetricsCapabilityBuildItem;
import io.quarkus.gizmo.ClassCreator;
import io.quarkus.gizmo.ClassOutput;
Expand Down Expand Up @@ -130,7 +133,8 @@ public class AiServicesProcessor {
@BuildStep
public void nativeSupport(CombinedIndexBuildItem indexBuildItem,
List<AiServicesMethodBuildItem> aiServicesMethodBuildItems,
BuildProducer<ReflectiveClassBuildItem> reflectiveClassProducer) {
BuildProducer<ReflectiveClassBuildItem> reflectiveClassProducer,
BuildProducer<ServiceProviderBuildItem> serviceProviderProducer) {
IndexView index = indexBuildItem.getIndex();
Collection<AnnotationInstance> instances = index.getAnnotations(LangChain4jDotNames.DESCRIPTION);
Set<ClassInfo> classesUsingDescription = new HashSet<>();
Expand Down Expand Up @@ -163,10 +167,14 @@ public void nativeSupport(CombinedIndexBuildItem indexBuildItem,
.constructors(false)
.build());
}

serviceProviderProducer.produce(new ServiceProviderBuildItem(DefaultMemoryIdProvider.class.getName(),
RequestScopeStateDefaultMemoryIdProvider.class.getName()));
}

@BuildStep
public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
CustomScopeAnnotationsBuildItem customScopes,
BuildProducer<RequestChatModelBeanBuildItem> requestChatModelBeanProducer,
BuildProducer<RequestModerationModelBeanBuildItem> requestModerationModelBeanProducer,
BuildProducer<DeclarativeAiServiceBuildItem> declarativeAiServiceProducer,
Expand Down Expand Up @@ -299,8 +307,11 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
}
}

BuiltinScope declaredScope = BuiltinScope.from(declarativeAiServiceClassInfo);
ScopeInfo cdiScope = declaredScope != null ? declaredScope.getInfo() : BuiltinScope.REQUEST.getInfo();
DotName cdiScope = BuiltinScope.REQUEST.getInfo().getDotName();
Optional<AnnotationInstance> scopeAnnotation = customScopes.getScope(declarativeAiServiceClassInfo.annotations());
if (scopeAnnotation.isPresent()) {
cdiScope = scopeAnnotation.get().name();
}

declarativeAiServiceProducer.produce(
new DeclarativeAiServiceBuildItem(
Expand Down Expand Up @@ -670,12 +681,12 @@ public void handleAiServices(AiServicesRecorder recorder,
try (ClassCreator classCreator = classCreatorBuilder.build()) {
if (isRegisteredService) {
// we need to make this a bean, so we need to add the proper scope annotation
ScopeInfo scopeInfo = declarativeAiServiceItems.stream()
DotName scopeInfo = declarativeAiServiceItems.stream()
.filter(bi -> bi.getServiceClassInfo().equals(iface))
.findFirst().orElseThrow(() -> new IllegalStateException(
"Unable to determine the CDI scope of " + iface))
.getCdiScope();
classCreator.addAnnotation(scopeInfo.getDotName().toString());
classCreator.addAnnotation(scopeInfo.toString());
}

FieldDescriptor contextField = classCreator.getFieldCreator("context", QuarkusAiServiceContext.class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import org.jboss.jandex.ClassInfo;
import org.jboss.jandex.DotName;

import io.quarkus.arc.processor.ScopeInfo;
import io.quarkus.builder.item.MultiBuildItem;

/**
Expand All @@ -23,7 +22,7 @@ public final class DeclarativeAiServiceBuildItem extends MultiBuildItem {
private final boolean customRetrievalAugmentorSupplierClassIsABean;
private final DotName auditServiceClassSupplierDotName;
private final DotName moderationModelSupplierDotName;
private final ScopeInfo cdiScope;
private final DotName cdiScope;
private final String chatModelName;
private final String moderationModelName;

Expand All @@ -35,7 +34,7 @@ public DeclarativeAiServiceBuildItem(ClassInfo serviceClassInfo, DotName languag
boolean customRetrievalAugmentorSupplierClassIsABean,
DotName auditServiceClassSupplierDotName,
DotName moderationModelSupplierDotName,
ScopeInfo cdiScope,
DotName cdiScope,
String chatModelName,
String moderationModelName) {
this.serviceClassInfo = serviceClassInfo;
Expand Down Expand Up @@ -88,7 +87,7 @@ public DotName getModerationModelSupplierDotName() {
return moderationModelSupplierDotName;
}

public ScopeInfo getCdiScope() {
public DotName getCdiScope() {
return cdiScope;
}

Expand Down
1 change: 1 addition & 0 deletions core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
<modules>
<module>deployment</module>
<module>runtime</module>
<module>runtime-spi</module>
</modules>

</project>
16 changes: 16 additions & 0 deletions core/runtime-spi/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-core-parent</artifactId>
<version>999-SNAPSHOT</version>
</parent>


<artifactId>quarkus-langchain4j-core-runtime-spi</artifactId>
<name>Quarkus LangChain4j - Core - Runtime - SPI</name>

</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package io.quarkiverse.langchain4j.spi;

/**
* Quarkus extension can decide whether they can provide a default value for the memory ID object if none was explicitly
* provided.
* <p>
* The idea behind this is that depending on the type of request that is being served, Quarkus can determine a unique
* (per request) object to be used.
*/
public interface DefaultMemoryIdProvider {
int DEFAULT_PRIORITY = 0;

/**
* Defines the priority of the providers.
* A lower integer value means that the customizer will be considered before one with a higher priority
*/
default int priority() {
return DEFAULT_PRIORITY;
}

/**
* Determines the object to be used as the default memory ID.
* A value of {@code null} means that the provider is not going to give a value and therefore he next
* provider should be tried.
*/
Object getMemoryId();
}
14 changes: 14 additions & 0 deletions core/runtime/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,20 @@
<groupId>io.quarkus</groupId>
<artifactId>quarkus-vertx</artifactId>
</dependency>

<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-core-runtime-spi</artifactId>
<version>${project.version}</version>
</dependency>

<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-websockets-next</artifactId>
<version>${project.version}</version>
<optional>true</optional> <!-- conditional dependency -->
</dependency>

<dependency>
<groupId>io.micrometer</groupId>
<artifactId>micrometer-core</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package io.quarkiverse.langchain4j.runtime;

import io.quarkiverse.langchain4j.spi.DefaultMemoryIdProvider;
import io.quarkus.arc.Arc;
import io.quarkus.arc.ArcContainer;
import io.quarkus.arc.ManagedContext;

/**
* This implementation uses the state of the request scope as the default value
*/
public class RequestScopeStateDefaultMemoryIdProvider implements DefaultMemoryIdProvider {

@Override
public int priority() {
return DefaultMemoryIdProvider.DEFAULT_PRIORITY + 100;
}

@Override
public Object getMemoryId() {
ArcContainer container = Arc.container();
if (container == null) {
return null;
}
ManagedContext requestContext = container.requestContext();
if (requestContext.isActive()) {
return requestContext.getState();
}
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -40,11 +42,10 @@
import dev.langchain4j.service.AiServiceContext;
import dev.langchain4j.service.AiServiceTokenStream;
import dev.langchain4j.service.TokenStream;
import dev.langchain4j.spi.ServiceHelper;
import io.quarkiverse.langchain4j.audit.Audit;
import io.quarkiverse.langchain4j.audit.AuditService;
import io.quarkus.arc.Arc;
import io.quarkus.arc.ArcContainer;
import io.quarkus.arc.ManagedContext;
import io.quarkiverse.langchain4j.spi.DefaultMemoryIdProvider;
import io.smallrye.mutiny.Multi;
import io.smallrye.mutiny.infrastructure.Infrastructure;
import io.smallrye.mutiny.subscription.MultiEmitter;
Expand All @@ -58,6 +59,24 @@ public class AiServiceMethodImplementationSupport {

private static final int MAX_SEQUENTIAL_TOOL_EXECUTIONS = 10;

private static final List<DefaultMemoryIdProvider> DEFAULT_MEMORY_ID_PROVIDERS;

static {
var defaultMemoryIdProviders = ServiceHelper.loadFactories(
DefaultMemoryIdProvider.class);
if (defaultMemoryIdProviders.isEmpty()) {
DEFAULT_MEMORY_ID_PROVIDERS = Collections.emptyList();
} else {
DEFAULT_MEMORY_ID_PROVIDERS = new ArrayList<>(defaultMemoryIdProviders);
DEFAULT_MEMORY_ID_PROVIDERS.sort(new Comparator<>() {
@Override
public int compare(DefaultMemoryIdProvider o1, DefaultMemoryIdProvider o2) {
return Integer.compare(o1.priority(), o2.priority());
}
});
}
}

/**
* This method is called by the implementations of each ai service method.
*/
Expand Down Expand Up @@ -322,16 +341,16 @@ private static Object memoryId(AiServiceMethodCreateInfo createInfo, Object[] me
if (createInfo.getMemoryIdParamPosition().isPresent()) {
return methodArgs[createInfo.getMemoryIdParamPosition().get()];
}

if (hasChatMemoryProvider) {
// first we try to use the current context in order to make sure that we don't interleave chat messages of concurrent requests
ArcContainer container = Arc.container();
if (container != null) {
ManagedContext requestContext = container.requestContext();
if (requestContext.isActive()) {
return requestContext.getState();
for (DefaultMemoryIdProvider provider : DEFAULT_MEMORY_ID_PROVIDERS) {
Object memoryId = provider.getMemoryId();
if (memoryId != null) {
return memoryId;
}
}
}

// fallback to the default since there is nothing else we can really use here
return "default";
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
io.quarkiverse.langchain4j.runtime.RequestScopeStateDefaultMemoryIdProvider
1 change: 1 addition & 0 deletions docs/modules/ROOT/nav.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,5 @@
* Advanced topics
** xref:fault-tolerance.adoc[Fault Tolerance]
** xref:websockets.adoc[WebSockets]
** xref:enable-disable-integrations.adoc[Enabling and Disabling Integrations]
70 changes: 70 additions & 0 deletions docs/modules/ROOT/pages/websockets.adoc
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
= Using AI services in WebSockets

include::./includes/attributes.adoc[]

Using a chatbot in a WebSockets environment is quite common, which is why the extension provides a few facilities to make such usages as easy as possible.

* 1. Start by adding the link:https://quarkus.io/guides/websockets-next-tutorial[quarkus-websockets-next] dependency to your `pom.xml` file:
[source,xml]
----
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-websockets-next</artifactId>
</dependency>
----

IMPORTANT: `quarkus-websockets-next` is available as of Quarkus 3.9.

* 2. Annotated your AiService with `@SessionScoped`
[source,java]
----
import dev.langchain4j.service.SystemMessage;
import io.quarkiverse.langchain4j.RegisterAiService;
import jakarta.enterprise.context.SessionScoped;
@RegisterAiService
@SessionScoped
public interface SessionScopedChatBot {
@SystemMessage("You are chatbot that helps users with their queries")
String chat(String message);
}
----

* 3. Create a WebSocket endpoint
[source,java]
----
import io.quarkus.websockets.next.OnOpen;
import io.quarkus.websockets.next.OnTextMessage;
import io.quarkus.websockets.next.WebSocket;
@WebSocket(path = "/websocket")
public static class WebSocketChatBot {
private final SessionScopedChatBot bot;
public WebSocketChatBot(SessionScopedChatBot bot) {
this.bot = bot;
}
@OnOpen
public String onOpen() {
return bot.chat("Hello, how can I help you?");
}
@OnTextMessage
public String onMessage(String message) {
return bot.chat(message);
}
}
----

Two things are important to note in the snippets above:

* There is no `@MemoryId` field being used in the AI service. Quarkus will automatically all the WebSocket connection ID as the memory ID.
This ensures that each WebSocket session has its own chat memory.
* The use of `@SessionScoped` is important as the scope of the AI service is tied to the scope of the WebSocket endpoint.
This allows Quarkus to automatically clear chat memory when the WebSocket connection is closed for whatever reason.
Loading

0 comments on commit 8c4429e

Please sign in to comment.