Skip to content

Commit

Permalink
Merge pull request #708 from sberyozkin/oidc_model_auth_provider
Browse files Browse the repository at this point in the history
Update Vertex AI Gemini provider to use ModelAuthProvider
  • Loading branch information
jmartisk committed Jul 8, 2024
2 parents aab1101 + 75732f0 commit d1b5200
Show file tree
Hide file tree
Showing 21 changed files with 768 additions and 20 deletions.
44 changes: 44 additions & 0 deletions model-auth-providers/oidc-model-auth-provider/deployment/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-oidc-model-auth-provider-parent</artifactId>
<version>999-SNAPSHOT</version>
</parent>
<artifactId>quarkus-langchain4j-oidc-model-auth-provider-deployment</artifactId>
<name>Quarkus LangChain4j - OpenId Connect (OIDC) ModelAuthProvider - Deployment</name>
<dependencies>
<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-oidc-model-auth-provider</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-core-deployment</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-junit5-internal</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<annotationProcessorPaths>
<path>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-extension-processor</artifactId>
<version>${quarkus.version}</version>
</path>
</annotationProcessorPaths>
</configuration>
</plugin>
</plugins>
</build>
</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package io.quarkiverse.langchain4j.oidc.deployment;

import static io.quarkus.runtime.annotations.ConfigPhase.BUILD_TIME;

import java.util.Optional;

import io.quarkus.runtime.annotations.ConfigDocDefault;
import io.quarkus.runtime.annotations.ConfigRoot;
import io.smallrye.config.ConfigMapping;

@ConfigRoot(phase = BUILD_TIME)
@ConfigMapping(prefix = "quarkus.langchain4j.oidc-model-auth-provider")
public interface OidcModelAuthProviderBuildConfig {
/**
* Whether the OIDC ModelAuthProvider should be enabled
*/
@ConfigDocDefault("true")
Optional<Boolean> enabled();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package io.quarkiverse.langchain4j.oidc.deployment;

import java.util.function.BooleanSupplier;

import io.quarkiverse.langchain4j.oidc.runtime.OidcModelAuthProvider;
import io.quarkus.arc.deployment.AdditionalBeanBuildItem;
import io.quarkus.deployment.annotations.BuildProducer;
import io.quarkus.deployment.annotations.BuildStep;
import io.quarkus.deployment.annotations.BuildSteps;
import io.quarkus.deployment.builditem.FeatureBuildItem;

@BuildSteps(onlyIf = OidcModelAuthProviderProcessor.IsEnabled.class)
public class OidcModelAuthProviderProcessor {
private static final String FEATURE = "langchain4j-oidc-model-auth-provider";

@BuildStep
FeatureBuildItem feature() {
return new FeatureBuildItem(FEATURE);
}

@BuildStep
public void additionalBeans(BuildProducer<AdditionalBeanBuildItem> additionalBeans) {
AdditionalBeanBuildItem.Builder builder = AdditionalBeanBuildItem.builder().setUnremovable();
builder.addBeanClass(OidcModelAuthProvider.class);
additionalBeans.produce(builder.build());
}

public static class IsEnabled implements BooleanSupplier {
OidcModelAuthProviderBuildConfig config;

public boolean getAsBoolean() {
return config.enabled().orElse(true);
}
}
}
20 changes: 20 additions & 0 deletions model-auth-providers/oidc-model-auth-provider/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-parent</artifactId>
<version>999-SNAPSHOT</version>
<relativePath>../../pom.xml</relativePath>
</parent>
<artifactId>quarkus-langchain4j-oidc-model-auth-provider-parent</artifactId>
<name>Quarkus LangChain4j - OpenId Connect (OIDC) ModelAuthProvider - Parent</name>
<packaging>pom</packaging>

<modules>
<module>deployment</module>
<module>runtime</module>
</modules>


</project>
65 changes: 65 additions & 0 deletions model-auth-providers/oidc-model-auth-provider/runtime/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-oidc-model-auth-provider-parent</artifactId>
<version>999-SNAPSHOT</version>
</parent>
<artifactId>quarkus-langchain4j-oidc-model-auth-provider</artifactId>
<name>Quarkus LangChain4j - OpenId Connect (OIDC) ModelAuthProvider - Runtime</name>
<dependencies>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-arc</artifactId>
</dependency>
<dependency>
<groupId>io.quarkus.security</groupId>
<artifactId>quarkus-security</artifactId>
</dependency>
<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-core</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-junit5-internal</artifactId>
<scope>test</scope>
</dependency>
</dependencies>

<build>
<plugins>
<plugin>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-extension-maven-plugin</artifactId>
<version>${quarkus.version}</version>
<executions>
<execution>
<phase>compile</phase>
<goals>
<goal>extension-descriptor</goal>
</goals>
<configuration>
<deployment>${project.groupId}:${project.artifactId}-deployment:${project.version}</deployment>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<annotationProcessorPaths>
<path>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-extension-processor</artifactId>
<version>${quarkus.version}</version>
</path>
</annotationProcessorPaths>
</configuration>
</plugin>
</plugins>
</build>

</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package io.quarkiverse.langchain4j.oidc.runtime;

import jakarta.enterprise.inject.Instance;
import jakarta.inject.Inject;

import io.quarkiverse.langchain4j.runtime.auth.ModelAuthProvider;
import io.quarkus.security.credential.TokenCredential;

public class OidcModelAuthProvider implements ModelAuthProvider {
@Inject
Instance<TokenCredential> tokenCredential;

@Override
public String getAuthorization(Input input) {
return tokenCredential.isResolvable() ? "Bearer " + tokenCredential.get().getToken() : null;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
name: LangChain4j OpenId Connect (OIDC) ModelAuthProvider
artifact: ${project.groupId}:${project.artifactId}:${project.version}
description: Provides ModelAuthProvider which uses OIDC bearer or authorization code flow access tokens
metadata:
keywords:
- ai
- langchain4j
- oidc
- security
guide: "https://docs.quarkiverse.io/quarkus-langchain4j/dev/index.html"
categories:
- "security"
status: "experimental"

Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
import com.github.tomakehurst.wiremock.verification.LoggedRequest;

import dev.langchain4j.model.chat.ChatLanguageModel;
import io.quarkiverse.langchain4j.runtime.auth.ModelAuthProvider;
import io.quarkiverse.langchain4j.testing.internal.WiremockAware;
import io.quarkiverse.langchain4j.vertexai.runtime.gemini.VertexAiGeminiChatLanguageModel;
import io.quarkiverse.langchain4j.vertexai.runtime.gemini.VertxAiGeminiRestApi;
import io.quarkus.arc.ClientProxy;
import io.quarkus.test.QuarkusUnitTest;

Expand Down Expand Up @@ -111,11 +111,13 @@ void test() {
}

@Singleton
public static class DummyAuthProvider implements VertxAiGeminiRestApi.AuthProvider {
public static class DummyAuthProvider implements ModelAuthProvider {

@Override
public String getBearerToken() {
return API_KEY;
public String getAuthorization(Input input) {
return "Bearer " + API_KEY;
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,19 @@

import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.URI;
import java.util.concurrent.ExecutorService;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import jakarta.enterprise.context.ApplicationScoped;
import jakarta.enterprise.inject.Instance;
import jakarta.inject.Inject;
import jakarta.ws.rs.BeanParam;
import jakarta.ws.rs.POST;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.core.MultivaluedMap;

import org.eclipse.microprofile.context.ManagedExecutor;
import org.eclipse.microprofile.rest.client.annotation.RegisterProvider;
import org.jboss.logging.Logger;
import org.jboss.resteasy.reactive.RestPath;
Expand All @@ -25,7 +29,9 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.auth.oauth2.GoogleCredentials;

import io.quarkus.arc.DefaultBean;
import io.quarkiverse.langchain4j.runtime.auth.ModelAuthProvider;
import io.quarkiverse.langchain4j.runtime.auth.ModelAuthProvider.Input;
import io.quarkiverse.langchain4j.vertexai.runtime.gemini.config.ChatModelConfig;
import io.quarkus.rest.client.reactive.jackson.ClientObjectMapper;
import io.vertx.core.Handler;
import io.vertx.core.MultiMap;
Expand Down Expand Up @@ -102,17 +108,10 @@ public ApiMetadata build() {
}
}

interface AuthProvider {

String getBearerToken();
}

@ApplicationScoped
@DefaultBean
class ApplicationDefaultAuthProvider implements AuthProvider {
class ApplicationDefaultAuthProvider implements ModelAuthProvider {

@Override
public String getBearerToken() {
public String getAuthorization(Input input) {
try {
var credentials = GoogleCredentials.getApplicationDefault();
credentials.refreshIfExpired();
Expand All @@ -126,11 +125,17 @@ public String getBearerToken() {
class TokenFilter implements ResteasyReactiveClientRequestFilter {

private final ExecutorService executorService;
private final AuthProvider authProvider;
private final ModelAuthProvider defaultAuthorizer;
private final ModelAuthProvider authorizer;

@Inject
Instance<ChatModelConfig> model;

public TokenFilter(ExecutorService executorService, AuthProvider authProvider) {
public TokenFilter(ManagedExecutor executorService) {
this.executorService = executorService;
this.authProvider = authProvider;
this.defaultAuthorizer = new ApplicationDefaultAuthProvider();
this.authorizer = ModelAuthProvider.resolve(
model != null && model.isResolvable() ? model.get().modelId() : null).orElse(null);
}

@Override
Expand All @@ -140,14 +145,25 @@ public void filter(ResteasyReactiveClientRequestContext context) {
@Override
public void run() {
try {
context.getHeaders().add("Authorization", "Bearer " + authProvider.getBearerToken());
final Input authInput = new AuthInputImpl(context.getMethod(), context.getUri(), context.getHeaders());
String authorization = authorizer != null ? authorizer.getAuthorization(authInput) : null;
if (authorization == null) {
authorization = defaultAuthorizer.getAuthorization(authInput);
}
context.getHeaders().add("Authorization", authorization);
context.resume();
} catch (Exception e) {
context.resume(e);
}
}
});
}

private record AuthInputImpl(
String method,
URI uri,
MultivaluedMap<String, Object> headers) implements ModelAuthProvider.Input {
}
}

class VertxAiClientLogger implements ClientLogger {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ interface VertexAiGeminiConfig {
Optional<String> baseUrl();

/**
* Whether to enable the integration. Defaults to {@code true}, which means requests are made to the Anthropic
* Whether to enable the integration. Defaults to {@code true}, which means requests are made to the Vertex AI Gemini
* provider.
* Set to {@code false} to disable all requests.
*/
Expand Down
3 changes: 3 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
<module>model-providers/vertex-ai-gemini</module>
<module>model-providers/watsonx</module>

<module>model-auth-providers/oidc-model-auth-provider</module>

<module>quarkus-integrations/websockets-next</module>

<module>rag/easy-rag</module>
Expand Down Expand Up @@ -199,6 +201,7 @@
<module>samples/review-triage</module>
<module>samples/fraud-detection</module>
<module>samples/secure-fraud-detection</module>
<module>samples/secure-vertex-ai-gemini-poem</module>
<module>samples/chatbot</module>
<module>samples/chatbot-easy-rag</module>
<module>samples/sql-chatbot</module>
Expand Down
Loading

0 comments on commit d1b5200

Please sign in to comment.