Skip to content

Commit

Permalink
Allow instantiating beans for dynamic authorization in OpenAI
Browse files Browse the repository at this point in the history
fixing PR comments

added documentation and renamed Interface to AuthProvider
  • Loading branch information
csotiriou committed Jun 2, 2024
1 parent 11c3638 commit de43ea5
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 10 deletions.
60 changes: 60 additions & 0 deletions docs/modules/ROOT/pages/openai.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,63 @@ We used `null` as the `apiVersion` parameter in the call to `streamingCompletion
This parameter is only required when using Azure OpenAI.
====


=== Dynamic Authorization Headers

There are cases where one may need to provide dynamic authorization headers, to be passed in OpenAI endpoints (in OpenAI or Azure OpenAI)

There are two ways to achieve this:

==== Using a ContainerRequestFilter annotated with `@Provider`.
Since OpenAI client relies in the Quarkus Rest Client, you can have apply a filter that will be called in all OpenAI requests and set the headers accordingly.

[source,java]
----
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;
import jakarta.ws.rs.ext.Provider;
import org.jboss.resteasy.reactive.client.spi.ResteasyReactiveClientRequestContext;
import org.jboss.resteasy.reactive.client.spi.ResteasyReactiveClientRequestFilter;
@Provider
@ApplicationScoped
public class RequestFilter implements ResteasyReactiveClientRequestFilter {
@Inject
MyAuthorizationService myAuthorizationService;
@Override
public void filter(ResteasyReactiveClientRequestContext requestContext) {
/*
* All requests will be filtered here, therefore make sure that you make
* the necessary checks to avoid putting the Authorization header in
* requests that do not need it.
*/
requestContext.getHeaders().putSingle("Authorization", ...);
}
}
----

==== Using `AuthProvider`
One can implement the `AuthProvider` interface and provide the implementation of the `getAuthorization` method.

This is useful when you need to provide different authorization headers for different OpenAI models. The `@Named` annotation can be used to specify the model name in this scenario.

[source,java]
----
import io.quarkiverse.langchain4j.openai.OpenAiRestApi;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Named;
@ApplicationScoped
@Named("my-model-name") //you can omit this if you have only one model or if you want to use the default model
public class TestClass implements OpenAiRestApi.AuthProvider {
@Inject MyTokenProviderService tokenProviderService;
@Override
public String getAuthorization() {
return "Bearer " + tokenProviderService.getToken();
}
}
----

Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ public AzureOpenAiChatModel(String endpoint,
Proxy proxy,
String responseFormat,
Boolean logRequests,
Boolean logResponses) {
Boolean logResponses,
String configName) {

timeout = getOrDefault(timeout, ofSeconds(60));

Expand All @@ -86,6 +87,7 @@ public AzureOpenAiChatModel(String endpoint,
.userAgent(Consts.DEFAULT_USER_AGENT)
.azureAdToken(adToken)
.azureApiKey(apiKey)
.configName(configName)
.build();

this.temperature = getOrDefault(temperature, 0.7);
Expand Down Expand Up @@ -172,6 +174,7 @@ public static class Builder {
private String responseFormat;
private Boolean logRequests;
private Boolean logResponses;
private String configName;

/**
* Sets the Azure OpenAI endpoint. This is a mandatory parameter.
Expand All @@ -196,6 +199,11 @@ public Builder apiVersion(String apiVersion) {
return this;
}

public Builder configName(String configName) {
this.configName = configName;
return this;
}

/**
* Sets the Azure OpenAI API key. This is a mandatory parameter.
*
Expand Down Expand Up @@ -288,7 +296,8 @@ public AzureOpenAiChatModel build() {
proxy,
responseFormat,
logRequests,
logResponses);
logResponses,
configName);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ public AzureOpenAiEmbeddingModel(String endpoint,
Integer maxRetries,
Proxy proxy,
Boolean logRequests,
Boolean logResponses) {
Boolean logResponses,
String configName) {

timeout = getOrDefault(timeout, ofSeconds(60));
if (maxRetries < 1) {
Expand All @@ -78,6 +79,7 @@ public AzureOpenAiEmbeddingModel(String endpoint,
.userAgent(DEFAULT_USER_AGENT)
.azureAdToken(adToken)
.azureApiKey(apiKey)
.configName(configName)
.build();
this.maxRetries = getOrDefault(maxRetries, 3);
this.tokenizer = tokenizer;
Expand Down Expand Up @@ -149,6 +151,7 @@ public static class Builder {
private Boolean logRequests;
private Boolean logResponses;
private String adToken;
private String configName;

/**
* Sets the Azure OpenAI endpoint. This is a mandatory parameter.
Expand Down Expand Up @@ -219,6 +222,11 @@ public Builder logResponses(Boolean logResponses) {
return this;
}

public Builder configName(String configName) {
this.configName = configName;
return this;
}

public AzureOpenAiEmbeddingModel build() {
return new AzureOpenAiEmbeddingModel(endpoint,
apiVersion,
Expand All @@ -229,7 +237,8 @@ public AzureOpenAiEmbeddingModel build() {
maxRetries,
proxy,
logRequests,
logResponses);
logResponses,
configName);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public AzureOpenAiImageModel(String endpoint, String apiKey, String adToken, Str
String size,
String quality, String style, Optional<String> user, String responseFormat, Duration timeout,
Integer maxRetries, Boolean logRequests, Boolean logResponses,
Optional<Path> persistDirectory) {
Optional<Path> persistDirectory, String configName) {
this.modelName = modelName;
this.size = size;
this.quality = quality;
Expand All @@ -69,6 +69,7 @@ public AzureOpenAiImageModel(String endpoint, String apiKey, String adToken, Str
.userAgent(DEFAULT_USER_AGENT)
.azureAdToken(adToken)
.azureApiKey(apiKey)
.configName(configName)
.build();
}

Expand Down Expand Up @@ -159,6 +160,7 @@ public static class Builder {
private Boolean logRequests;
private Boolean logResponses;
private Optional<Path> persistDirectory;
private String configName;

public Builder endpoint(String endpoint) {
this.endpoint = endpoint;
Expand Down Expand Up @@ -235,10 +237,15 @@ public Builder persistDirectory(Optional<Path> persistDirectory) {
return this;
}

public Builder configName(String configName) {
this.configName = configName;
return this;
}

public AzureOpenAiImageModel build() {
return new AzureOpenAiImageModel(endpoint, apiKey, adToken, apiVersion, modelName, size, quality, style, user,
responseFormat, timeout, maxRetries, logRequests, logResponses,
persistDirectory);
persistDirectory, configName);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ public AzureOpenAiStreamingChatModel(String endpoint,
Proxy proxy,
String responseFormat,
Boolean logRequests,
Boolean logResponses) {
Boolean logResponses,
String configName) {

timeout = getOrDefault(timeout, ofSeconds(60));

Expand All @@ -91,6 +92,7 @@ public AzureOpenAiStreamingChatModel(String endpoint,
.userAgent(DEFAULT_USER_AGENT)
.azureAdToken(adToken)
.azureApiKey(apiKey)
.configName(configName)
.build();
this.temperature = getOrDefault(temperature, 0.7);
this.topP = topP;
Expand Down Expand Up @@ -203,6 +205,7 @@ public static class Builder {
private String responseFormat;
private Boolean logRequests;
private Boolean logResponses;
private String configName;

/**
* Sets the Azure OpenAI endpoint. This is a mandatory parameter.
Expand Down Expand Up @@ -298,6 +301,11 @@ public Builder logResponses(Boolean logResponses) {
return this;
}

public Builder configName(String configName) {
this.configName = configName;
return this;
}

public AzureOpenAiStreamingChatModel build() {
return new AzureOpenAiStreamingChatModel(endpoint,
apiVersion,
Expand All @@ -313,7 +321,8 @@ public AzureOpenAiStreamingChatModel build() {
proxy,
responseFormat,
logRequests,
logResponses);
logResponses,
configName);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ public Supplier<ChatLanguageModel> chatModel(LangChain4jAzureOpenAiConfig runtim

var builder = AzureOpenAiChatModel.builder()
.endpoint(getEndpoint(azureAiConfig, configName, EndpointType.CHAT))
.configName(NamedConfigUtil.isDefault(configName) ? null : configName)
.apiKey(apiKey)
.adToken(adToken)
.apiVersion(azureAiConfig.apiVersion())
Expand Down Expand Up @@ -97,6 +98,7 @@ public Supplier<StreamingChatLanguageModel> streamingChatModel(LangChain4jAzureO
.endpoint(getEndpoint(azureAiConfig, configName, EndpointType.CHAT))
.apiKey(apiKey)
.adToken(adToken)
.configName(NamedConfigUtil.isDefault(configName) ? null : configName)
.apiVersion(azureAiConfig.apiVersion())
.timeout(azureAiConfig.timeout().orElse(Duration.ofSeconds(10)))
.logRequests(chatModelConfig.logRequests().orElse(false))
Expand Down Expand Up @@ -141,6 +143,7 @@ public Supplier<EmbeddingModel> embeddingModel(LangChain4jAzureOpenAiConfig runt
.endpoint(getEndpoint(azureAiConfig, configName, EndpointType.EMBEDDING))
.apiKey(apiKey)
.adToken(adToken)
.configName(NamedConfigUtil.isDefault(configName) ? null : configName)
.apiVersion(azureAiConfig.apiVersion())
.timeout(azureAiConfig.timeout().orElse(Duration.ofSeconds(10)))
.maxRetries(azureAiConfig.maxRetries())
Expand Down Expand Up @@ -182,6 +185,7 @@ public Supplier<ImageModel> imageModel(LangChain4jAzureOpenAiConfig runtimeConfi
.logRequests(imageModelConfig.logRequests().orElse(false))
.logResponses(imageModelConfig.logResponses().orElse(false))
.modelName(imageModelConfig.modelName())
.configName(NamedConfigUtil.isDefault(configName) ? null : configName)
.size(imageModelConfig.size())
.quality(imageModelConfig.quality())
.style(imageModelConfig.style())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
import org.jboss.resteasy.reactive.client.SseEvent;
import org.jboss.resteasy.reactive.client.SseEventFilter;
import org.jboss.resteasy.reactive.client.api.ClientLogger;
import org.jboss.resteasy.reactive.client.spi.ResteasyReactiveClientRequestContext;
import org.jboss.resteasy.reactive.client.spi.ResteasyReactiveClientRequestFilter;
import org.jboss.resteasy.reactive.common.providers.serialisers.AbstractJsonMessageBodyReader;

import com.fasterxml.jackson.databind.ObjectMapper;
Expand Down Expand Up @@ -179,6 +181,23 @@ public boolean test(SseEvent<String> event) {
}
}

interface AuthProvider {
String getAuthorization();
}

class OpenAIRestAPIFilter implements ResteasyReactiveClientRequestFilter {
AuthProvider authorizer;

public OpenAIRestAPIFilter(AuthProvider authorizer) {
this.authorizer = authorizer;
}

@Override
public void filter(ResteasyReactiveClientRequestContext requestContext) {
requestContext.getHeaders().putSingle("Authorization", authorizer.getAuthorization());
}
}

@Priority(Priorities.USER + 100) // this priority ensures that our Writer has priority over the standard Jackson one
class OpenAiRestApiJacksonWriter implements MessageBodyWriter<Object> {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
import java.util.function.Consumer;
import java.util.function.Supplier;

import jakarta.enterprise.inject.Instance;
import jakarta.enterprise.inject.literal.NamedLiteral;
import jakarta.enterprise.inject.spi.CDI;
import jakarta.ws.rs.client.ClientRequestContext;
import jakarta.ws.rs.client.ClientRequestFilter;

Expand Down Expand Up @@ -113,6 +116,21 @@ public void filter(ClientRequestContext requestContext) {
});
}

OpenAiRestApi.AuthProvider authorizer = null;
Instance<OpenAiRestApi.AuthProvider> beanInstance = builder.configName == null
? CDI.current().select(OpenAiRestApi.AuthProvider.class)
: CDI.current().select(OpenAiRestApi.AuthProvider.class, NamedLiteral.of(builder.configName));

//get the first one without causing a bean resolution exception
for (var handle : beanInstance.handles()) {
authorizer = handle.get();
break;
}

if (authorizer != null) {
var filterProvider = new OpenAiRestApi.OpenAIRestAPIFilter(authorizer);
restApiBuilder.register(filterProvider);
}
return restApiBuilder.build(OpenAiRestApi.class);
} catch (URISyntaxException e) {
throw new RuntimeException(e);
Expand Down Expand Up @@ -515,6 +533,7 @@ public static class Builder extends OpenAiClient.Builder<QuarkusOpenAiClient, Bu

private String userAgent;
private String azureAdToken;
private String configName;

public Builder userAgent(String userAgent) {
this.userAgent = userAgent;
Expand All @@ -526,6 +545,11 @@ public Builder azureAdToken(String azureAdToken) {
return this;
}

public Builder configName(String configName) {
this.configName = configName;
return this;
}

@Override
public Builder openAiApiKey(String openAiApiKey) {
this.openAiApiKey = openAiApiKey;
Expand Down Expand Up @@ -564,14 +588,15 @@ public boolean equals(Object o) {
builder.writeTimeout)
&& Objects.equals(proxy, builder.proxy)
&& Objects.equals(azureAdToken, builder.azureAdToken)
&& Objects.equals(userAgent, builder.userAgent);
&& Objects.equals(userAgent, builder.userAgent)
&& Objects.equals(configName, builder.configName);
}

@Override
public int hashCode() {
return Objects.hash(baseUrl, apiVersion, openAiApiKey, azureApiKey, organizationId, callTimeout, connectTimeout,
readTimeout,
writeTimeout, proxy, logRequests, logResponses, logStreamingResponses, userAgent, azureAdToken);
writeTimeout, proxy, logRequests, logResponses, logStreamingResponses, userAgent, azureAdToken, configName);
}
}

Expand Down

0 comments on commit de43ea5

Please sign in to comment.