Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow dynamic authorization headers for OpenAI + Azure OpenAI #646

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions docs/modules/ROOT/pages/openai.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,68 @@ We used `null` as the `apiVersion` parameter in the call to `streamingCompletion
This parameter is only required when using Azure OpenAI.
====


=== Dynamic Authorization Headers

There are cases where one may need to provide dynamic authorization headers, to be passed in OpenAI endpoints (in OpenAI or Azure OpenAI)

There are two ways to achieve this:

==== Using a ContainerRequestFilter annotated with `@Provider`.
Since OpenAI client relies in the Quarkus Rest Client, you can have apply a filter that will be called in all OpenAI requests and set the headers accordingly.

[source,java]
----
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;
import jakarta.ws.rs.ext.Provider;
import org.jboss.resteasy.reactive.client.spi.ResteasyReactiveClientRequestContext;
import org.jboss.resteasy.reactive.client.spi.ResteasyReactiveClientRequestFilter;

@Provider
@ApplicationScoped
public class RequestFilter implements ResteasyReactiveClientRequestFilter {

@Inject
MyAuthorizationService myAuthorizationService;

@Override
public void filter(ResteasyReactiveClientRequestContext requestContext) {
/*
* All requests will be filtered here, therefore make sure that you make
* the necessary checks to avoid putting the Authorization header in
* requests that do not need it.
*/
requestContext.getHeaders().putSingle("Authorization", ...);
}
}
----

==== Using `AuthProvider`
One can implement the `AuthProvider` interface and provide the implementation of the `getAuthorization` method.

This is useful when you need to provide different authorization headers for different OpenAI models. The `@Named` annotation can be used to specify the model name in this scenario.

[source,java]
----
import io.quarkiverse.langchain4j.ModelName;
import io.quarkiverse.langchain4j.openai.OpenAiRestApi;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;

@ApplicationScoped
@ModelName("my-model-name") //you can omit this if you have only one model or if you want to use the default model
public class TestClass implements OpenAiRestApi.AuthProvider {
@Inject MyTokenProviderService tokenProviderService;

@Override
public String getAuthorization(Input input) {
/*
the `input` will contain some information about the request
about to be passed to the remote model endpoints
*/
return "Bearer " + tokenProviderService.getToken();
}
}
----

Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ public AzureOpenAiChatModel(String endpoint,
Proxy proxy,
String responseFormat,
Boolean logRequests,
Boolean logResponses) {
Boolean logResponses,
String configName) {

timeout = getOrDefault(timeout, ofSeconds(60));

Expand All @@ -86,6 +87,7 @@ public AzureOpenAiChatModel(String endpoint,
.userAgent(Consts.DEFAULT_USER_AGENT)
.azureAdToken(adToken)
.azureApiKey(apiKey)
.configName(configName)
.build();

this.temperature = getOrDefault(temperature, 0.7);
Expand Down Expand Up @@ -172,6 +174,7 @@ public static class Builder {
private String responseFormat;
private Boolean logRequests;
private Boolean logResponses;
private String configName;

/**
* Sets the Azure OpenAI endpoint. This is a mandatory parameter.
Expand All @@ -196,6 +199,11 @@ public Builder apiVersion(String apiVersion) {
return this;
}

public Builder configName(String configName) {
this.configName = configName;
return this;
}

/**
* Sets the Azure OpenAI API key. This is a mandatory parameter.
*
Expand Down Expand Up @@ -288,7 +296,8 @@ public AzureOpenAiChatModel build() {
proxy,
responseFormat,
logRequests,
logResponses);
logResponses,
configName);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ public AzureOpenAiEmbeddingModel(String endpoint,
Integer maxRetries,
Proxy proxy,
Boolean logRequests,
Boolean logResponses) {
Boolean logResponses,
String configName) {

timeout = getOrDefault(timeout, ofSeconds(60));
if (maxRetries < 1) {
Expand All @@ -78,6 +79,7 @@ public AzureOpenAiEmbeddingModel(String endpoint,
.userAgent(DEFAULT_USER_AGENT)
.azureAdToken(adToken)
.azureApiKey(apiKey)
.configName(configName)
.build();
this.maxRetries = getOrDefault(maxRetries, 3);
this.tokenizer = tokenizer;
Expand Down Expand Up @@ -149,6 +151,7 @@ public static class Builder {
private Boolean logRequests;
private Boolean logResponses;
private String adToken;
private String configName;

/**
* Sets the Azure OpenAI endpoint. This is a mandatory parameter.
Expand Down Expand Up @@ -219,6 +222,11 @@ public Builder logResponses(Boolean logResponses) {
return this;
}

public Builder configName(String configName) {
this.configName = configName;
return this;
}

public AzureOpenAiEmbeddingModel build() {
return new AzureOpenAiEmbeddingModel(endpoint,
apiVersion,
Expand All @@ -229,7 +237,8 @@ public AzureOpenAiEmbeddingModel build() {
maxRetries,
proxy,
logRequests,
logResponses);
logResponses,
configName);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public AzureOpenAiImageModel(String endpoint, String apiKey, String adToken, Str
String size,
String quality, String style, Optional<String> user, String responseFormat, Duration timeout,
Integer maxRetries, Boolean logRequests, Boolean logResponses,
Optional<Path> persistDirectory) {
Optional<Path> persistDirectory, String configName) {
this.modelName = modelName;
this.size = size;
this.quality = quality;
Expand All @@ -69,6 +69,7 @@ public AzureOpenAiImageModel(String endpoint, String apiKey, String adToken, Str
.userAgent(DEFAULT_USER_AGENT)
.azureAdToken(adToken)
.azureApiKey(apiKey)
.configName(configName)
.build();
}

Expand Down Expand Up @@ -159,6 +160,7 @@ public static class Builder {
private Boolean logRequests;
private Boolean logResponses;
private Optional<Path> persistDirectory;
private String configName;

public Builder endpoint(String endpoint) {
this.endpoint = endpoint;
Expand Down Expand Up @@ -235,10 +237,15 @@ public Builder persistDirectory(Optional<Path> persistDirectory) {
return this;
}

public Builder configName(String configName) {
this.configName = configName;
return this;
}

public AzureOpenAiImageModel build() {
return new AzureOpenAiImageModel(endpoint, apiKey, adToken, apiVersion, modelName, size, quality, style, user,
responseFormat, timeout, maxRetries, logRequests, logResponses,
persistDirectory);
persistDirectory, configName);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ public AzureOpenAiStreamingChatModel(String endpoint,
Proxy proxy,
String responseFormat,
Boolean logRequests,
Boolean logResponses) {
Boolean logResponses,
String configName) {

timeout = getOrDefault(timeout, ofSeconds(60));

Expand All @@ -91,6 +92,7 @@ public AzureOpenAiStreamingChatModel(String endpoint,
.userAgent(DEFAULT_USER_AGENT)
.azureAdToken(adToken)
.azureApiKey(apiKey)
.configName(configName)
.build();
this.temperature = getOrDefault(temperature, 0.7);
this.topP = topP;
Expand Down Expand Up @@ -203,6 +205,7 @@ public static class Builder {
private String responseFormat;
private Boolean logRequests;
private Boolean logResponses;
private String configName;

/**
* Sets the Azure OpenAI endpoint. This is a mandatory parameter.
Expand Down Expand Up @@ -298,6 +301,11 @@ public Builder logResponses(Boolean logResponses) {
return this;
}

public Builder configName(String configName) {
this.configName = configName;
return this;
}

public AzureOpenAiStreamingChatModel build() {
return new AzureOpenAiStreamingChatModel(endpoint,
apiVersion,
Expand All @@ -313,7 +321,8 @@ public AzureOpenAiStreamingChatModel build() {
proxy,
responseFormat,
logRequests,
logResponses);
logResponses,
configName);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ public Supplier<ChatLanguageModel> chatModel(LangChain4jAzureOpenAiConfig runtim

var builder = AzureOpenAiChatModel.builder()
.endpoint(getEndpoint(azureAiConfig, configName, EndpointType.CHAT))
.configName(NamedConfigUtil.isDefault(configName) ? null : configName)
.apiKey(apiKey)
.adToken(adToken)
.apiVersion(azureAiConfig.apiVersion())
Expand Down Expand Up @@ -97,6 +98,7 @@ public Supplier<StreamingChatLanguageModel> streamingChatModel(LangChain4jAzureO
.endpoint(getEndpoint(azureAiConfig, configName, EndpointType.CHAT))
.apiKey(apiKey)
.adToken(adToken)
.configName(NamedConfigUtil.isDefault(configName) ? null : configName)
.apiVersion(azureAiConfig.apiVersion())
.timeout(azureAiConfig.timeout().orElse(Duration.ofSeconds(10)))
.logRequests(chatModelConfig.logRequests().orElse(false))
Expand Down Expand Up @@ -141,6 +143,7 @@ public Supplier<EmbeddingModel> embeddingModel(LangChain4jAzureOpenAiConfig runt
.endpoint(getEndpoint(azureAiConfig, configName, EndpointType.EMBEDDING))
.apiKey(apiKey)
.adToken(adToken)
.configName(NamedConfigUtil.isDefault(configName) ? null : configName)
.apiVersion(azureAiConfig.apiVersion())
.timeout(azureAiConfig.timeout().orElse(Duration.ofSeconds(10)))
.maxRetries(azureAiConfig.maxRetries())
Expand Down Expand Up @@ -182,6 +185,7 @@ public Supplier<ImageModel> imageModel(LangChain4jAzureOpenAiConfig runtimeConfi
.logRequests(imageModelConfig.logRequests().orElse(false))
.logResponses(imageModelConfig.logResponses().orElse(false))
.modelName(imageModelConfig.modelName())
.configName(NamedConfigUtil.isDefault(configName) ? null : configName)
.size(imageModelConfig.size())
.quality(imageModelConfig.quality())
.style(imageModelConfig.style())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import com.knuddels.jtokkit.Encodings;

import io.quarkiverse.langchain4j.openai.OpenAiRestApi;
import io.quarkus.arc.deployment.UnremovableBeanBuildItem;
import io.quarkus.deployment.annotations.BuildProducer;
import io.quarkus.deployment.annotations.BuildStep;
import io.quarkus.deployment.builditem.IndexDependencyBuildItem;
Expand All @@ -27,6 +29,11 @@ void indexDependencies(BuildProducer<IndexDependencyBuildItem> producer) {
producer.produce(new IndexDependencyBuildItem("dev.ai4j", "openai4j"));
}

@BuildStep
UnremovableBeanBuildItem unremovableBeans() {
return UnremovableBeanBuildItem.beanTypes(OpenAiRestApi.AuthProvider.class);
}

@BuildStep
void nativeImageSupport(BuildProducer<NativeImageResourceBuildItem> resourcesProducer) {
registerJtokkitResources(resourcesProducer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.io.OutputStream;
import java.lang.annotation.Annotation;
import java.lang.reflect.Type;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.function.Predicate;
Expand Down Expand Up @@ -38,6 +39,8 @@
import org.jboss.resteasy.reactive.client.SseEvent;
import org.jboss.resteasy.reactive.client.SseEventFilter;
import org.jboss.resteasy.reactive.client.api.ClientLogger;
import org.jboss.resteasy.reactive.client.spi.ResteasyReactiveClientRequestContext;
import org.jboss.resteasy.reactive.client.spi.ResteasyReactiveClientRequestFilter;
import org.jboss.resteasy.reactive.common.providers.serialisers.AbstractJsonMessageBodyReader;

import com.fasterxml.jackson.databind.ObjectMapper;
Expand Down Expand Up @@ -179,6 +182,38 @@ public boolean test(SseEvent<String> event) {
}
}

interface AuthProvider {
String getAuthorization(Input input);

interface Input {
String method();

URI uri();

MultivaluedMap<String, Object> headers();
}
}

class OpenAIRestAPIFilter implements ResteasyReactiveClientRequestFilter {
AuthProvider authorizer;

public OpenAIRestAPIFilter(AuthProvider authorizer) {
this.authorizer = authorizer;
}

@Override
public void filter(ResteasyReactiveClientRequestContext requestContext) {
requestContext.getHeaders().putSingle("Authorization", authorizer.getAuthorization(
new AuthInputImpl(requestContext.getMethod(), requestContext.getUri(), requestContext.getHeaders())));
}

private record AuthInputImpl(
String method,
URI uri,
MultivaluedMap<String, Object> headers) implements AuthProvider.Input {
}
}

@Priority(Priorities.USER + 100) // this priority ensures that our Writer has priority over the standard Jackson one
class OpenAiRestApiJacksonWriter implements MessageBodyWriter<Object> {

Expand Down
Loading
Loading