From c50e59484354b08a5d022d21a947d1b394bce797 Mon Sep 17 00:00:00 2001 From: Georgios Andrianakis Date: Wed, 13 Mar 2024 09:05:26 +0200 Subject: [PATCH] Introduce a Mistral AI module Closes: #371 --- integration-tests/mistralai/pom.xml | 126 ++++++++++ .../chat/ChatLanguageModelResource.java | 58 +++++ .../chat/EmbeddingModelResource.java | 22 ++ .../mistralai/chat/ModelsResource.java | 22 ++ .../src/main/resources/application.properties | 2 + integration-tests/pom.xml | 1 + mistral/deployment/pom.xml | 62 +++++ .../deployment/ChatModelBuildConfig.java | 16 ++ .../deployment/EmbeddingModelBuildConfig.java | 16 ++ .../LangChain4jMistralAiBuildConfig.java | 22 ++ .../deployment/MistralAiProcessor.java | 102 ++++++++ .../MistralAiChatLanguageModelSmokeTest.java | 106 ++++++++ mistral/pom.xml | 21 ++ mistral/runtime/pom.xml | 129 ++++++++++ .../mistralai/MistralAiRestApi.java | 176 +++++++++++++ .../mistralai/QuarkusMistralAiClient.java | 233 ++++++++++++++++++ .../mistralai/runtime/MistralAiRecorder.java | 191 ++++++++++++++ .../runtime/config/ChatModelConfig.java | 67 +++++ .../runtime/config/EmbeddingModelConfig.java | 30 +++ .../config/LangChain4jMistralAiConfig.java | 88 +++++++ .../runtime/jackson/MistralAiRoleMixin.java | 14 ++ .../runtime/jackson/RoleDeserializer.java | 24 ++ .../runtime/jackson/RoleSerializer.java | 21 ++ .../src/main/resources/META-INF/beans.xml | 0 .../resources/META-INF/quarkus-extension.yaml | 17 ++ ...el.mistralai.MistralAiClientBuilderFactory | 1 + pom.xml | 1 + 27 files changed, 1568 insertions(+) create mode 100644 integration-tests/mistralai/pom.xml create mode 100644 integration-tests/mistralai/src/main/java/org/acme/example/mistralai/chat/ChatLanguageModelResource.java create mode 100644 integration-tests/mistralai/src/main/java/org/acme/example/mistralai/chat/EmbeddingModelResource.java create mode 100644 integration-tests/mistralai/src/main/java/org/acme/example/mistralai/chat/ModelsResource.java create mode 100644 integration-tests/mistralai/src/main/resources/application.properties create mode 100644 mistral/deployment/pom.xml create mode 100644 mistral/deployment/src/main/java/io/quarkiverse/langchain4j/mistralai/deployment/ChatModelBuildConfig.java create mode 100644 mistral/deployment/src/main/java/io/quarkiverse/langchain4j/mistralai/deployment/EmbeddingModelBuildConfig.java create mode 100644 mistral/deployment/src/main/java/io/quarkiverse/langchain4j/mistralai/deployment/LangChain4jMistralAiBuildConfig.java create mode 100644 mistral/deployment/src/main/java/io/quarkiverse/langchain4j/mistralai/deployment/MistralAiProcessor.java create mode 100644 mistral/deployment/src/test/java/io/quarkiverse/langchain4j/mistralai/deployment/MistralAiChatLanguageModelSmokeTest.java create mode 100644 mistral/pom.xml create mode 100644 mistral/runtime/pom.xml create mode 100644 mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/MistralAiRestApi.java create mode 100644 mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/QuarkusMistralAiClient.java create mode 100644 mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/runtime/MistralAiRecorder.java create mode 100644 mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/runtime/config/ChatModelConfig.java create mode 100644 mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/runtime/config/EmbeddingModelConfig.java create mode 100644 mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/runtime/config/LangChain4jMistralAiConfig.java create mode 100644 mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/runtime/jackson/MistralAiRoleMixin.java create mode 100644 mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/runtime/jackson/RoleDeserializer.java create mode 100644 mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/runtime/jackson/RoleSerializer.java create mode 100644 mistral/runtime/src/main/resources/META-INF/beans.xml create mode 100644 mistral/runtime/src/main/resources/META-INF/quarkus-extension.yaml create mode 100644 mistral/runtime/src/main/resources/META-INF/services/dev.langchain4j.model.mistralai.MistralAiClientBuilderFactory diff --git a/integration-tests/mistralai/pom.xml b/integration-tests/mistralai/pom.xml new file mode 100644 index 000000000..3c680f372 --- /dev/null +++ b/integration-tests/mistralai/pom.xml @@ -0,0 +1,126 @@ + + + 4.0.0 + + io.quarkiverse.langchain4j + quarkus-langchain4j-integration-tests-parent + 999-SNAPSHOT + + quarkus-langchain4j-integration-test-mistralai + Quarkus LangChain4j - Integration Tests - MistralAI + + true + + + + io.quarkus + quarkus-resteasy-reactive-jackson + + + io.quarkiverse.langchain4j + quarkus-langchain4j-mistral-ai + ${project.version} + + + io.quarkus + quarkus-micrometer + + + io.quarkus + quarkus-smallrye-fault-tolerance + + + io.quarkus + quarkus-junit5 + test + + + io.rest-assured + rest-assured + test + + + org.assertj + assertj-core + ${assertj.version} + test + + + io.quarkus + quarkus-devtools-testing + test + + + + + io.quarkiverse.langchain4j + quarkus-langchain4j-mistral-ai-deployment + ${project.version} + pom + test + + + * + * + + + + + + + + io.quarkus + quarkus-maven-plugin + + + + build + + + + + + maven-failsafe-plugin + + + + integration-test + verify + + + + ${project.build.directory}/${project.build.finalName}-runner + org.jboss.logmanager.LogManager + ${maven.home} + + + + + + + + + + native-image + + + native + + + + + + maven-surefire-plugin + + ${native.surefire.skip} + + + + + + false + native + + + + diff --git a/integration-tests/mistralai/src/main/java/org/acme/example/mistralai/chat/ChatLanguageModelResource.java b/integration-tests/mistralai/src/main/java/org/acme/example/mistralai/chat/ChatLanguageModelResource.java new file mode 100644 index 000000000..e6487abe7 --- /dev/null +++ b/integration-tests/mistralai/src/main/java/org/acme/example/mistralai/chat/ChatLanguageModelResource.java @@ -0,0 +1,58 @@ +package org.acme.example.mistralai.chat; + +import jakarta.ws.rs.GET; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.core.MediaType; + +import org.jboss.resteasy.reactive.RestStreamElementType; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.output.Response; +import io.smallrye.mutiny.Multi; + +@Path("chat") +public class ChatLanguageModelResource { + + private final ChatLanguageModel chatModel; + private final StreamingChatLanguageModel streamingChatModel; + + public ChatLanguageModelResource(ChatLanguageModel chatModel, StreamingChatLanguageModel streamingChatModel) { + this.chatModel = chatModel; + this.streamingChatModel = streamingChatModel; + } + + @GET + @Path("blocking") + public String blocking() { + return chatModel.generate("When was the nobel prize for economics first awarded?"); + } + + @GET + @Path("streaming") + @RestStreamElementType(MediaType.TEXT_PLAIN) + public Multi streaming() { + return Multi.createFrom().emitter(emitter -> { + streamingChatModel.generate("When was the nobel prize for economics first awarded?", + new StreamingResponseHandler<>() { + @Override + public void onNext(String token) { + emitter.emit(token); + } + + @Override + public void onError(Throwable error) { + emitter.fail(error); + } + + @Override + public void onComplete(Response response) { + emitter.complete(); + } + }); + }); + + } +} diff --git a/integration-tests/mistralai/src/main/java/org/acme/example/mistralai/chat/EmbeddingModelResource.java b/integration-tests/mistralai/src/main/java/org/acme/example/mistralai/chat/EmbeddingModelResource.java new file mode 100644 index 000000000..f3ac7f5c9 --- /dev/null +++ b/integration-tests/mistralai/src/main/java/org/acme/example/mistralai/chat/EmbeddingModelResource.java @@ -0,0 +1,22 @@ +package org.acme.example.mistralai.chat; + +import dev.langchain4j.model.embedding.EmbeddingModel; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.Path; + +import dev.langchain4j.model.mistralai.MistralAiEmbeddingModel; + +@Path("embedding") +public class EmbeddingModelResource { + + private final EmbeddingModel embeddingModel; + + public EmbeddingModelResource(EmbeddingModel embeddingModel) { + this.embeddingModel = embeddingModel; + } + + @GET + public int blocking() { + return embeddingModel.embed("When was the nobel prize for economics first awarded?").content().dimension(); + } +} diff --git a/integration-tests/mistralai/src/main/java/org/acme/example/mistralai/chat/ModelsResource.java b/integration-tests/mistralai/src/main/java/org/acme/example/mistralai/chat/ModelsResource.java new file mode 100644 index 000000000..06711cadb --- /dev/null +++ b/integration-tests/mistralai/src/main/java/org/acme/example/mistralai/chat/ModelsResource.java @@ -0,0 +1,22 @@ +package org.acme.example.mistralai.chat; + +import jakarta.ws.rs.GET; +import jakarta.ws.rs.Path; + +import dev.langchain4j.model.mistralai.MistralAiClient; +import dev.langchain4j.model.mistralai.MistralAiModelResponse; + +@Path("models") +public class ModelsResource { + + private final MistralAiClient client; + + public ModelsResource(MistralAiClient client) { + this.client = client; + } + + @GET + public MistralAiModelResponse models() { + return client.listModels(); + } +} diff --git a/integration-tests/mistralai/src/main/resources/application.properties b/integration-tests/mistralai/src/main/resources/application.properties new file mode 100644 index 000000000..f7c0441cb --- /dev/null +++ b/integration-tests/mistralai/src/main/resources/application.properties @@ -0,0 +1,2 @@ +quarkus.langchain4j.mistralai.log-requests=true +quarkus.langchain4j.mistralai.log-responses=true diff --git a/integration-tests/pom.xml b/integration-tests/pom.xml index 83e52bb34..7ba3c6ca9 100644 --- a/integration-tests/pom.xml +++ b/integration-tests/pom.xml @@ -18,6 +18,7 @@ simple-ollama azure-openai multiple-providers + mistralai devui embed-all-minilm-l6-v2-q embed-all-minilm-l6-v2 diff --git a/mistral/deployment/pom.xml b/mistral/deployment/pom.xml new file mode 100644 index 000000000..bc998e268 --- /dev/null +++ b/mistral/deployment/pom.xml @@ -0,0 +1,62 @@ + + + 4.0.0 + + io.quarkiverse.langchain4j + quarkus-langchain4j-mistral-ai-parent + 999-SNAPSHOT + + quarkus-langchain4j-mistral-ai-deployment + Quarkus LangChain4j - Mistral AI - Deployment + + + io.quarkiverse.langchain4j + quarkus-langchain4j-mistral-ai + ${project.version} + + + io.quarkus + quarkus-rest-client-reactive-jackson-deployment + + + io.quarkiverse.langchain4j + quarkus-langchain4j-core-deployment + ${project.version} + + + io.quarkus + quarkus-junit5-internal + test + + + org.assertj + assertj-core + ${assertj.version} + test + + + org.wiremock + wiremock-standalone + ${wiremock.version} + test + + + + + + maven-compiler-plugin + + + + io.quarkus + quarkus-extension-processor + ${quarkus.version} + + + + + + + diff --git a/mistral/deployment/src/main/java/io/quarkiverse/langchain4j/mistralai/deployment/ChatModelBuildConfig.java b/mistral/deployment/src/main/java/io/quarkiverse/langchain4j/mistralai/deployment/ChatModelBuildConfig.java new file mode 100644 index 000000000..5ee92c5da --- /dev/null +++ b/mistral/deployment/src/main/java/io/quarkiverse/langchain4j/mistralai/deployment/ChatModelBuildConfig.java @@ -0,0 +1,16 @@ +package io.quarkiverse.langchain4j.mistralai.deployment; + +import java.util.Optional; + +import io.quarkus.runtime.annotations.ConfigDocDefault; +import io.quarkus.runtime.annotations.ConfigGroup; + +@ConfigGroup +public interface ChatModelBuildConfig { + + /** + * Whether the model should be enabled + */ + @ConfigDocDefault("true") + Optional enabled(); +} diff --git a/mistral/deployment/src/main/java/io/quarkiverse/langchain4j/mistralai/deployment/EmbeddingModelBuildConfig.java b/mistral/deployment/src/main/java/io/quarkiverse/langchain4j/mistralai/deployment/EmbeddingModelBuildConfig.java new file mode 100644 index 000000000..ae5870f40 --- /dev/null +++ b/mistral/deployment/src/main/java/io/quarkiverse/langchain4j/mistralai/deployment/EmbeddingModelBuildConfig.java @@ -0,0 +1,16 @@ +package io.quarkiverse.langchain4j.mistralai.deployment; + +import java.util.Optional; + +import io.quarkus.runtime.annotations.ConfigDocDefault; +import io.quarkus.runtime.annotations.ConfigGroup; + +@ConfigGroup +public interface EmbeddingModelBuildConfig { + + /** + * Whether the model should be enabled + */ + @ConfigDocDefault("true") + Optional enabled(); +} diff --git a/mistral/deployment/src/main/java/io/quarkiverse/langchain4j/mistralai/deployment/LangChain4jMistralAiBuildConfig.java b/mistral/deployment/src/main/java/io/quarkiverse/langchain4j/mistralai/deployment/LangChain4jMistralAiBuildConfig.java new file mode 100644 index 000000000..43aa9d173 --- /dev/null +++ b/mistral/deployment/src/main/java/io/quarkiverse/langchain4j/mistralai/deployment/LangChain4jMistralAiBuildConfig.java @@ -0,0 +1,22 @@ +package io.quarkiverse.langchain4j.mistralai.deployment; + +import static io.quarkus.runtime.annotations.ConfigPhase.BUILD_TIME; + +import io.quarkus.runtime.annotations.ConfigRoot; +import io.smallrye.config.ConfigMapping; + +@ConfigRoot(phase = BUILD_TIME) +@ConfigMapping(prefix = "quarkus.langchain4j.mistralai") +public interface LangChain4jMistralAiBuildConfig { + + /** + * Chat model related settings + */ + ChatModelBuildConfig chatModel(); + + /** + * Embedding model related settings + */ + EmbeddingModelBuildConfig embeddingModel(); + +} diff --git a/mistral/deployment/src/main/java/io/quarkiverse/langchain4j/mistralai/deployment/MistralAiProcessor.java b/mistral/deployment/src/main/java/io/quarkiverse/langchain4j/mistralai/deployment/MistralAiProcessor.java new file mode 100644 index 000000000..6c7942076 --- /dev/null +++ b/mistral/deployment/src/main/java/io/quarkiverse/langchain4j/mistralai/deployment/MistralAiProcessor.java @@ -0,0 +1,102 @@ +package io.quarkiverse.langchain4j.mistralai.deployment; + +import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.CHAT_MODEL; +import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.EMBEDDING_MODEL; +import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.STREAMING_CHAT_MODEL; + +import java.util.List; + +import jakarta.enterprise.context.ApplicationScoped; + +import org.jboss.jandex.AnnotationInstance; + +import io.quarkiverse.langchain4j.ModelName; +import io.quarkiverse.langchain4j.deployment.items.ChatModelProviderCandidateBuildItem; +import io.quarkiverse.langchain4j.deployment.items.EmbeddingModelProviderCandidateBuildItem; +import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem; +import io.quarkiverse.langchain4j.deployment.items.SelectedEmbeddingModelCandidateBuildItem; +import io.quarkiverse.langchain4j.mistralai.runtime.MistralAiRecorder; +import io.quarkiverse.langchain4j.mistralai.runtime.config.LangChain4jMistralAiConfig; +import io.quarkiverse.langchain4j.runtime.NamedModelUtil; +import io.quarkus.arc.deployment.SyntheticBeanBuildItem; +import io.quarkus.deployment.annotations.BuildProducer; +import io.quarkus.deployment.annotations.BuildStep; +import io.quarkus.deployment.annotations.ExecutionTime; +import io.quarkus.deployment.annotations.Record; +import io.quarkus.deployment.builditem.FeatureBuildItem; + +public class MistralAiProcessor { + + private static final String FEATURE = "langchain4j-mistralai"; + private static final String PROVIDER = "mistralai"; + + @BuildStep + FeatureBuildItem feature() { + return new FeatureBuildItem(FEATURE); + } + + @BuildStep + public void providerCandidates(BuildProducer chatProducer, + BuildProducer embeddingProducer, + LangChain4jMistralAiBuildConfig config) { + if (config.chatModel().enabled().isEmpty() || config.chatModel().enabled().get()) { + chatProducer.produce(new ChatModelProviderCandidateBuildItem(PROVIDER)); + } + if (config.embeddingModel().enabled().isEmpty() || config.embeddingModel().enabled().get()) { + embeddingProducer.produce(new EmbeddingModelProviderCandidateBuildItem(PROVIDER)); + } + } + + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") + @BuildStep + @Record(ExecutionTime.RUNTIME_INIT) + void generateBeans(MistralAiRecorder recorder, + List selectedChatItem, + List selectedEmbedding, + LangChain4jMistralAiConfig config, + BuildProducer beanProducer) { + + for (var selected : selectedChatItem) { + if (PROVIDER.equals(selected.getProvider())) { + String modelName = selected.getModelName(); + var builder = SyntheticBeanBuildItem + .configure(CHAT_MODEL) + .setRuntimeInit() + .defaultBean() + .scope(ApplicationScoped.class) + .supplier(recorder.chatModel(config, modelName)); + addQualifierIfNecessary(builder, modelName); + beanProducer.produce(builder.done()); + + var streamingBuilder = SyntheticBeanBuildItem + .configure(STREAMING_CHAT_MODEL) + .setRuntimeInit() + .defaultBean() + .scope(ApplicationScoped.class) + .supplier(recorder.streamingChatModel(config, modelName)); + addQualifierIfNecessary(streamingBuilder, modelName); + beanProducer.produce(streamingBuilder.done()); + } + } + + for (var selected : selectedEmbedding) { + if (PROVIDER.equals(selected.getProvider())) { + String modelName = selected.getModelName(); + var builder = SyntheticBeanBuildItem + .configure(EMBEDDING_MODEL) + .setRuntimeInit() + .defaultBean() + .scope(ApplicationScoped.class) + .supplier(recorder.embeddingModel(config, modelName)); + addQualifierIfNecessary(builder, modelName); + beanProducer.produce(builder.done()); + } + } + } + + private void addQualifierIfNecessary(SyntheticBeanBuildItem.ExtendedBeanConfigurator builder, String modelName) { + if (!NamedModelUtil.isDefault(modelName)) { + builder.addQualifier(AnnotationInstance.builder(ModelName.class).add("value", modelName).build()); + } + } +} diff --git a/mistral/deployment/src/test/java/io/quarkiverse/langchain4j/mistralai/deployment/MistralAiChatLanguageModelSmokeTest.java b/mistral/deployment/src/test/java/io/quarkiverse/langchain4j/mistralai/deployment/MistralAiChatLanguageModelSmokeTest.java new file mode 100644 index 000000000..810c0e64e --- /dev/null +++ b/mistral/deployment/src/test/java/io/quarkiverse/langchain4j/mistralai/deployment/MistralAiChatLanguageModelSmokeTest.java @@ -0,0 +1,106 @@ +package io.quarkiverse.langchain4j.mistralai.deployment; + +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.equalTo; +import static com.github.tomakehurst.wiremock.client.WireMock.post; +import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; +import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.options; +import static org.assertj.core.api.Assertions.assertThat; + +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import com.github.tomakehurst.wiremock.WireMockServer; +import com.github.tomakehurst.wiremock.stubbing.ServeEvent; +import com.github.tomakehurst.wiremock.verification.LoggedRequest; + +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.mistralai.MistralAiChatModel; +import io.quarkus.arc.ClientProxy; +import io.quarkus.test.QuarkusUnitTest; + +class MistralAiChatLanguageModelSmokeTest { + private static final int WIREMOCK_PORT = 8089; + private static final String CHAT_MODEL_ID = "mistral-tiny"; + private static final String API_KEY = "somekey"; + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)) + .overrideRuntimeConfigKey("quarkus.langchain4j.mistralai.api-key", API_KEY) + .overrideRuntimeConfigKey("quarkus.langchain4j.mistralai.log-requests", "true") + .overrideRuntimeConfigKey("quarkus.langchain4j.mistralai.base-url", "http://localhost:" + WIREMOCK_PORT + "/v1"); + + static WireMockServer wireMockServer; + + @BeforeAll + static void beforeAll() { + wireMockServer = new WireMockServer(options().port(WIREMOCK_PORT)); + wireMockServer.start(); + } + + @AfterAll + static void afterAll() { + wireMockServer.stop(); + } + + @BeforeEach + void setup() { + wireMockServer.resetAll(); + } + + @Inject + ChatLanguageModel chatLanguageModel; + + @Test + void test() { + assertThat(ClientProxy.unwrap(chatLanguageModel)).isInstanceOf(MistralAiChatModel.class); + + wireMockServer.stubFor( + post(urlEqualTo("/v1/chat/completions")) + .withHeader("Authorization", equalTo("Bearer " + API_KEY)) + .willReturn(aResponse() + .withHeader("Content-Type", "application/json") + .withBody(""" + { + "id": "0bdf265cb18d493d96b62029f024d897", + "object": "chat.completion", + "created": 1711442725, + "model": "mistral-tiny", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Nice to meet you" + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 19, + "total_tokens": 127, + "completion_tokens": 108 + } + } + """))); + + String response = chatLanguageModel.generate("hello"); + assertThat(response).isEqualTo("Nice to meet you"); + + assertThat(wireMockServer.getAllServeEvents()).hasSize(1); + ServeEvent serveEvent = wireMockServer.getAllServeEvents().get(0); // this works because we reset requests for Wiremock before each test + LoggedRequest loggedRequest = serveEvent.getRequest(); + assertThat(loggedRequest.getHeader("User-Agent")).isEqualTo("Resteasy Reactive Client"); + String requestBody = new String(loggedRequest.getBody()); + assertThat(requestBody).contains("hello").contains(CHAT_MODEL_ID); + } + +} diff --git a/mistral/pom.xml b/mistral/pom.xml new file mode 100644 index 000000000..33ccb79d7 --- /dev/null +++ b/mistral/pom.xml @@ -0,0 +1,21 @@ + + + 4.0.0 + + io.quarkiverse.langchain4j + quarkus-langchain4j-parent + 999-SNAPSHOT + + quarkus-langchain4j-mistral-ai-parent + Quarkus LangChain4j - Mistral AI - Parent + pom + + + deployment + runtime + + + + diff --git a/mistral/runtime/pom.xml b/mistral/runtime/pom.xml new file mode 100644 index 000000000..889a531b0 --- /dev/null +++ b/mistral/runtime/pom.xml @@ -0,0 +1,129 @@ + + + 4.0.0 + + io.quarkiverse.langchain4j + quarkus-langchain4j-mistral-ai-parent + 999-SNAPSHOT + + quarkus-langchain4j-mistral-ai + Quarkus LangChain4j - Mistral AI - Runtime + + + io.quarkus + quarkus-arc + + + io.quarkus + quarkus-rest-client-reactive-jackson + + + io.quarkiverse.langchain4j + quarkus-langchain4j-core + ${project.version} + + + + dev.langchain4j + langchain4j-mistral-ai + + + com.squareup.retrofit2 + retrofit + + + com.squareup.okhttp3 + okhttp + + + com.google.code.gson + gson + + + com.squareup.retrofit2 + converter-gson + + + com.squareup.okhttp3 + okhttp-sse + + + com.github.spullara.mustache.java + compiler + + + + + + io.quarkus + quarkus-junit5-internal + test + + + org.mockito + mockito-core + test + + + org.assertj + assertj-core + ${assertj.version} + test + + + + + + io.quarkus + quarkus-extension-maven-plugin + ${quarkus.version} + + + compile + + extension-descriptor + + + ${project.groupId}:${project.artifactId}-deployment:${project.version} + + + + + + maven-compiler-plugin + + + + io.quarkus + quarkus-extension-processor + ${quarkus.version} + + + + + + maven-jar-plugin + + + generate-codestart-jar + generate-resources + + jar + + + ${project.basedir}/src/main + + codestarts/** + + codestarts + true + + + + + + + + diff --git a/mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/MistralAiRestApi.java b/mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/MistralAiRestApi.java new file mode 100644 index 000000000..500948b9d --- /dev/null +++ b/mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/MistralAiRestApi.java @@ -0,0 +1,176 @@ +package io.quarkiverse.langchain4j.mistralai; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.lang.annotation.Annotation; +import java.lang.reflect.Type; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.function.Predicate; + +import jakarta.annotation.Priority; +import jakarta.ws.rs.Consumes; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.POST; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.Priorities; +import jakarta.ws.rs.Produces; +import jakarta.ws.rs.WebApplicationException; +import jakarta.ws.rs.core.HttpHeaders; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.MultivaluedMap; +import jakarta.ws.rs.ext.MessageBodyWriter; +import jakarta.ws.rs.ext.WriterInterceptor; +import jakarta.ws.rs.ext.WriterInterceptorContext; + +import org.eclipse.microprofile.rest.client.annotation.ClientHeaderParam; +import org.eclipse.microprofile.rest.client.annotation.RegisterProvider; +import org.jboss.resteasy.reactive.RestStreamElementType; +import org.jboss.resteasy.reactive.client.SseEvent; +import org.jboss.resteasy.reactive.client.SseEventFilter; +import org.jboss.resteasy.reactive.common.providers.serialisers.AbstractJsonMessageBodyReader; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.ObjectReader; + +import dev.langchain4j.model.mistralai.MistralAiChatCompletionRequest; +import dev.langchain4j.model.mistralai.MistralAiChatCompletionResponse; +import dev.langchain4j.model.mistralai.MistralAiEmbeddingRequest; +import dev.langchain4j.model.mistralai.MistralAiEmbeddingResponse; +import dev.langchain4j.model.mistralai.MistralAiModelResponse; +import io.quarkiverse.langchain4j.QuarkusJsonCodecFactory; +import io.quarkus.rest.client.reactive.NotBody; +import io.smallrye.mutiny.Multi; + +/** + * This Microprofile REST client is used as the building block of all the API calls to MistralAI. + * The implementation is provided by the Reactive REST Client in Quarkus. + */ + +@Path("") +@ClientHeaderParam(name = "Authorization", value = "Bearer {token}") +@Consumes(MediaType.APPLICATION_JSON) +@Produces(MediaType.APPLICATION_JSON) +@RegisterProvider(MistralAiRestApi.MistralAiRestApiJacksonReader.class) +@RegisterProvider(MistralAiRestApi.MistralAiRestApiJacksonWriter.class) +@RegisterProvider(MistralAiRestApi.MistralAiRestApiWriterInterceptor.class) +public interface MistralAiRestApi { + + /** + * Perform a blocking request for a completion response + */ + @Path("chat/completions") + @POST + MistralAiChatCompletionResponse blockingChatCompletion(MistralAiChatCompletionRequest request, @NotBody String token); + + /** + * Performs a non-blocking request for a streaming completion request + */ + @Path("chat/completions") + @POST + @RestStreamElementType(MediaType.APPLICATION_JSON) + @SseEventFilter(DoneFilter.class) + Multi streamingChatCompletion(MistralAiChatCompletionRequest request, + @NotBody String token); + + @Path("embeddings") + @POST + MistralAiEmbeddingResponse embedding(MistralAiEmbeddingRequest request, @NotBody String token); + + @Path("models") + @GET + MistralAiModelResponse models(@NotBody String token); + + /** + * The point of this is to properly set the {@code stream} value of the request + * so users don't have to remember to set it manually + */ + class MistralAiRestApiWriterInterceptor implements WriterInterceptor { + @Override + public void aroundWriteTo(WriterInterceptorContext context) throws IOException, WebApplicationException { + Object entity = context.getEntity(); + if (entity instanceof MistralAiChatCompletionRequest request) { + MultivaluedMap headers = context.getHeaders(); + List acceptList = headers.get(HttpHeaders.ACCEPT); + if ((acceptList != null) && (acceptList.size() == 1)) { + String accept = (String) acceptList.get(0); + if (MediaType.APPLICATION_JSON.equals(accept)) { + if (Boolean.TRUE.equals(request.getStream())) { + context.setEntity(from(request).stream(null).build()); + } + } else if (MediaType.SERVER_SENT_EVENTS.equals(accept)) { + if (!Boolean.TRUE.equals(request.getStream())) { + context.setEntity(from(request).stream(true).build()); + } + } + } + } + context.proceed(); + } + + private MistralAiChatCompletionRequest.MistralAiChatCompletionRequestBuilder from( + MistralAiChatCompletionRequest request) { + var builder = MistralAiChatCompletionRequest.builder(); + builder.model(request.getModel()); + builder.messages(request.getMessages()); + builder.temperature(request.getTemperature()); + builder.topP(request.getTopP()); + builder.maxTokens(request.getMaxTokens()); + builder.stream(request.getStream()); + builder.safePrompt(request.getSafePrompt()); + builder.randomSeed(request.getRandomSeed()); + return builder; + } + } + + /** + * Ensures that the terminal event sent by OpenAI is not processed (as it is not a valid json event) + */ + class DoneFilter implements Predicate> { + + @Override + public boolean test(SseEvent event) { + return !"[DONE]".equals(event.data()); + } + } + + @Priority(Priorities.USER - 100) // this priority ensures that our Reader has priority over the standard Jackson one + class MistralAiRestApiJacksonReader extends AbstractJsonMessageBodyReader { + + /** + * We need a custom version of the Jackson provider because reading SSE values does not work properly with + * {@code @ClientObjectMapper} due to the lack of a complete context in those requests + */ + @Override + public Object readFrom(Class type, Type genericType, Annotation[] annotations, MediaType mediaType, + MultivaluedMap httpHeaders, InputStream entityStream) + throws IOException, WebApplicationException { + return ObjectMapperHolder.READER + .forType(ObjectMapperHolder.READER.getTypeFactory().constructType(genericType != null ? genericType : type)) + .readValue(entityStream); + } + } + + @Priority(Priorities.USER + 100) // this priority ensures that our Writer has priority over the standard Jackson one + class MistralAiRestApiJacksonWriter implements MessageBodyWriter { + + @Override + public boolean isWriteable(Class type, Type genericType, Annotation[] annotations, MediaType mediaType) { + return true; + } + + @Override + public void writeTo(Object o, Class type, Type genericType, Annotation[] annotations, MediaType mediaType, + MultivaluedMap httpHeaders, OutputStream entityStream) + throws IOException, WebApplicationException { + entityStream.write(ObjectMapperHolder.MAPPER.writeValueAsString(o).getBytes(StandardCharsets.UTF_8)); + } + } + + class ObjectMapperHolder { + public static final ObjectMapper MAPPER = QuarkusJsonCodecFactory.SnakeCaseObjectMapperHolder.MAPPER; + + private static final ObjectReader READER = MAPPER.reader(); + } +} diff --git a/mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/QuarkusMistralAiClient.java b/mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/QuarkusMistralAiClient.java new file mode 100644 index 000000000..61aaf873a --- /dev/null +++ b/mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/QuarkusMistralAiClient.java @@ -0,0 +1,233 @@ +package io.quarkiverse.langchain4j.mistralai; + +import static dev.langchain4j.model.mistralai.DefaultMistralAiHelper.finishReasonFrom; +import static dev.langchain4j.model.mistralai.DefaultMistralAiHelper.tokenUsageFrom; +import static java.util.stream.Collectors.joining; +import static java.util.stream.StreamSupport.stream; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import org.jboss.logging.Logger; +import org.jboss.resteasy.reactive.client.api.ClientLogger; +import org.jboss.resteasy.reactive.client.api.LoggingScope; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.mistralai.MistralAiChatCompletionChoice; +import dev.langchain4j.model.mistralai.MistralAiChatCompletionRequest; +import dev.langchain4j.model.mistralai.MistralAiChatCompletionResponse; +import dev.langchain4j.model.mistralai.MistralAiClient; +import dev.langchain4j.model.mistralai.MistralAiClientBuilderFactory; +import dev.langchain4j.model.mistralai.MistralAiEmbeddingRequest; +import dev.langchain4j.model.mistralai.MistralAiEmbeddingResponse; +import dev.langchain4j.model.mistralai.MistralAiModelResponse; +import dev.langchain4j.model.mistralai.MistralAiUsage; +import dev.langchain4j.model.output.FinishReason; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; +import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; +import io.vertx.core.Handler; +import io.vertx.core.MultiMap; +import io.vertx.core.buffer.Buffer; +import io.vertx.core.http.HttpClientRequest; +import io.vertx.core.http.HttpClientResponse; + +public class QuarkusMistralAiClient extends MistralAiClient { + + private final String apiKey; + private final MistralAiRestApi restApi; + + public QuarkusMistralAiClient(Builder builder) { + this.apiKey = builder.apiKey; + + try { + QuarkusRestClientBuilder restApiBuilder = QuarkusRestClientBuilder.newBuilder() + .baseUri(new URI(builder.baseUrl)) + .connectTimeout(builder.timeout.toSeconds(), TimeUnit.SECONDS) + .readTimeout(builder.timeout.toSeconds(), TimeUnit.SECONDS); + if (builder.logRequests || builder.logResponses) { + restApiBuilder.loggingScope(LoggingScope.REQUEST_RESPONSE); + restApiBuilder.clientLogger(new QuarkusMistralAiClient.MistralAiClientLogger(builder.logRequests, + builder.logResponses)); + } + restApi = restApiBuilder.build(MistralAiRestApi.class); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + @Override + public MistralAiChatCompletionResponse chatCompletion(MistralAiChatCompletionRequest request) { + return restApi.blockingChatCompletion(request, apiKey); + } + + @Override + public void streamingChatCompletion(MistralAiChatCompletionRequest request, + StreamingResponseHandler handler) { + AtomicReference contentBuilder = new AtomicReference<>(new StringBuffer()); + AtomicReference tokenUsage = new AtomicReference<>(); + AtomicReference finishReason = new AtomicReference<>(); + restApi.streamingChatCompletion(request, apiKey).subscribe().with(new Consumer<>() { + @Override + public void accept(MistralAiChatCompletionResponse response) { + MistralAiChatCompletionChoice choice = response.getChoices().get(0); + String chunk = choice.getDelta().getContent(); + contentBuilder.get().append(chunk); + handler.onNext(chunk); + + MistralAiUsage usageInfo = response.getUsage(); + if (usageInfo != null) { + tokenUsage.set(tokenUsageFrom(usageInfo)); + } + + String finishReasonString = choice.getFinishReason(); + if (finishReasonString != null) { + finishReason.set(finishReasonFrom(finishReasonString)); + } + } + }, new Consumer<>() { + @Override + public void accept(Throwable t) { + handler.onError(t); + } + }, new Runnable() { + @Override + public void run() { + Response response = Response.from( + AiMessage.from(contentBuilder.get().toString()), + tokenUsage.get(), + finishReason.get()); + handler.onComplete(response); + } + }); + } + + @Override + public MistralAiEmbeddingResponse embedding(MistralAiEmbeddingRequest request) { + return restApi.embedding(request, apiKey); + } + + @Override + public MistralAiModelResponse listModels() { + return restApi.models(apiKey); + } + + public static class QuarkusMistralAiClientBuilderFactory implements MistralAiClientBuilderFactory { + + @Override + public Builder get() { + return new Builder(); + } + } + + public static class Builder extends MistralAiClient.Builder { + + @Override + public QuarkusMistralAiClient build() { + return new QuarkusMistralAiClient(this); + } + } + + /** + * Introduce a custom logger as the stock one logs at the DEBUG level by default... + */ + static class MistralAiClientLogger implements ClientLogger { + private static final Logger log = Logger.getLogger(MistralAiClientLogger.class); + + private static final Pattern BEARER_PATTERN = Pattern.compile("(Bearer\\s*)(\\w{2})(\\w+)(\\w{2})"); + + private final boolean logRequests; + private final boolean logResponses; + + public MistralAiClientLogger(boolean logRequests, boolean logResponses) { + this.logRequests = logRequests; + this.logResponses = logResponses; + } + + @Override + public void setBodySize(int bodySize) { + // ignore + } + + @Override + public void logRequest(HttpClientRequest request, Buffer body, boolean omitBody) { + if (!logRequests || !log.isInfoEnabled()) { + return; + } + try { + log.infof("Request:\n- method: %s\n- url: %s\n- headers: %s\n- body: %s", + request.getMethod(), + request.absoluteURI(), + inOneLine(request.headers()), + bodyToString(body)); + } catch (Exception e) { + log.warn("Failed to log request", e); + } + } + + @Override + public void logResponse(HttpClientResponse response, boolean redirect) { + if (!logResponses || !log.isInfoEnabled()) { + return; + } + response.bodyHandler(new Handler<>() { + @Override + public void handle(Buffer body) { + try { + log.infof( + "Response:\n- status code: %s\n- headers: %s\n- body: %s", + response.statusCode(), + inOneLine(response.headers()), + bodyToString(body)); + } catch (Exception e) { + log.warn("Failed to log response", e); + } + } + }); + } + + private String bodyToString(Buffer body) { + if (body == null) { + return ""; + } + return body.toString(); + } + + private String inOneLine(MultiMap headers) { + + return stream(headers.spliterator(), false) + .map(header -> { + String headerKey = header.getKey(); + String headerValue = header.getValue(); + if (headerKey.equals("Authorization")) { + headerValue = maskAuthorizationHeaderValue(headerValue); + } + return String.format("[%s: %s]", headerKey, headerValue); + }) + .collect(joining(", ")); + } + + private static String maskAuthorizationHeaderValue(String authorizationHeaderValue) { + try { + + Matcher matcher = BEARER_PATTERN.matcher(authorizationHeaderValue); + + StringBuilder sb = new StringBuilder(); + while (matcher.find()) { + matcher.appendReplacement(sb, matcher.group(1) + matcher.group(2) + "..." + matcher.group(4)); + } + matcher.appendTail(sb); + + return sb.toString(); + } catch (Exception e) { + return "Failed to mask the API key."; + } + } + } +} diff --git a/mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/runtime/MistralAiRecorder.java b/mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/runtime/MistralAiRecorder.java new file mode 100644 index 000000000..7d688fb5e --- /dev/null +++ b/mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/runtime/MistralAiRecorder.java @@ -0,0 +1,191 @@ +package io.quarkiverse.langchain4j.mistralai.runtime; + +import static io.quarkiverse.langchain4j.runtime.OptionalUtil.firstOrDefault; + +import java.util.function.Supplier; + +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.chat.DisabledChatLanguageModel; +import dev.langchain4j.model.chat.DisabledStreamingChatLanguageModel; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.embedding.DisabledEmbeddingModel; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.mistralai.MistralAiChatModel; +import dev.langchain4j.model.mistralai.MistralAiEmbeddingModel; +import dev.langchain4j.model.mistralai.MistralAiStreamingChatModel; +import io.quarkiverse.langchain4j.mistralai.runtime.config.ChatModelConfig; +import io.quarkiverse.langchain4j.mistralai.runtime.config.EmbeddingModelConfig; +import io.quarkiverse.langchain4j.mistralai.runtime.config.LangChain4jMistralAiConfig; +import io.quarkiverse.langchain4j.runtime.NamedModelUtil; +import io.quarkus.runtime.annotations.Recorder; +import io.smallrye.config.ConfigValidationException; + +@Recorder +public class MistralAiRecorder { + private static final String DUMMY_KEY = "dummy"; + + public Supplier chatModel(LangChain4jMistralAiConfig runtimeConfig, String modelName) { + LangChain4jMistralAiConfig.MistralAiConfig mistralAiConfig = correspondingMistralAiConfig(runtimeConfig, + modelName); + + if (mistralAiConfig.enableIntegration()) { + String apiKey = mistralAiConfig.apiKey(); + ChatModelConfig chatModelConfig = mistralAiConfig.chatModel(); + + if (DUMMY_KEY.equals(apiKey)) { + throw new ConfigValidationException(createApiKeyConfigProblem(modelName)); + } + + var builder = MistralAiChatModel.builder() + .baseUrl(mistralAiConfig.baseUrl()) + .apiKey(apiKey) + .modelName(chatModelConfig.modelName()) + .logRequests(firstOrDefault(false, chatModelConfig.logRequests(), mistralAiConfig.logRequests())) + .logResponses(firstOrDefault(false, chatModelConfig.logResponses(), mistralAiConfig.logResponses())) + .timeout(mistralAiConfig.timeout()); + + if (chatModelConfig.temperature().isPresent()) { + builder.temperature(chatModelConfig.temperature().getAsDouble()); + } + if (chatModelConfig.topP().isPresent()) { + builder.topP(chatModelConfig.topP().getAsDouble()); + } + if (chatModelConfig.maxTokens().isPresent()) { + builder.maxTokens(chatModelConfig.maxTokens().getAsInt()); + } + if (chatModelConfig.safePrompt().isPresent()) { + builder.safePrompt(chatModelConfig.safePrompt().get()); + } + if (chatModelConfig.randomSeed().isPresent()) { + builder.randomSeed(chatModelConfig.randomSeed().getAsInt()); + } + + return new Supplier<>() { + @Override + public ChatLanguageModel get() { + return builder.build(); + } + }; + } else { + return new Supplier<>() { + @Override + public ChatLanguageModel get() { + return new DisabledChatLanguageModel(); + } + }; + } + } + + public Supplier streamingChatModel(LangChain4jMistralAiConfig runtimeConfig, String modelName) { + LangChain4jMistralAiConfig.MistralAiConfig mistralAiConfig = correspondingMistralAiConfig(runtimeConfig, + modelName); + + if (mistralAiConfig.enableIntegration()) { + String apiKey = mistralAiConfig.apiKey(); + ChatModelConfig chatModelConfig = mistralAiConfig.chatModel(); + + if (DUMMY_KEY.equals(apiKey)) { + throw new ConfigValidationException(createApiKeyConfigProblem(modelName)); + } + + var builder = MistralAiStreamingChatModel.builder() + .baseUrl(mistralAiConfig.baseUrl()) + .apiKey(apiKey) + .modelName(chatModelConfig.modelName()) + .logRequests(firstOrDefault(false, chatModelConfig.logRequests(), mistralAiConfig.logRequests())) + .logResponses(firstOrDefault(false, chatModelConfig.logResponses(), mistralAiConfig.logResponses())) + .timeout(mistralAiConfig.timeout()); + + if (chatModelConfig.temperature().isPresent()) { + builder.temperature(chatModelConfig.temperature().getAsDouble()); + } + if (chatModelConfig.topP().isPresent()) { + builder.topP(chatModelConfig.topP().getAsDouble()); + } + if (chatModelConfig.maxTokens().isPresent()) { + builder.maxTokens(chatModelConfig.maxTokens().getAsInt()); + } + if (chatModelConfig.safePrompt().isPresent()) { + builder.safePrompt(chatModelConfig.safePrompt().get()); + } + if (chatModelConfig.randomSeed().isPresent()) { + builder.randomSeed(chatModelConfig.randomSeed().getAsInt()); + } + + return new Supplier<>() { + @Override + public StreamingChatLanguageModel get() { + return builder.build(); + } + }; + } else { + return new Supplier<>() { + @Override + public StreamingChatLanguageModel get() { + return new DisabledStreamingChatLanguageModel(); + } + }; + } + } + + public Supplier embeddingModel(LangChain4jMistralAiConfig runtimeConfig, String modelName) { + LangChain4jMistralAiConfig.MistralAiConfig mistralAiConfig = correspondingMistralAiConfig(runtimeConfig, + modelName); + + if (mistralAiConfig.enableIntegration()) { + String apiKey = mistralAiConfig.apiKey(); + EmbeddingModelConfig embeddingModelConfig = mistralAiConfig.embeddingModel(); + + if (DUMMY_KEY.equals(apiKey)) { + throw new ConfigValidationException(createApiKeyConfigProblem(modelName)); + } + + var builder = MistralAiEmbeddingModel.builder() + .baseUrl(mistralAiConfig.baseUrl()) + .apiKey(apiKey) + .modelName(embeddingModelConfig.modelName()) + .logRequests(firstOrDefault(false, embeddingModelConfig.logRequests(), mistralAiConfig.logRequests())) + .logResponses(firstOrDefault(false, embeddingModelConfig.logResponses(), mistralAiConfig.logResponses())) + .timeout(mistralAiConfig.timeout()); + + return new Supplier<>() { + @Override + public EmbeddingModel get() { + return builder.build(); + } + }; + } else { + return new Supplier<>() { + @Override + public EmbeddingModel get() { + return new DisabledEmbeddingModel(); + } + }; + } + } + + private LangChain4jMistralAiConfig.MistralAiConfig correspondingMistralAiConfig( + LangChain4jMistralAiConfig runtimeConfig, String modelName) { + LangChain4jMistralAiConfig.MistralAiConfig huggingFaceConfig; + if (NamedModelUtil.isDefault(modelName)) { + huggingFaceConfig = runtimeConfig.defaultConfig(); + } else { + huggingFaceConfig = runtimeConfig.namedConfig().get(modelName); + } + return huggingFaceConfig; + } + + private ConfigValidationException.Problem[] createApiKeyConfigProblem(String modelName) { + return createConfigProblems("api-key", modelName); + } + + private ConfigValidationException.Problem[] createConfigProblems(String key, String modelName) { + return new ConfigValidationException.Problem[] { createConfigProblem(key, modelName) }; + } + + private static ConfigValidationException.Problem createConfigProblem(String key, String modelName) { + return new ConfigValidationException.Problem(String.format( + "SRCFG00014: The config property quarkus.langchain4j.mistralai%s%s is required but it could not be found in any config source", + NamedModelUtil.isDefault(modelName) ? "." : ("." + modelName + "."), key)); + } +} diff --git a/mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/runtime/config/ChatModelConfig.java b/mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/runtime/config/ChatModelConfig.java new file mode 100644 index 000000000..1446c045c --- /dev/null +++ b/mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/runtime/config/ChatModelConfig.java @@ -0,0 +1,67 @@ +package io.quarkiverse.langchain4j.mistralai.runtime.config; + +import java.util.Optional; +import java.util.OptionalDouble; +import java.util.OptionalInt; + +import io.quarkus.runtime.annotations.ConfigDocDefault; +import io.quarkus.runtime.annotations.ConfigGroup; +import io.smallrye.config.WithDefault; + +@ConfigGroup +public interface ChatModelConfig { + + /** + * Model name to use + */ + @WithDefault("mistral-tiny") + String modelName(); + + /** + * What sampling temperature to use, between 0.0 and 1.0. Higher values like 0.8 will make the output more random, while + * lower values like 0.2 will make it more focused and deterministic. + *

+ * It is generally recommended to set this or the {@code top-k} property but not both. + */ + @ConfigDocDefault("0.7") + OptionalDouble temperature(); + + /** + * The maximum number of tokens to generate in the completion. + *

+ * The token count of your prompt plus {@code max_tokens} cannot exceed the model's context length + */ + OptionalInt maxTokens(); + + /** + * Double (0.0-1.0). Nucleus sampling, where the model considers the results of the tokens with top_p probability mass. + * So 0.1 means only the tokens comprising the top 10% probability mass are considered. + *

+ * It is generally recommended to set this or the {@code temperature} property but not both. + */ + @ConfigDocDefault("1.0") + OptionalDouble topP(); + + /** + * Whether to inject a safety prompt before all conversations + */ + Optional safePrompt(); + + /** + * The seed to use for random sampling. If set, different calls will generate deterministic results. + */ + OptionalInt randomSeed(); + + /** + * Whether chat model requests should be logged + */ + @ConfigDocDefault("false") + Optional logRequests(); + + /** + * Whether chat model responses should be logged + */ + @ConfigDocDefault("false") + Optional logResponses(); + +} diff --git a/mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/runtime/config/EmbeddingModelConfig.java b/mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/runtime/config/EmbeddingModelConfig.java new file mode 100644 index 000000000..c4fb6e826 --- /dev/null +++ b/mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/runtime/config/EmbeddingModelConfig.java @@ -0,0 +1,30 @@ +package io.quarkiverse.langchain4j.mistralai.runtime.config; + +import java.util.Optional; + +import io.quarkus.runtime.annotations.ConfigDocDefault; +import io.quarkus.runtime.annotations.ConfigGroup; +import io.smallrye.config.WithDefault; + +@ConfigGroup +public interface EmbeddingModelConfig { + + /** + * Model name to use + */ + @WithDefault("mistral-embed") + String modelName(); + + /** + * Whether embedding model requests should be logged + */ + @ConfigDocDefault("false") + Optional logRequests(); + + /** + * Whether embedding model responses should be logged + */ + @ConfigDocDefault("false") + Optional logResponses(); + +} diff --git a/mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/runtime/config/LangChain4jMistralAiConfig.java b/mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/runtime/config/LangChain4jMistralAiConfig.java new file mode 100644 index 000000000..2add62bf8 --- /dev/null +++ b/mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/runtime/config/LangChain4jMistralAiConfig.java @@ -0,0 +1,88 @@ +package io.quarkiverse.langchain4j.mistralai.runtime.config; + +import static io.quarkus.runtime.annotations.ConfigPhase.RUN_TIME; + +import java.time.Duration; +import java.util.Map; +import java.util.Optional; + +import io.quarkus.runtime.annotations.ConfigDocDefault; +import io.quarkus.runtime.annotations.ConfigDocMapKey; +import io.quarkus.runtime.annotations.ConfigDocSection; +import io.quarkus.runtime.annotations.ConfigGroup; +import io.quarkus.runtime.annotations.ConfigRoot; +import io.smallrye.config.ConfigMapping; +import io.smallrye.config.WithDefault; +import io.smallrye.config.WithDefaults; +import io.smallrye.config.WithParentName; + +@ConfigRoot(phase = RUN_TIME) +@ConfigMapping(prefix = "quarkus.langchain4j.mistralai") +public interface LangChain4jMistralAiConfig { + + /** + * Default model config. + */ + @WithParentName + MistralAiConfig defaultConfig(); + + /** + * Named model config. + */ + @ConfigDocSection + @ConfigDocMapKey("model-name") + @WithParentName + @WithDefaults + Map namedConfig(); + + @ConfigGroup + interface MistralAiConfig { + /** + * Base URL of Mistral API + */ + @WithDefault("https://api.mistral.ai/v1/") + String baseUrl(); + + /** + * Mistral API key + */ + @WithDefault("dummy") // TODO: this should be optional but Smallrye Config doesn't like it + String apiKey(); + + /** + * Timeout for Mistral calls + */ + @WithDefault("10s") + Duration timeout(); + + /** + * Chat model related settings + */ + ChatModelConfig chatModel(); + + /** + * Embedding model related settings + */ + EmbeddingModelConfig embeddingModel(); + + /** + * Whether the Mistral client should log requests + */ + @ConfigDocDefault("false") + Optional logRequests(); + + /** + * Whether the Mistral client should log responses + */ + @ConfigDocDefault("false") + Optional logResponses(); + + /** + * Whether or not to enable the integration. Defaults to {@code true}, which means requests are made to the Mistral AI + * provider. + * Set to {@code false} to disable all requests. + */ + @WithDefault("true") + Boolean enableIntegration(); + } +} diff --git a/mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/runtime/jackson/MistralAiRoleMixin.java b/mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/runtime/jackson/MistralAiRoleMixin.java new file mode 100644 index 000000000..36926e72e --- /dev/null +++ b/mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/runtime/jackson/MistralAiRoleMixin.java @@ -0,0 +1,14 @@ +package io.quarkiverse.langchain4j.mistralai.runtime.jackson; + +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.annotation.JsonSerialize; + +import dev.langchain4j.model.mistralai.MistralAiRole; +import io.quarkus.jackson.JacksonMixin; + +@JacksonMixin(MistralAiRole.class) +@JsonSerialize(using = RoleSerializer.class) +@JsonDeserialize(using = RoleDeserializer.class) +public abstract class MistralAiRoleMixin { + +} diff --git a/mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/runtime/jackson/RoleDeserializer.java b/mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/runtime/jackson/RoleDeserializer.java new file mode 100644 index 000000000..c3ba914f1 --- /dev/null +++ b/mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/runtime/jackson/RoleDeserializer.java @@ -0,0 +1,24 @@ +package io.quarkiverse.langchain4j.mistralai.runtime.jackson; + +import java.io.IOException; +import java.util.Locale; + +import com.fasterxml.jackson.core.JacksonException; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.deser.std.StdDeserializer; + +import dev.langchain4j.model.mistralai.MistralAiRole; + +public class RoleDeserializer extends StdDeserializer { + public RoleDeserializer() { + super(MistralAiRole.class); + } + + @Override + public MistralAiRole deserialize(JsonParser jp, DeserializationContext deserializationContext) + throws IOException, JacksonException { + return MistralAiRole.valueOf(jp.getValueAsString().toUpperCase(Locale.ROOT)); + } + +} diff --git a/mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/runtime/jackson/RoleSerializer.java b/mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/runtime/jackson/RoleSerializer.java new file mode 100644 index 000000000..bb331f8a0 --- /dev/null +++ b/mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/runtime/jackson/RoleSerializer.java @@ -0,0 +1,21 @@ +package io.quarkiverse.langchain4j.mistralai.runtime.jackson; + +import java.io.IOException; +import java.util.Locale; + +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.databind.SerializerProvider; +import com.fasterxml.jackson.databind.ser.std.StdSerializer; + +import dev.langchain4j.model.mistralai.MistralAiRole; + +public class RoleSerializer extends StdSerializer { + public RoleSerializer() { + super(MistralAiRole.class); + } + + @Override + public void serialize(MistralAiRole value, JsonGenerator gen, SerializerProvider provider) throws IOException { + gen.writeString(value.toString().toLowerCase(Locale.ROOT)); + } +} diff --git a/mistral/runtime/src/main/resources/META-INF/beans.xml b/mistral/runtime/src/main/resources/META-INF/beans.xml new file mode 100644 index 000000000..e69de29bb diff --git a/mistral/runtime/src/main/resources/META-INF/quarkus-extension.yaml b/mistral/runtime/src/main/resources/META-INF/quarkus-extension.yaml new file mode 100644 index 000000000..3c473d83a --- /dev/null +++ b/mistral/runtime/src/main/resources/META-INF/quarkus-extension.yaml @@ -0,0 +1,17 @@ +name: LangChain4j Mistral AI +artifact: ${project.groupId}:${project.artifactId}:${project.version} +description: Provides integration of Quarkus LangChain4j with Mistral AI +metadata: + keywords: + - ai + - langchain4j + - mistral + # guide: https://quarkiverse.github.io/quarkiverse-docs/langchain4j/dev/ # To create and publish this guide, see https://github.com/quarkiverse/quarkiverse/wiki#documenting-your-extension + categories: + - "miscellaneous" + status: "experimental" + codestart: + name: langchain4j-mistral + languages: + - "java" + artifact: "io.quarkiverse.langchain4j:quarkus-langchain4j-mistral-ai:codestarts:jar:${project.version}" diff --git a/mistral/runtime/src/main/resources/META-INF/services/dev.langchain4j.model.mistralai.MistralAiClientBuilderFactory b/mistral/runtime/src/main/resources/META-INF/services/dev.langchain4j.model.mistralai.MistralAiClientBuilderFactory new file mode 100644 index 000000000..a69c62f5e --- /dev/null +++ b/mistral/runtime/src/main/resources/META-INF/services/dev.langchain4j.model.mistralai.MistralAiClientBuilderFactory @@ -0,0 +1 @@ +io.quarkiverse.langchain4j.mistralai.QuarkusMistralAiClient$QuarkusMistralAiClientBuilderFactory diff --git a/pom.xml b/pom.xml index fe09cc167..268282e54 100644 --- a/pom.xml +++ b/pom.xml @@ -21,6 +21,7 @@ hugging-face infinispan milvus + mistral ollama openai/azure-openai openai/openai-common