Skip to content

Commit

Permalink
Allow instantiating beans for dynamic authorization in OpenAI
Browse files Browse the repository at this point in the history
fixing PR comments

added documentation and renamed Interface to AuthProvider

changed signature of of the "getAuthorization" method

changed signature of the getAuthorization method

made the Authorizer beans unremovable

updated documentation

used ModelName instead of Named annotation

made Input an interface and separated instantiation in a private class
  • Loading branch information
csotiriou committed Jun 3, 2024
1 parent 11c3638 commit 78cd849
Show file tree
Hide file tree
Showing 9 changed files with 182 additions and 10 deletions.
65 changes: 65 additions & 0 deletions docs/modules/ROOT/pages/openai.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,68 @@ We used `null` as the `apiVersion` parameter in the call to `streamingCompletion
This parameter is only required when using Azure OpenAI.
====


=== Dynamic Authorization Headers

There are cases where one may need to provide dynamic authorization headers, to be passed in OpenAI endpoints (in OpenAI or Azure OpenAI)

There are two ways to achieve this:

==== Using a ContainerRequestFilter annotated with `@Provider`.
Since OpenAI client relies in the Quarkus Rest Client, you can have apply a filter that will be called in all OpenAI requests and set the headers accordingly.

[source,java]
----
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;
import jakarta.ws.rs.ext.Provider;
import org.jboss.resteasy.reactive.client.spi.ResteasyReactiveClientRequestContext;
import org.jboss.resteasy.reactive.client.spi.ResteasyReactiveClientRequestFilter;
@Provider
@ApplicationScoped
public class RequestFilter implements ResteasyReactiveClientRequestFilter {
@Inject
MyAuthorizationService myAuthorizationService;
@Override
public void filter(ResteasyReactiveClientRequestContext requestContext) {
/*
* All requests will be filtered here, therefore make sure that you make
* the necessary checks to avoid putting the Authorization header in
* requests that do not need it.
*/
requestContext.getHeaders().putSingle("Authorization", ...);
}
}
----

==== Using `AuthProvider`
One can implement the `AuthProvider` interface and provide the implementation of the `getAuthorization` method.

This is useful when you need to provide different authorization headers for different OpenAI models. The `@Named` annotation can be used to specify the model name in this scenario.

[source,java]
----
import io.quarkiverse.langchain4j.ModelName;
import io.quarkiverse.langchain4j.openai.OpenAiRestApi;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;
@ApplicationScoped
@ModelName("my-model-name") //you can omit this if you have only one model or if you want to use the default model
public class TestClass implements OpenAiRestApi.AuthProvider {
@Inject MyTokenProviderService tokenProviderService;
@Override
public String getAuthorization(Input input) {
/*
the `input` will contain some information about the request
about to be passed to the remote model endpoints
*/
return "Bearer " + tokenProviderService.getToken();
}
}
----

Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ public AzureOpenAiChatModel(String endpoint,
Proxy proxy,
String responseFormat,
Boolean logRequests,
Boolean logResponses) {
Boolean logResponses,
String configName) {

timeout = getOrDefault(timeout, ofSeconds(60));

Expand All @@ -86,6 +87,7 @@ public AzureOpenAiChatModel(String endpoint,
.userAgent(Consts.DEFAULT_USER_AGENT)
.azureAdToken(adToken)
.azureApiKey(apiKey)
.configName(configName)
.build();

this.temperature = getOrDefault(temperature, 0.7);
Expand Down Expand Up @@ -172,6 +174,7 @@ public static class Builder {
private String responseFormat;
private Boolean logRequests;
private Boolean logResponses;
private String configName;

/**
* Sets the Azure OpenAI endpoint. This is a mandatory parameter.
Expand All @@ -196,6 +199,11 @@ public Builder apiVersion(String apiVersion) {
return this;
}

public Builder configName(String configName) {
this.configName = configName;
return this;
}

/**
* Sets the Azure OpenAI API key. This is a mandatory parameter.
*
Expand Down Expand Up @@ -288,7 +296,8 @@ public AzureOpenAiChatModel build() {
proxy,
responseFormat,
logRequests,
logResponses);
logResponses,
configName);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ public AzureOpenAiEmbeddingModel(String endpoint,
Integer maxRetries,
Proxy proxy,
Boolean logRequests,
Boolean logResponses) {
Boolean logResponses,
String configName) {

timeout = getOrDefault(timeout, ofSeconds(60));
if (maxRetries < 1) {
Expand All @@ -78,6 +79,7 @@ public AzureOpenAiEmbeddingModel(String endpoint,
.userAgent(DEFAULT_USER_AGENT)
.azureAdToken(adToken)
.azureApiKey(apiKey)
.configName(configName)
.build();
this.maxRetries = getOrDefault(maxRetries, 3);
this.tokenizer = tokenizer;
Expand Down Expand Up @@ -149,6 +151,7 @@ public static class Builder {
private Boolean logRequests;
private Boolean logResponses;
private String adToken;
private String configName;

/**
* Sets the Azure OpenAI endpoint. This is a mandatory parameter.
Expand Down Expand Up @@ -219,6 +222,11 @@ public Builder logResponses(Boolean logResponses) {
return this;
}

public Builder configName(String configName) {
this.configName = configName;
return this;
}

public AzureOpenAiEmbeddingModel build() {
return new AzureOpenAiEmbeddingModel(endpoint,
apiVersion,
Expand All @@ -229,7 +237,8 @@ public AzureOpenAiEmbeddingModel build() {
maxRetries,
proxy,
logRequests,
logResponses);
logResponses,
configName);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public AzureOpenAiImageModel(String endpoint, String apiKey, String adToken, Str
String size,
String quality, String style, Optional<String> user, String responseFormat, Duration timeout,
Integer maxRetries, Boolean logRequests, Boolean logResponses,
Optional<Path> persistDirectory) {
Optional<Path> persistDirectory, String configName) {
this.modelName = modelName;
this.size = size;
this.quality = quality;
Expand All @@ -69,6 +69,7 @@ public AzureOpenAiImageModel(String endpoint, String apiKey, String adToken, Str
.userAgent(DEFAULT_USER_AGENT)
.azureAdToken(adToken)
.azureApiKey(apiKey)
.configName(configName)
.build();
}

Expand Down Expand Up @@ -159,6 +160,7 @@ public static class Builder {
private Boolean logRequests;
private Boolean logResponses;
private Optional<Path> persistDirectory;
private String configName;

public Builder endpoint(String endpoint) {
this.endpoint = endpoint;
Expand Down Expand Up @@ -235,10 +237,15 @@ public Builder persistDirectory(Optional<Path> persistDirectory) {
return this;
}

public Builder configName(String configName) {
this.configName = configName;
return this;
}

public AzureOpenAiImageModel build() {
return new AzureOpenAiImageModel(endpoint, apiKey, adToken, apiVersion, modelName, size, quality, style, user,
responseFormat, timeout, maxRetries, logRequests, logResponses,
persistDirectory);
persistDirectory, configName);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ public AzureOpenAiStreamingChatModel(String endpoint,
Proxy proxy,
String responseFormat,
Boolean logRequests,
Boolean logResponses) {
Boolean logResponses,
String configName) {

timeout = getOrDefault(timeout, ofSeconds(60));

Expand All @@ -91,6 +92,7 @@ public AzureOpenAiStreamingChatModel(String endpoint,
.userAgent(DEFAULT_USER_AGENT)
.azureAdToken(adToken)
.azureApiKey(apiKey)
.configName(configName)
.build();
this.temperature = getOrDefault(temperature, 0.7);
this.topP = topP;
Expand Down Expand Up @@ -203,6 +205,7 @@ public static class Builder {
private String responseFormat;
private Boolean logRequests;
private Boolean logResponses;
private String configName;

/**
* Sets the Azure OpenAI endpoint. This is a mandatory parameter.
Expand Down Expand Up @@ -298,6 +301,11 @@ public Builder logResponses(Boolean logResponses) {
return this;
}

public Builder configName(String configName) {
this.configName = configName;
return this;
}

public AzureOpenAiStreamingChatModel build() {
return new AzureOpenAiStreamingChatModel(endpoint,
apiVersion,
Expand All @@ -313,7 +321,8 @@ public AzureOpenAiStreamingChatModel build() {
proxy,
responseFormat,
logRequests,
logResponses);
logResponses,
configName);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ public Supplier<ChatLanguageModel> chatModel(LangChain4jAzureOpenAiConfig runtim

var builder = AzureOpenAiChatModel.builder()
.endpoint(getEndpoint(azureAiConfig, configName, EndpointType.CHAT))
.configName(NamedConfigUtil.isDefault(configName) ? null : configName)
.apiKey(apiKey)
.adToken(adToken)
.apiVersion(azureAiConfig.apiVersion())
Expand Down Expand Up @@ -97,6 +98,7 @@ public Supplier<StreamingChatLanguageModel> streamingChatModel(LangChain4jAzureO
.endpoint(getEndpoint(azureAiConfig, configName, EndpointType.CHAT))
.apiKey(apiKey)
.adToken(adToken)
.configName(NamedConfigUtil.isDefault(configName) ? null : configName)
.apiVersion(azureAiConfig.apiVersion())
.timeout(azureAiConfig.timeout().orElse(Duration.ofSeconds(10)))
.logRequests(chatModelConfig.logRequests().orElse(false))
Expand Down Expand Up @@ -141,6 +143,7 @@ public Supplier<EmbeddingModel> embeddingModel(LangChain4jAzureOpenAiConfig runt
.endpoint(getEndpoint(azureAiConfig, configName, EndpointType.EMBEDDING))
.apiKey(apiKey)
.adToken(adToken)
.configName(NamedConfigUtil.isDefault(configName) ? null : configName)
.apiVersion(azureAiConfig.apiVersion())
.timeout(azureAiConfig.timeout().orElse(Duration.ofSeconds(10)))
.maxRetries(azureAiConfig.maxRetries())
Expand Down Expand Up @@ -182,6 +185,7 @@ public Supplier<ImageModel> imageModel(LangChain4jAzureOpenAiConfig runtimeConfi
.logRequests(imageModelConfig.logRequests().orElse(false))
.logResponses(imageModelConfig.logResponses().orElse(false))
.modelName(imageModelConfig.modelName())
.configName(NamedConfigUtil.isDefault(configName) ? null : configName)
.size(imageModelConfig.size())
.quality(imageModelConfig.quality())
.style(imageModelConfig.style())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import com.knuddels.jtokkit.Encodings;

import io.quarkiverse.langchain4j.openai.OpenAiRestApi;
import io.quarkus.arc.deployment.UnremovableBeanBuildItem;
import io.quarkus.deployment.annotations.BuildProducer;
import io.quarkus.deployment.annotations.BuildStep;
import io.quarkus.deployment.builditem.IndexDependencyBuildItem;
Expand All @@ -27,6 +29,11 @@ void indexDependencies(BuildProducer<IndexDependencyBuildItem> producer) {
producer.produce(new IndexDependencyBuildItem("dev.ai4j", "openai4j"));
}

@BuildStep
UnremovableBeanBuildItem unremovableBeans() {
return UnremovableBeanBuildItem.beanTypes(OpenAiRestApi.AuthProvider.class);
}

@BuildStep
void nativeImageSupport(BuildProducer<NativeImageResourceBuildItem> resourcesProducer) {
registerJtokkitResources(resourcesProducer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.io.OutputStream;
import java.lang.annotation.Annotation;
import java.lang.reflect.Type;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.function.Predicate;
Expand Down Expand Up @@ -38,6 +39,8 @@
import org.jboss.resteasy.reactive.client.SseEvent;
import org.jboss.resteasy.reactive.client.SseEventFilter;
import org.jboss.resteasy.reactive.client.api.ClientLogger;
import org.jboss.resteasy.reactive.client.spi.ResteasyReactiveClientRequestContext;
import org.jboss.resteasy.reactive.client.spi.ResteasyReactiveClientRequestFilter;
import org.jboss.resteasy.reactive.common.providers.serialisers.AbstractJsonMessageBodyReader;

import com.fasterxml.jackson.databind.ObjectMapper;
Expand Down Expand Up @@ -179,6 +182,40 @@ public boolean test(SseEvent<String> event) {
}
}

interface AuthProvider {
String getAuthorization(Input input);

interface Input {
String method();
URI uri();
MultivaluedMap<String, Object> headers();
}
}

class OpenAIRestAPIFilter implements ResteasyReactiveClientRequestFilter {
AuthProvider authorizer;

public OpenAIRestAPIFilter(AuthProvider authorizer) {
this.authorizer = authorizer;
}

@Override
public void filter(ResteasyReactiveClientRequestContext requestContext) {
requestContext.getHeaders().putSingle("Authorization", authorizer.getAuthorization(
new AuthInputImpl(requestContext.getMethod(), requestContext.getUri(), requestContext.getHeaders())
));
}

private record AuthInputImpl(
String method,
URI uri,
MultivaluedMap<String, Object> headers
) implements AuthProvider.Input { }
}




@Priority(Priorities.USER + 100) // this priority ensures that our Writer has priority over the standard Jackson one
class OpenAiRestApiJacksonWriter implements MessageBodyWriter<Object> {

Expand Down
Loading

0 comments on commit 78cd849

Please sign in to comment.