Skip to content

Commit

Permalink
WIP - Introduce a Mistral AI module
Browse files Browse the repository at this point in the history
Closes: #371
  • Loading branch information
geoand committed Mar 14, 2024
1 parent 5b336f9 commit 4437d1c
Show file tree
Hide file tree
Showing 26 changed files with 1,457 additions and 3 deletions.
4 changes: 2 additions & 2 deletions docs/modules/ROOT/pages/includes/attributes.adoc
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
:project-version: 0.9.1
:langchain4j-version: 0.27.1
:examples-dir: ./../examples/
:langchain4j-version: 0.28.0
:examples-dir: ./../examples/
126 changes: 126 additions & 0 deletions integration-tests/mistralai/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-integration-tests-parent</artifactId>
<version>999-SNAPSHOT</version>
</parent>
<artifactId>quarkus-langchain4j-integration-test-mistralai</artifactId>
<name>Quarkus LangChain4j - Integration Tests - MistralAI</name>
<properties>
<skipITs>true</skipITs>
</properties>
<dependencies>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-resteasy-reactive-jackson</artifactId>
</dependency>
<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-mistral-ai</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-micrometer</artifactId>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-smallrye-fault-tolerance</artifactId>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-junit5</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>io.rest-assured</groupId>
<artifactId>rest-assured</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<version>${assertj.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-devtools-testing</artifactId>
<scope>test</scope>
</dependency>

<!-- Make sure the deployment artifact is built before executing this module -->
<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-mistral-ai-deployment</artifactId>
<version>${project.version}</version>
<type>pom</type>
<scope>test</scope>
<exclusions>
<exclusion>
<groupId>*</groupId>
<artifactId>*</artifactId>
</exclusion>
</exclusions>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-maven-plugin</artifactId>
<executions>
<execution>
<goals>
<goal>build</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<artifactId>maven-failsafe-plugin</artifactId>
<executions>
<execution>
<goals>
<goal>integration-test</goal>
<goal>verify</goal>
</goals>
<configuration>
<systemPropertyVariables>
<native.image.path>${project.build.directory}/${project.build.finalName}-runner</native.image.path>
<java.util.logging.manager>org.jboss.logmanager.LogManager</java.util.logging.manager>
<maven.home>${maven.home}</maven.home>
</systemPropertyVariables>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
<profiles>
<profile>
<id>native-image</id>
<activation>
<property>
<name>native</name>
</property>
</activation>
<build>
<plugins>
<plugin>
<artifactId>maven-surefire-plugin</artifactId>
<configuration>
<skipTests>${native.surefire.skip}</skipTests>
</configuration>
</plugin>
</plugins>
</build>
<properties>
<skipITs>false</skipITs>
<quarkus.package.type>native</quarkus.package.type>
</properties>
</profile>
</profiles>
</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package org.acme.example.mistralai.chat;

import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.core.MediaType;

import org.jboss.resteasy.reactive.RestStreamElementType;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.mistralai.MistralAiChatModel;
import dev.langchain4j.model.mistralai.MistralAiStreamingChatModel;
import dev.langchain4j.model.output.Response;
import io.smallrye.mutiny.Multi;

@Path("chat")
public class ChatLanguageModelResource {

private final MistralAiChatModel chatModel;
private final MistralAiStreamingChatModel streamingChatModel;

public ChatLanguageModelResource() {
this.chatModel = MistralAiChatModel.builder().apiKey("MISTRAL_KEY").logRequests(true)
.logResponses(true).build();
this.streamingChatModel = MistralAiStreamingChatModel.builder().apiKey("MISTRAL_KEY").logRequests(true)
.logResponses(true).build();
}

@GET
@Path("blocking")
public String blocking() {
return chatModel.generate("When was the nobel prize for economics first awarded?");
}

@GET
@Path("streaming")
@RestStreamElementType(MediaType.TEXT_PLAIN)
public Multi<String> streaming() {
return Multi.createFrom().emitter(emitter -> {
streamingChatModel.generate("When was the nobel prize for economics first awarded?",
new StreamingResponseHandler<>() {
@Override
public void onNext(String token) {
emitter.emit(token);
}

@Override
public void onError(Throwable error) {
emitter.fail(error);
}

@Override
public void onComplete(Response<AiMessage> response) {
emitter.complete();
}
});
});

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package org.acme.example.mistralai.chat;

import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path;

import dev.langchain4j.model.mistralai.MistralAiEmbeddingModel;

@Path("embedding")
public class EmbeddingModelResource {

private final MistralAiEmbeddingModel embeddingModel;

public EmbeddingModelResource() {
this.embeddingModel = MistralAiEmbeddingModel.builder().apiKey("MISTRAL_KEY").logRequests(true)
.logResponses(false).build();
}

@GET
public int blocking() {
return embeddingModel.embed("When was the nobel prize for economics first awarded?").content().dimension();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package org.acme.example.mistralai.chat;

import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path;

import dev.langchain4j.model.mistralai.MistralAiClient;
import dev.langchain4j.model.mistralai.MistralAiModelResponse;

@Path("models")
public class ModelsResource {

private final MistralAiClient client;

public ModelsResource() {
this.client = MistralAiClient.builder().apiKey("YqbbnSRf67zHXySuiPCi2sySfmstvEBP").logRequests(true)
.logResponses(true).build();
}

@GET
public MistralAiModelResponse models() {
return client.listModels();
}
}
1 change: 1 addition & 0 deletions integration-tests/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
<module>simple-ollama</module>
<module>azure-openai</module>
<module>multiple-providers</module>
<module>mistralai</module>
<module>devui</module>
<module>embed-all-minilm-l6-v2-q</module>
<module>embed-all-minilm-l6-v2</module>
Expand Down
62 changes: 62 additions & 0 deletions mistral/deployment/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-mistral-ai-parent</artifactId>
<version>999-SNAPSHOT</version>
</parent>
<artifactId>quarkus-langchain4j-mistral-ai-deployment</artifactId>
<name>Quarkus LangChain4j - Mistral AI - Deployment</name>
<dependencies>
<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-mistral-ai</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-rest-client-reactive-jackson-deployment</artifactId>
</dependency>
<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-core-deployment</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-junit5-internal</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<version>${assertj.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.wiremock</groupId>
<artifactId>wiremock-standalone</artifactId>
<version>${wiremock.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<annotationProcessorPaths>
<path>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-extension-processor</artifactId>
<version>${quarkus.version}</version>
</path>
</annotationProcessorPaths>
</configuration>
</plugin>
</plugins>
</build>
</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package io.quarkiverse.langchain4j.mistralai.deployment;

import java.util.Optional;

import io.quarkus.runtime.annotations.ConfigDocDefault;
import io.quarkus.runtime.annotations.ConfigGroup;

@ConfigGroup
public interface ChatModelBuildConfig {

/**
* Whether the model should be enabled
*/
@ConfigDocDefault("true")
Optional<Boolean> enabled();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package io.quarkiverse.langchain4j.mistralai.deployment;

import java.util.Optional;

import io.quarkus.runtime.annotations.ConfigDocDefault;
import io.quarkus.runtime.annotations.ConfigGroup;

@ConfigGroup
public interface EmbeddingModelBuildConfig {

/**
* Whether the model should be enabled
*/
@ConfigDocDefault("true")
Optional<Boolean> enabled();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package io.quarkiverse.langchain4j.mistralai.deployment;

import static io.quarkus.runtime.annotations.ConfigPhase.BUILD_TIME;

import io.quarkus.runtime.annotations.ConfigRoot;
import io.smallrye.config.ConfigMapping;

@ConfigRoot(phase = BUILD_TIME)
@ConfigMapping(prefix = "quarkus.langchain4j.mistralai")
public interface LangChain4jMistralAiBuildConfig {

/**
* Chat model related settings
*/
ChatModelBuildConfig chatModel();

/**
* Embedding model related settings
*/
EmbeddingModelBuildConfig embeddingModel();

}
Loading

0 comments on commit 4437d1c

Please sign in to comment.