diff --git a/README.md b/README.md index 444880a16..de294fc28 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,6 @@ Official Weaviate Java Client. To start using Weaviate Java Client add the dependency to `pom.xml`: ```xml - io.weaviate client6 @@ -19,9 +18,18 @@ To start using Weaviate Java Client add the dependency to `pom.xml`: ### Uber JAR🫙 -If you're building a uber-JAR with something like `maven-assembly-plugin`, use a shaded version with classifier `all`. +If you're building an uber-JAR with something like `maven-assembly-plugin`, use a shaded version with classifier `all`. This ensures that all dynamically-loaded dependecies of `io.grpc` are resolved correctly. +```xml + + io.weaviate + client6 + 6.0.0-beta4 + all + +``` + ### SNAPSHOT releases The latest development version of `client6` is released after every merged pull request. To include it in you project set the version to `6.0.0-SNAPSHOT` and [configure your `` section accordingly](https://central.sonatype.org/publish/publish-portal-snapshots/#consuming-snapshot-releases-for-your-project). diff --git a/src/it/java/io/weaviate/containers/Container.java b/src/it/java/io/weaviate/containers/Container.java index cca0bcc3e..81b450389 100644 --- a/src/it/java/io/weaviate/containers/Container.java +++ b/src/it/java/io/weaviate/containers/Container.java @@ -18,16 +18,6 @@ public class Container { public static final Img2VecNeural IMG2VEC_NEURAL = Img2VecNeural.createDefault(); public static final MinIo MINIO = MinIo.createDefault(); - static { - startAll(); - } - - /** Start all shared Testcontainers. */ - // TODO: start lazily! - static void startAll() { - // WEAVIATE.start(); - } - /** * Stop all shared Testcontainers created in {@link #startAll}. *

diff --git a/src/it/java/io/weaviate/containers/Weaviate.java b/src/it/java/io/weaviate/containers/Weaviate.java index 745e0c626..e1bb7157f 100644 --- a/src/it/java/io/weaviate/containers/Weaviate.java +++ b/src/it/java/io/weaviate/containers/Weaviate.java @@ -6,41 +6,81 @@ import java.util.HashSet; import java.util.Map; import java.util.Set; +import java.util.function.Function; import org.testcontainers.weaviate.WeaviateContainer; +import io.weaviate.client6.v1.api.Config; import io.weaviate.client6.v1.api.WeaviateClient; +import io.weaviate.client6.v1.internal.ObjectBuilder; public class Weaviate extends WeaviateContainer { - private WeaviateClient clientInstance; - - public static final String VERSION = "1.29.1"; + public static final String VERSION = "1.32.3"; public static final String DOCKER_IMAGE = "semitechnologies/weaviate"; + private volatile SharedClient clientInstance; + + public WeaviateClient getClient() { + return getClient(ObjectBuilder.identity()); + } + /** - * Get a client for the current Weaviate container. - * As we aren't running tests in parallel at the moment, - * this is not made thread-safe. + * Create a new instance of WeaviateClient connected to this container if none + * exist. Get an existing client otherwise. + * + * The lifetime of this client is tied to that of its container, which means + * that you do not need to {@code close} it manually. It will only truly close + * after the parent Testcontainer is stopped. */ - public WeaviateClient getClient() { - // FIXME: control from containers? + public WeaviateClient getClient(Function> fn) { if (!isRunning()) { start(); } - if (clientInstance == null) { - try { - clientInstance = WeaviateClient.local( + if (clientInstance != null) { + return clientInstance; + } + + synchronized (this) { + if (clientInstance == null) { + var host = getHost(); + var customFn = ObjectBuilder.partial(fn, conn -> conn - .host(getHost()) + .scheme("http") + .httpHost(host) + .grpcHost(host) .httpPort(getMappedPort(8080)) .grpcPort(getMappedPort(50051))); - } catch (Exception e) { - throw new RuntimeException("create WeaviateClient for Weaviate container", e); + var config = customFn.apply(new Config.Custom()).build(); + try { + clientInstance = new SharedClient(config, this); + } catch (Exception e) { + throw new RuntimeException("create WeaviateClient for Weaviate container", e); + } } } return clientInstance; } + /** + * Create a new instance of WeaviateClient connected to this container. + * Prefer using {@link #getClient} unless your test needs the initialization + * steps to run, e.g. OIDC authorization grant exchange. + */ + public WeaviateClient getNewClient(Function> fn) { + if (!isRunning()) { + start(); + } + var host = getHost(); + var customFn = ObjectBuilder.partial(fn, + conn -> conn + .scheme("http") + .httpHost(host) + .grpcHost(host) + .httpPort(getMappedPort(8080)) + .grpcPort(getMappedPort(50051))); + return WeaviateClient.custom(customFn); + } + public static Weaviate createDefault() { return new Builder().build(); } @@ -99,7 +139,22 @@ public Builder withOffloadS3(String accessKey, String secretKey) { } public Builder enableTelemetry(boolean enable) { - telemetry = enable; + environment.put("DISABLE_TELEMETRY", Boolean.toString(!enable)); + return this; + } + + public Builder enableAnonymousAccess(boolean enable) { + environment.put("AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED", Boolean.toString(enable)); + return this; + } + + public Builder withOIDC(String clientId, String issuer, String usernameClaim, String groupsClaim) { + enableAnonymousAccess(false); + environment.put("AUTHENTICATION_OIDC_ENABLED", "true"); + environment.put("AUTHENTICATION_OIDC_CLIENT_ID", clientId); + environment.put("AUTHENTICATION_OIDC_ISSUER", issuer); + environment.put("AUTHENTICATION_OIDC_USERNAME_CLAIM", usernameClaim); + environment.put("AUTHENTICATION_OIDC_GROUPS_CLAIM", groupsClaim); return this; } @@ -107,12 +162,9 @@ public Weaviate build() { var c = new Weaviate(DOCKER_IMAGE + ":" + versionTag); if (!enableModules.isEmpty()) { - c.withEnv("ENABLE_API_BASED_MODULES", "'true'"); + c.withEnv("ENABLE_API_BASED_MODULES", Boolean.TRUE.toString()); c.withEnv("ENABLE_MODULES", String.join(",", enableModules)); } - if (!telemetry) { - c.withEnv("DISABLE_TELEMETRY", "true"); - } environment.forEach((name, value) -> c.withEnv(name, value)); c.withCreateContainerCmdModifier(cmd -> cmd.withHostName("weaviate")); @@ -134,10 +186,32 @@ public void stop() { if (clientInstance == null) { return; } - try { - clientInstance.close(); - } catch (IOException e) { - // TODO: log error + synchronized (this) { + try { + clientInstance.close(this); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + } + + /** SharedClient's lifetime is tied to that of it's parent container. */ + private class SharedClient extends WeaviateClient { + private final Weaviate parent; + + private SharedClient(Config config, Weaviate parent) { + super(config); + this.parent = parent; + } + + private void close(Weaviate caller) throws Exception { + if (caller == parent) { + super.close(); + } + } + + @Override + public void close() throws IOException { } } } diff --git a/src/it/java/io/weaviate/integration/OIDCSupportITest.java b/src/it/java/io/weaviate/integration/OIDCSupportITest.java new file mode 100644 index 000000000..9b9a36691 --- /dev/null +++ b/src/it/java/io/weaviate/integration/OIDCSupportITest.java @@ -0,0 +1,222 @@ +package io.weaviate.integration; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.Socket; +import java.util.List; +import java.util.UUID; + +import org.assertj.core.api.Assertions; +import org.junit.Assume; +import org.junit.Test; + +import io.weaviate.ConcurrentTest; +import io.weaviate.client6.v1.api.Authentication; +import io.weaviate.client6.v1.internal.TokenProvider; +import io.weaviate.client6.v1.internal.rest.RestTransport; +import io.weaviate.containers.Weaviate; + +/** + * Test that the client can use one of the supported authentication flows to + * obtain a token from the OIDC provider and use it in a request to Weaviate. + * + * Running this test suite successfully requires talking to external services, + * so tests will be skipped if we don't have internet. See + * {@link #hasInternetConnection}. + * Additionally, {@code WCS_DUMMY_CI_PW} and {@code OKTA_CLIENT_SECRET} + * environment variables must be set. + */ +public class OIDCSupportITest extends ConcurrentTest { + private static final String WCS_DUMMY_CI_USERNAME = "oidc-test-user@weaviate.io"; + private static final String WCS_DUMMY_CI_PW = System.getenv("WCS_DUMMY_CI_PW"); + + /** + * Weaviate container that uses WCS-backed OIDC provider. + * Supports ResourceOwnerPassword and RefreshToken authentication flows. + */ + private static final Weaviate wcsContainer = Weaviate.custom() + .withOIDC("wcs", "https://auth.wcs.api.weaviate.io/auth/realms/SeMI", "email", "groups") + .build(); + + private static final String OKTA_CLIENT_ID = "0oa7e9ipdkVZRUcxo5d7"; + private static final String OKTA_CLIENT_SECRET = System.getenv("OKTA_CLIENT_SECRET"); + + /** + * Weaviate container that uses Okta's dummy OIDC provider. + * Supports ClientCredentials flow. + */ + private static final Weaviate oktaContainer = Weaviate.custom() + .withOIDC(OKTA_CLIENT_ID, "https://dev-32300990.okta.com/oauth2/aus7e9kxbwYQB0eht5d7", "cid", "groups") + .build(); + + /** + * Exchange a Resource Owner Password grant for a bearer token + * and authenticate with it. + */ + @Test + public void test_bearerToken() throws Exception { + Assume.assumeTrue("WCS_DUMMY_CI_PW is not set", WCS_DUMMY_CI_PW != null); + Assume.assumeTrue("no internet connection", hasInternetConnection()); + + var passwordAuth = Authentication.resourceOwnerPassword(WCS_DUMMY_CI_USERNAME, WCS_DUMMY_CI_PW, List.of()); + var t = SpyTokenProvider.stealToken(passwordAuth); + Assertions.assertThat(t.isValid()).as("bearer token is valid").isTrue(); + + // Expire this token immediately to force the client to fetch a new one. + var auth = SpyTokenProvider.spyOn(Authentication.bearerToken(t.accessToken(), t.refreshToken(), 0)); + pingWeaviate(wcsContainer, auth); + + var newT = auth.getToken(); + Assertions.assertThat(newT.accessToken()) + .as("expect access_token was refreshed") + .isNotEqualTo(t.accessToken()); + + // Check that the new token authenticates requests. + pingWeaviate(wcsContainer, auth); + pingWeaviateAsync(wcsContainer, auth); + } + + @Test + public void test_resourceOwnerPassword() throws Exception { + Assume.assumeTrue("WCS_DUMMY_CI_PW is not set", WCS_DUMMY_CI_PW != null); + Assume.assumeTrue("no internet connection", hasInternetConnection()); + + // Check norwal resource owner password flow works. + var password = Authentication.resourceOwnerPassword(WCS_DUMMY_CI_USERNAME, WCS_DUMMY_CI_PW, List.of()); + var auth = SpyTokenProvider.spyOn(password); + pingWeaviate(wcsContainer, auth); + pingWeaviateAsync(wcsContainer, auth); + + // Get the token obtained by the wrapped TokenProvider. + var t = auth.getToken(); + + // Now make all tokens expire immediately, forcing the client to refresh.. + // Verify the new token is different from the one before. + auth.setExpiresIn(0); + pingWeaviate(wcsContainer, auth); + + var newT = auth.getToken(); + Assertions.assertThat(newT.accessToken()) + .as("expect access_token was refreshed") + .isNotEqualTo(t.accessToken()); + } + + @Test + public void test_clientCredentials() throws Exception { + Assume.assumeTrue("OKTA_CLIENT_SECRET is not set", OKTA_CLIENT_SECRET != null); + Assume.assumeTrue("no internet connection", hasInternetConnection()); + + // Check norwal client credentials flow works. + var cc = Authentication.clientCredentials(OKTA_CLIENT_ID, OKTA_CLIENT_SECRET, List.of()); + var auth = SpyTokenProvider.spyOn(cc); + pingWeaviate(oktaContainer, auth); + pingWeaviateAsync(oktaContainer, auth); + + // Get the token obtained by the wrapped TokenProvider. + var t = auth.getToken(); + + // Now make all tokens expire immediately, forcing the client to refresh.. + // Verify the new token is different from the one before. + auth.setExpiresIn(0); + pingWeaviate(oktaContainer, auth); + + var newT = auth.getToken(); + Assertions.assertThat(newT.accessToken()) + .as("expect access_token was refreshed") + .isNotEqualTo(t.accessToken()); + } + + /** Send an HTTP and gRPC requests using a "sync" client. */ + private static void pingWeaviate(final Weaviate container, Authentication auth) throws Exception { + try (final var client = container.getNewClient(conn -> conn.authentication(auth))) { + // Make an authenticated HTTP call + Assertions.assertThat(client.isLive()).isTrue(); + + // Make an authenticated gRPC call + var nsThings = unique("Things"); + client.collections.create(nsThings); + var things = client.collections.use(nsThings); + var randomUuid = UUID.randomUUID().toString(); + Assertions.assertThat(things.data.exists(randomUuid)).isFalse(); + } + } + + /** Send an HTTP and gRPC requests using an "async" client. */ + private static void pingWeaviateAsync(final Weaviate container, Authentication auth) throws Exception { + try (final var client = container.getNewClient(conn -> conn.authentication(auth))) { + try (final var async = client.async()) { + // Make an authenticated HTTP call + Assertions.assertThat(async.isLive().join()).isTrue(); + + // Make an authenticated gRPC call + var nsThings = unique("Things"); + async.collections.create(nsThings).join(); + var things = async.collections.use(nsThings); + var randomUuid = UUID.randomUUID().toString(); + Assertions.assertThat(things.data.exists(randomUuid).join()).isFalse(); + } + } + } + + private static boolean hasInternetConnection() { + return ping("www.google.com"); + } + + private static boolean ping(String site) { + InetSocketAddress addr = new InetSocketAddress(site, 80); + try (final var sock = new Socket()) { + sock.connect(addr, 3000); + return true; + } catch (IOException e) { + return false; + } + } + + /** + * SpyTokenProvider is an Authentication implementation that spies on the + * TokenProvider it creates and can expose tokens generated by it. + */ + private static class SpyTokenProvider implements Authentication, TokenProvider { + + /** Spy on the TokenProvider returned by thie Authentication. */ + static SpyTokenProvider spyOn(Authentication auth) { + return new SpyTokenProvider(auth); + } + + /** Spy a token obtained by another TokenProvider. */ + static Token stealToken(Authentication auth) throws Exception { + var spy = spyOn(auth); + pingWeaviate(wcsContainer, spy); + return spy.getToken(); + } + + private Long expiresIn; + private Authentication authentication; + private TokenProvider tokenProvider; + + private SpyTokenProvider(Authentication actual) { + this.authentication = actual; + } + + @Override + public TokenProvider getTokenProvider(RestTransport transport) { + tokenProvider = authentication.getTokenProvider(transport); + return this; + } + + @Override + public Token getToken() { + var t = tokenProvider.getToken(); + if (expiresIn != null) { + t = Token.expireAfter(t.accessToken(), t.refreshToken(), expiresIn); + } + return t; + } + + /** Expire all tokens in {@code expiresIn} seconds. */ + void setExpiresIn(long expiresIn) { + this.expiresIn = expiresIn; + } + + } +} diff --git a/src/it/java/io/weaviate/integration/PaginationITest.java b/src/it/java/io/weaviate/integration/PaginationITest.java index 3d97d0a21..5044bdb73 100644 --- a/src/it/java/io/weaviate/integration/PaginationITest.java +++ b/src/it/java/io/weaviate/integration/PaginationITest.java @@ -126,7 +126,7 @@ public void testWithQueryOptions() throws IOException { } @Test - public void testAsyncPaginator() throws IOException, InterruptedException, ExecutionException { + public void testAsyncPaginator() throws Exception, InterruptedException, ExecutionException { // Arrange var nsThings = ns("Things"); var count = 10; diff --git a/src/it/java/io/weaviate/integration/SearchITest.java b/src/it/java/io/weaviate/integration/SearchITest.java index d4c132e0e..2d5d78bb9 100644 --- a/src/it/java/io/weaviate/integration/SearchITest.java +++ b/src/it/java/io/weaviate/integration/SearchITest.java @@ -288,7 +288,7 @@ public void testBm25() throws IOException, InterruptedException, ExecutionExcept * test both sync/async clients. */ @Test - public void testBm25_async() throws IOException, InterruptedException, ExecutionException { + public void testBm25_async() throws Exception, InterruptedException, ExecutionException { var nsWords = ns("Words"); try (final var async = client.async()) { diff --git a/src/main/java/io/weaviate/client6/v1/api/Authentication.java b/src/main/java/io/weaviate/client6/v1/api/Authentication.java new file mode 100644 index 000000000..22a7d4a34 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/Authentication.java @@ -0,0 +1,79 @@ +package io.weaviate.client6.v1.api; + +import java.util.List; + +import io.weaviate.client6.v1.internal.TokenProvider; +import io.weaviate.client6.v1.internal.oidc.OidcConfig; +import io.weaviate.client6.v1.internal.oidc.OidcUtils; +import io.weaviate.client6.v1.internal.rest.RestTransport; + +public interface Authentication { + TokenProvider getTokenProvider(RestTransport transport); + + /** + * Authenticate using a static API key. + * + * @param apiKey Weaviate API key. + */ + public static Authentication apiKey(String apiKey) { + return __ -> TokenProvider.staticToken(apiKey); + } + + /** + * Authenticate using an existing access_token + refresh_token + * pair. + * + * @param accessToken Access token. + * @param refreshToken Refresh token. + * @param expiresIn Remaining token lifetime in seconds. + * + * @return Authentication provider. + * @throws WeaviateOAuthException if an error occurred at any point of the + * exchange process. + */ + public static Authentication bearerToken(String accessToken, String refreshToken, long expiresIn) { + return transport -> { + OidcConfig oidc = OidcUtils.getConfig(transport); + return TokenProvider.bearerToken(oidc, accessToken, refreshToken, expiresIn); + }; + } + + /** + * Authenticate using Resource Owner Password authorization grant. + * + * @param username Resource owner username. + * @param password Resource owner password. + * @param scopes Client scopes. + * + * @return Authentication provider. + * @throws WeaviateOAuthException if an error occurred at any point of the token + * exchange process. + */ + public static Authentication resourceOwnerPassword(String username, String password, List scopes) { + return transport -> { + OidcConfig oidc = OidcUtils.getConfig(transport).withScopes(scopes).withScopes("offline_access"); + return TokenProvider.resourceOwnerPassword(oidc, username, password); + }; + } + + /** + * Authenticate using Client Credentials authorization grant. + * + * @param clientId Client ID. + * @param clientSecret Client secret. + * @param scopes Client scopes. + * + * @return Authentication provider. + * @throws WeaviateOAuthException if an error occurred at any point while + * obtaining a new token. + */ + public static Authentication clientCredentials(String clientId, String clientSecret, List scopes) { + return transport -> { + OidcConfig oidc = OidcUtils.getConfig(transport).withScopes(scopes); + if (oidc.scopes().isEmpty() && TokenProvider.isMicrosoft(oidc)) { + oidc = oidc.withScopes(clientId + "/.default"); + } + return TokenProvider.clientCredentials(oidc, clientId, clientSecret); + }; + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/Authorization.java b/src/main/java/io/weaviate/client6/v1/api/Authorization.java deleted file mode 100644 index 9bcfc2ed1..000000000 --- a/src/main/java/io/weaviate/client6/v1/api/Authorization.java +++ /dev/null @@ -1,10 +0,0 @@ -package io.weaviate.client6.v1.api; - -import io.weaviate.client6.v1.internal.TokenProvider; -import io.weaviate.client6.v1.internal.TokenProvider.Token; - -public class Authorization { - public static TokenProvider apiKey(String apiKey) { - return TokenProvider.staticToken(new Token(apiKey)); - } -} diff --git a/src/main/java/io/weaviate/client6/v1/api/Config.java b/src/main/java/io/weaviate/client6/v1/api/Config.java index a9a6c9dda..a14f01290 100644 --- a/src/main/java/io/weaviate/client6/v1/api/Config.java +++ b/src/main/java/io/weaviate/client6/v1/api/Config.java @@ -19,7 +19,7 @@ public record Config( String grpcHost, int grpcPort, Map headers, - TokenProvider tokenProvider, + Authentication authentication, TrustManagerFactory trustManagerFactory) { public static Config of(Function> fn) { @@ -34,15 +34,23 @@ private Config(Builder builder) { builder.grpcHost, builder.grpcPort, builder.headers, - builder.tokenProvider, + builder.authentication, builder.trustManagerFactory); } RestTransportOptions restTransportOptions() { + return restTransportOptions(null); + } + + RestTransportOptions restTransportOptions(TokenProvider tokenProvider) { return new RestTransportOptions(scheme, httpHost, httpPort, headers, tokenProvider, trustManagerFactory); } GrpcChannelOptions grpcTransportOptions() { + return grpcTransportOptions(null); + } + + GrpcChannelOptions grpcTransportOptions(TokenProvider tokenProvider) { return new GrpcChannelOptions(scheme, grpcHost, grpcPort, headers, tokenProvider, trustManagerFactory); } @@ -53,7 +61,7 @@ private abstract static class Builder> implements Obj protected int httpPort; protected String grpcHost; protected int grpcPort; - protected TokenProvider tokenProvider; + protected Authentication authentication; protected TrustManagerFactory trustManagerFactory; protected Map headers = new HashMap<>(); @@ -102,6 +110,16 @@ protected SELF trustManagerFactory(TrustManagerFactory tmf) { return (SELF) this; } + /** + * Set authentication method. Setting this to {@code null} or omitting + * will not use any authentication mechanism. + */ + @SuppressWarnings("unchecked") + public SELF authentication(Authentication authz) { + this.authentication = authz; + return (SELF) this; + } + /** * Set a single request header. The client does not support header lists, * so there is no equivalent {@code addHeader} to append to existing header. @@ -144,7 +162,7 @@ private static boolean isWeaviateDomain(String host) { public Config build() { // For clusters hosted on Weaviate Cloud, Weaviate Embedding Service // will be available under the same domain. - if (isWeaviateDomain(httpHost) && tokenProvider != null) { + if (isWeaviateDomain(httpHost) && authentication != null) { setHeader(HEADER_X_WEAVIATE_CLUSTER_URL, "https://" + httpHost + ":" + httpPort); } return new Config(this); @@ -203,11 +221,11 @@ public Local grpcPort(int port) { * {@link #trustManagerFactory}. */ public static class WeaviateCloud extends Builder { - public WeaviateCloud(String httpHost, TokenProvider tokenProvider) { - this(URI.create(httpHost), tokenProvider); + public WeaviateCloud(String httpHost, Authentication authentication) { + this(URI.create(httpHost), authentication); } - public WeaviateCloud(URI clusterUri, TokenProvider tokenProvider) { + public WeaviateCloud(URI clusterUri, Authentication authentication) { scheme("https"); super.httpHost(clusterUri.getHost() != null ? clusterUri.getHost() // https://[example.com]/about @@ -215,7 +233,7 @@ public WeaviateCloud(URI clusterUri, TokenProvider tokenProvider) { super.grpcHost("grpc-" + this.httpHost); this.httpPort = 443; this.grpcPort = 443; - this.tokenProvider = tokenProvider; + this.authentication = authentication; } /** @@ -283,15 +301,6 @@ public Custom grpcPort(int port) { return this; } - /** - * Set authorization method. Setting this to {@code null} or omitting - * will not use any authorization mechanism. - */ - public Custom authorization(TokenProvider tokenProvider) { - this.tokenProvider = tokenProvider; - return this; - } - /** * Configure a custom TrustStore to validate third-party SSL certificates. * diff --git a/src/main/java/io/weaviate/client6/v1/api/InstanceMetadata.java b/src/main/java/io/weaviate/client6/v1/api/InstanceMetadata.java new file mode 100644 index 000000000..9664e8490 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/InstanceMetadata.java @@ -0,0 +1,12 @@ +package io.weaviate.client6.v1.api; + +import java.util.Map; + +import com.google.gson.annotations.SerializedName; + +public record InstanceMetadata( + @SerializedName("hostname") String hostName, + @SerializedName("version") String version, + @SerializedName("modules") Map modules, + @SerializedName("grpcMaxMessageSize") Long grpcMaxMessageSize) { +} diff --git a/src/main/java/io/weaviate/client6/v1/api/InstanceMetadataRequest.java b/src/main/java/io/weaviate/client6/v1/api/InstanceMetadataRequest.java new file mode 100644 index 000000000..09635c9df --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/InstanceMetadataRequest.java @@ -0,0 +1,14 @@ +package io.weaviate.client6.v1.api; + +import java.util.Collections; + +import io.weaviate.client6.v1.internal.rest.Endpoint; +import io.weaviate.client6.v1.internal.rest.SimpleEndpoint; + +public class InstanceMetadataRequest { + public static final Endpoint _ENDPOINT = SimpleEndpoint.noBody( + __ -> "GET", + __ -> "/meta", + __ -> Collections.emptyMap(), + InstanceMetadata.class); +} diff --git a/src/main/java/io/weaviate/client6/v1/api/IsLiveRequest.java b/src/main/java/io/weaviate/client6/v1/api/IsLiveRequest.java new file mode 100644 index 000000000..e098a6e09 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/IsLiveRequest.java @@ -0,0 +1,13 @@ +package io.weaviate.client6.v1.api; + +import java.util.Collections; + +import io.weaviate.client6.v1.internal.rest.BooleanEndpoint; +import io.weaviate.client6.v1.internal.rest.Endpoint; + +public record IsLiveRequest() { + public static final Endpoint _ENDPOINT = BooleanEndpoint.noBody( + request -> "GET", + request -> "/.well-known/live", + request -> Collections.emptyMap()); +} diff --git a/src/main/java/io/weaviate/client6/v1/api/IsReadyRequest.java b/src/main/java/io/weaviate/client6/v1/api/IsReadyRequest.java new file mode 100644 index 000000000..bc597f8ca --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/IsReadyRequest.java @@ -0,0 +1,13 @@ +package io.weaviate.client6.v1.api; + +import java.util.Collections; + +import io.weaviate.client6.v1.internal.rest.BooleanEndpoint; +import io.weaviate.client6.v1.internal.rest.Endpoint; + +public record IsReadyRequest() { + public static final Endpoint _ENDPOINT = BooleanEndpoint.noBody( + request -> "GET", + request -> "/.well-known/ready", + request -> Collections.emptyMap()); +} diff --git a/src/main/java/io/weaviate/client6/v1/api/WeaviateClient.java b/src/main/java/io/weaviate/client6/v1/api/WeaviateClient.java index 1e2127e22..956137ac2 100644 --- a/src/main/java/io/weaviate/client6/v1/api/WeaviateClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/WeaviateClient.java @@ -1,17 +1,19 @@ package io.weaviate.client6.v1.api; -import java.io.Closeable; import java.io.IOException; import java.util.function.Function; import io.weaviate.client6.v1.api.collections.WeaviateCollectionsClient; import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.TokenProvider; import io.weaviate.client6.v1.internal.grpc.DefaultGrpcTransport; +import io.weaviate.client6.v1.internal.grpc.GrpcChannelOptions; import io.weaviate.client6.v1.internal.grpc.GrpcTransport; import io.weaviate.client6.v1.internal.rest.DefaultRestTransport; import io.weaviate.client6.v1.internal.rest.RestTransport; +import io.weaviate.client6.v1.internal.rest.RestTransportOptions; -public class WeaviateClient implements Closeable { +public class WeaviateClient implements AutoCloseable { /** Store this for {@link #async()} helper. */ private final Config config; @@ -22,8 +24,27 @@ public class WeaviateClient implements Closeable { public WeaviateClient(Config config) { this.config = config; - this.restTransport = new DefaultRestTransport(config.restTransportOptions()); - this.grpcTransport = new DefaultGrpcTransport(config.grpcTransportOptions()); + + RestTransportOptions restOpt; + GrpcChannelOptions grpcOpt; + if (config.authentication() == null) { + restOpt = config.restTransportOptions(); + grpcOpt = config.grpcTransportOptions(); + } else { + TokenProvider tokenProvider; + try (final var noAuthRest = new DefaultRestTransport(config.restTransportOptions())) { + tokenProvider = config.authentication().getTokenProvider(noAuthRest); + } catch (Exception e) { + // Generally exceptions are caught in TokenProvider internals. + // This one may be thrown when noAuthRest transport is auto-closed. + throw new WeaviateOAuthException(e); + } + restOpt = config.restTransportOptions(tokenProvider); + grpcOpt = config.grpcTransportOptions(tokenProvider); + } + + this.restTransport = new DefaultRestTransport(restOpt); + this.grpcTransport = new DefaultGrpcTransport(grpcOpt); this.collections = new WeaviateCollectionsClient(restTransport, grpcTransport); } @@ -90,7 +111,7 @@ public static WeaviateClient wcd(String httpHost, String apiKey) { /** Connect to a Weaviate Cloud instance. */ public static WeaviateClient wcd(String httpHost, String apiKey, Function> fn) { - var config = new Config.WeaviateCloud(httpHost, Authorization.apiKey(apiKey)); + var config = new Config.WeaviateCloud(httpHost, Authentication.apiKey(apiKey)); return new WeaviateClient(fn.apply(config).build()); } @@ -99,12 +120,27 @@ public static WeaviateClient custom(Function + * This call is blocking if {@link Authentication} configured, + * as the client will need to do the initial token exchange. + */ public static WeaviateClientAsync local() { return local(ObjectBuilder.identity()); } - /** Connect to a local Weaviate instance. */ + /** + * Connect to a local Weaviate instance. + * + *

+ * This call is blocking if {@link Authentication} configured, + * as the client will need to do the initial token exchange. + */ public static WeaviateClientAsync local(Function> fn) { return new WeaviateClientAsync(fn.apply(new Config.Local()).build()); } - /** Connect to a Weaviate Cloud instance. */ + /** + * Connect to a Weaviate Cloud instance. + * + *

+ * This call is blocking if {@link Authentication} configured, + * as the client will need to do the initial token exchange. + */ public static WeaviateClientAsync wcd(String httpHost, String apiKey) { return wcd(httpHost, apiKey, ObjectBuilder.identity()); } - /** Connect to a Weaviate Cloud instance. */ + /** + * Connect to a Weaviate Cloud instance. + * + *

+ * This call is blocking if {@link Authentication} configured, + * as the client will need to do the initial token exchange. + */ public static WeaviateClientAsync wcd(String httpHost, String apiKey, Function> fn) { - var config = new Config.WeaviateCloud(httpHost, Authorization.apiKey(apiKey)); + var config = new Config.WeaviateCloud(httpHost, Authentication.apiKey(apiKey)); return new WeaviateClientAsync(fn.apply(config).build()); } - /** Connect to a Weaviate instance with custom configuration. */ + /** + * Connect to a Weaviate instance with custom configuration. + * + *

+ * This call is blocking if {@link Authentication} configured, + * as the client will need to do the initial token exchange. + */ public static WeaviateClientAsync custom(Function> fn) { return new WeaviateClientAsync(Config.of(fn)); } + /** Ping the server for a liveness check. */ + public CompletableFuture isLive() { + return this.restTransport.performRequestAsync(null, IsLiveRequest._ENDPOINT); + } + + /** Ping the server for a readiness check. */ + public CompletableFuture isReady() { + return this.restTransport.performRequestAsync(null, IsReadyRequest._ENDPOINT); + } + + /** Get deployement metadata for the target Weaviate instance. */ + public CompletableFuture meta() throws IOException { + return this.restTransport.performRequestAsync(null, InstanceMetadataRequest._ENDPOINT); + } + /** * Close {@link #restTransport} and {@link #grpcTransport} * and release associated resources. */ @Override - public void close() throws IOException { + public void close() throws Exception { this.restTransport.close(); this.grpcTransport.close(); } diff --git a/src/main/java/io/weaviate/client6/v1/api/WeaviateOAuthException.java b/src/main/java/io/weaviate/client6/v1/api/WeaviateOAuthException.java new file mode 100644 index 000000000..4c2f7931d --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/WeaviateOAuthException.java @@ -0,0 +1,20 @@ +package io.weaviate.client6.v1.api; + +/** + * Exception class thrown by client API message when the request's reached the + * server, but the operation did not complete successfully either either due to + * a bad request or a server error. + */ +public class WeaviateOAuthException extends WeaviateException { + public WeaviateOAuthException(String message) { + super(message); + } + + public WeaviateOAuthException(String message, Throwable cause) { + super(message, cause); + } + + public WeaviateOAuthException(Throwable cause) { + super(cause); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/internal/AsyncTokenProvider.java b/src/main/java/io/weaviate/client6/v1/internal/AsyncTokenProvider.java new file mode 100644 index 000000000..ec083d274 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/internal/AsyncTokenProvider.java @@ -0,0 +1,83 @@ +package io.weaviate.client6.v1.internal; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +import io.weaviate.client6.v1.internal.TokenProvider.Token; + +/** + * AsyncTokenProvider obtains authentication tokens asynchronously + * and can be used in a non-blocking context. + * + * As implementors are likely to schedule token fetches on a thread pool, + * instances must be closed to avoid resource leaks. + */ +public interface AsyncTokenProvider extends AutoCloseable { + CompletableFuture getToken(); + + CompletableFuture getToken(Executor executor); + + /** + * Create an {@link AsyncTokenProvider} instance from an existing + * {@link TokenProvider}. The inner provider MAY be called from + * multiple instances and MUST be thread-safe. + * + * Either use in a try-with-resources block or close after usage explicitly. + */ + static AsyncTokenProvider wrap(TokenProvider tp) { + return new Default(tp); + } + + /** + * AsyncTokenProvider fetches tokens in a + * shared single background thread. + */ + public static class Default implements AsyncTokenProvider { + /** + * Shared executor service. This way all instances + * can share the same thread pool. + */ + private static ExecutorService exec; + /** Shut down {@link #exec} once refCount reaches 0. */ + private static int refCount = 0; + + private final TokenProvider provider; + + Default(TokenProvider tp) { + synchronized (Default.class) { + if (refCount == 0) { + exec = Executors.newSingleThreadExecutor(); + } + refCount++; + } + this.provider = tp; + } + + /** Get token with the default single-thread executor. */ + @Override + public CompletableFuture getToken() { + return getToken(exec); + } + + /** Get token with a custom executor. */ + @Override + public CompletableFuture getToken(Executor executor) { + return CompletableFuture.supplyAsync(provider::getToken, executor); + } + + @Override + public void close() throws Exception { + provider.close(); + + synchronized (Default.class) { + refCount--; + if (refCount == 0 && exec != null) { + exec.shutdown(); + exec = null; + } + } + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/internal/BackgroundTokenProvider.java b/src/main/java/io/weaviate/client6/v1/internal/BackgroundTokenProvider.java new file mode 100644 index 000000000..2d95c480d --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/internal/BackgroundTokenProvider.java @@ -0,0 +1,61 @@ +package io.weaviate.client6.v1.internal; + +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; + +public class BackgroundTokenProvider implements TokenProvider { + private final ScheduledExecutorService exec; + private final TokenProvider provider; + + /** + * Create a background task to periodically refresh the token. + * + *

+ * This method will wrap the TokenProvider in {@link ReuseTokenProvider} + * before passing it on to the constructor to cache the token + * in-between the refreshes. + * + * If TokenProvider is an instance of BackgroundTokenProvider + * it is returned immediately. + */ + public static TokenProvider wrap(TokenProvider tp) { + if (tp instanceof BackgroundTokenProvider) { + return tp; + } + return new BackgroundTokenProvider(ReuseTokenProvider.wrap(null, tp)); + } + + private BackgroundTokenProvider(TokenProvider tp) { + this.provider = tp; + this.exec = Executors.newSingleThreadScheduledExecutor(); + + scheduleNextRefresh(); + } + + @Override + public Token getToken() { + return provider.getToken(); + } + + /** + * Fetch the token and schedule a task to refresh it + * after {@link Token#expiresIn} seconds. The next + * refresh task is scheduled immediately afterwards. + * + * If {@link Token#neverExpires} this method returns + * early and the next refresh task is never scheduled. + */ + private void scheduleNextRefresh() { + var t = getToken(); + if (t.neverExpires()) { + return; + } + exec.schedule(this::scheduleNextRefresh, t.expiresIn(), TimeUnit.SECONDS); + } + + @Override + public void close() throws Exception { + exec.shutdown(); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/internal/ExchangeTokenProvider.java b/src/main/java/io/weaviate/client6/v1/internal/ExchangeTokenProvider.java new file mode 100644 index 000000000..40bf7b66d --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/internal/ExchangeTokenProvider.java @@ -0,0 +1,37 @@ +package io.weaviate.client6.v1.internal; + +import io.weaviate.client6.v1.internal.oidc.OidcConfig; + +/** + * ExchangeTokenProvider obtains a new {@link Token} from "single-use" + * {@link TokenProvider}, usually one using an Resource Owner Password grant. + * It then creates a new internal TokenProvider to refresh the token each time + * {@link #getToken} is called. + * + *

+ * Usage: + * + *

{@code
+ * var initialGrant = TokenProvider.resourceOwnerPassword(oidc, username, password);
+ * var exchange = new ExchangeTokenProvider(oidc, initialGrant);
+ * var token = exchange.getToken();
+ * } 
+ */ +class ExchangeTokenProvider implements TokenProvider { + private final TokenProvider bearer; + + ExchangeTokenProvider(OidcConfig oidc, TokenProvider tp) { + var t = tp.getToken(); + this.bearer = TokenProvider.bearerToken(oidc, t.accessToken(), t.refreshToken(), t.expiresIn()); + } + + @Override + public Token getToken() { + return bearer.getToken(); + } + + @Override + public void close() throws Exception { + bearer.close(); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/internal/ObjectBuilder.java b/src/main/java/io/weaviate/client6/v1/internal/ObjectBuilder.java index 550b4266d..a1ed410fc 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/ObjectBuilder.java +++ b/src/main/java/io/weaviate/client6/v1/internal/ObjectBuilder.java @@ -9,6 +9,25 @@ static , T> Function> identity() return builder -> builder; } + /** + * Chain two builder-functions such that {@code partialFn} is applied before + * {@code fn}. + * + *

+ * Usage: + * + *

{@code
+   *  static final Function> defaultConfig = b -> {...};
+   *  void doWithConfig(Function> fn) {
+   *    var withDefault = ObjectBuilder.partial(fn, defaultConfig);
+   *    var config = fn.apply(new Config()).build();
+   *  }
+   * }
+ * + * @param fn Function that will be applied last. + * @param partialFn Function that will be applied first. + * @return ObjectBuilder with "pre-applied" function. + */ static , T> Function> partial(Function> fn, Function partialFn) { return partialFn.andThen(fn); diff --git a/src/main/java/io/weaviate/client6/v1/internal/ReuseTokenProvider.java b/src/main/java/io/weaviate/client6/v1/internal/ReuseTokenProvider.java new file mode 100644 index 000000000..8f71580d6 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/internal/ReuseTokenProvider.java @@ -0,0 +1,92 @@ +package io.weaviate.client6.v1.internal; + +import javax.annotation.concurrent.ThreadSafe; + +/** + * ReuseTokenProvider returns the same token as long as its valid and obtains a + * new token from a {@link TokenProvider} otherwise. + * + *

+ * Usage: + * + *

{@code
+ * // Create an TokenProvider that can rotate tokens as they expire.
+ * var myProvider = new MyTokenProvider();
+ *
+ * // Create a reusable TokenProvider.
+ * var tokenProvider = ReuseTokenProvider.wrap(myProvider);
+ * }
+ */ +@ThreadSafe +final class ReuseTokenProvider implements TokenProvider { + private final TokenProvider provider; + private final long expiryDelta; + + private volatile Token token; + + /** + * Create new {@link ReuseTokenProvider} from another {@link TokenProvider}. + * Wrapping an instance ReuseTokenProvider returns that instance if the token is + * {@code null}, so this method is safe to call with any TokenProvider. + * + * @return A TokenProvider. + */ + static TokenProvider wrap(Token t, TokenProvider tp, long expiryDelta) { + if (tp instanceof ReuseTokenProvider rtp) { + if (t == null) { + // Use it directly, but set new expiry delta. + return rtp.withExpiryDelta(expiryDelta); + } + } + return new ReuseTokenProvider(t, tp, expiryDelta); + } + + /** + * Create new {@link ReuseTokenProvider} from another {@link TokenProvider}. + * Wrapping an instance ReuseTokenProvider returns that instance if the token is + * {@code null}, so this method is safe to call with any TokenProvider. + */ + static TokenProvider wrap(Token t, TokenProvider tp) { + if (tp instanceof ReuseTokenProvider rtp) { + if (t == null) { + return rtp; // Use it directly. + } + } + return new ReuseTokenProvider(t, tp, 0); + } + + /** + * Create a new TokenProvider with a different expiryDelta. + * Tokens obtained from this TokenProvider with have the same early expiry. + * + * @param expiryDelta Early expiry in seconds. + * @return A new TokenProvider. + */ + TokenProvider withExpiryDelta(long expirtyDelta) { + return new ReuseTokenProvider(this.token, this.provider, expirtyDelta); + } + + private ReuseTokenProvider(Token t, TokenProvider tp, long expiryDelta) { + this.provider = tp; + this.token = t; + this.expiryDelta = expiryDelta; + } + + @Override + public Token getToken() { + if (token != null && token.isValid()) { + return token; + } + synchronized (this) { + if (token == null || !token.isValid()) { + token = provider.getToken().withExpiryDelta(expiryDelta); + } + } + return token; + } + + @Override + public void close() throws Exception { + provider.close(); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/internal/TokenProvider.java b/src/main/java/io/weaviate/client6/v1/internal/TokenProvider.java index af69a456b..fe565c345 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/TokenProvider.java +++ b/src/main/java/io/weaviate/client6/v1/internal/TokenProvider.java @@ -1,13 +1,213 @@ package io.weaviate.client6.v1.internal; +import java.net.URI; +import java.time.Instant; + +import io.weaviate.client6.v1.api.WeaviateOAuthException; +import io.weaviate.client6.v1.internal.oidc.OidcConfig; +import io.weaviate.client6.v1.internal.oidc.nimbus.NimbusTokenProvider; + +/** TokenProvider obtains authentication tokens. */ @FunctionalInterface -public interface TokenProvider { +public interface TokenProvider extends AutoCloseable { Token getToken(); - public record Token(String accessToken) { + /** Token represents an access_token + refresh_token pair. */ + public record Token(String accessToken, String refreshToken, Instant createdAt, long expiresIn, long expiryDelta) { + /** + * Returns {@code true} if remaining lifetime of the token is greater than 0. + * Tokens created with {@link #expireNever} are always valid. + */ + public boolean isValid() { + if (neverExpires()) { + return true; + } + return Instant.now().isBefore(createdAt.plusSeconds(expiresIn - expiryDelta)); + } + + /** Returns {@code true} if this token is always valid. */ + public boolean neverExpires() { + return expiresIn == -1; + } + + /** + * Set early expiry for the Token. + * + *

+ * A Token with {@link #expiresIn} of 10s and {@link #expiryDelta} of 3s + * will be invalid 7s after being created. + * + * @param expiryDelta Early expiry in seconds. A negative value is clamped to 0. + * @return A Token identical to the source one, but with a different expiry. + */ + public Token withExpiryDelta(long expiryDelta) { + return new Token(accessToken, refreshToken, createdAt, expiresIn, Math.max(0, expiryDelta)); + } + + /** Create a token with a different refresh_token. */ + public Token withRefreshToken(String refreshToken) { + return new Token(accessToken, refreshToken, createdAt, expiresIn, expiryDelta); + } + + /** + * Create a token with an expiration and a refresh_token. + * + * @param accessToken Access token. + * @param refreshToken Refresh token. + * @param expiresIn Remaining token lifetime in seconds. + * + * @return A new Token. + */ + public static Token expireAfter(String accessToken, String refreshToken, long expiresIn) { + return new Token(accessToken, refreshToken, Instant.now(), expiresIn, 0); + } + + /** + * Create a token that does not have a refresh_token. + * For example, a token obtained via a Client Credentials grant + * can only be renewed using that grant type. + * + * @param accessToken Access token. + * @param expiresIn Remaining token lifetime in seconds. + * + * @return A new Token. + */ + public static Token expireAfter(String accessToken, long expiresIn) { + return expireAfter(accessToken, null, expiresIn); + } + + /** + * Create a token that never expires. + * + * @param accessToken Access token. + * @return A new Token. + */ + public static Token expireNever(String accessToken) { + return Token.expireAfter(accessToken, -1); + } } - public static TokenProvider staticToken(Token token) { + /** + * Refreshing the token slightly ahead of time will help prevent + * phony unauthorized access errors. + * + * This value is currently not configurable and should be seen + * as an internal implementation detail. + */ + static long DEFAULT_EARLY_EXPIRY = 30; + + /** + * Authorize using a token that never expires and doesn't need to be refreshed. + * + * @param accessToken Access token. + */ + public static TokenProvider staticToken(String accessToken) { + final var token = Token.expireNever(accessToken); return () -> token; } + + /** + * Create a TokenProvider that uses an existing access_token + refresh_token + * pair. + * + * @param oidc OIDC config. + * @param accessToken Access token. + * @param refreshToken Refresh token. + * @param expiresIn Remaining token lifetime in seconds. + * + * @return Internal TokenProvider implementation. + * @throws WeaviateOAuthException if an error occurred at any point of the + * exchange process. + */ + public static TokenProvider bearerToken(OidcConfig oidc, String accessToken, String refreshToken, long expiresIn) { + final var token = Token.expireAfter(accessToken, refreshToken, expiresIn); + final var provider = NimbusTokenProvider.refreshToken(oidc, token); + return background(reuse(token, provider, DEFAULT_EARLY_EXPIRY)); + } + + /** + * Create a TokenProvider that uses Resource Owner Password authorization grant. + * + * @param oidc OIDC config. + * @param username Resource owner username. + * @param password Resource owner password. + * + * @return Internal TokenProvider implementation. + * @throws WeaviateOAuthException if an error occurred at any point of the token + * exchange process. + */ + public static TokenProvider resourceOwnerPassword(OidcConfig oidc, String username, String password) { + final var passwordGrant = NimbusTokenProvider.resourceOwnerPassword(oidc, username, password); + return background(reuse(null, exchange(oidc, passwordGrant), DEFAULT_EARLY_EXPIRY)); + } + + /** + * Create a TokenProvider that uses Client Credentials authorization grant. + * + * @param oidc OIDC config. + * @param clientId Client ID. + * @param clientSecret Client secret. + * + * @return Internal TokenProvider implementation. + * @throws WeaviateOAuthException if an error occurred at any point while + * obtaining a new token. + */ + public static TokenProvider clientCredentials(OidcConfig oidc, String clientId, String clientSecret) { + final var provider = NimbusTokenProvider.clientCredentials(oidc, clientId, clientSecret); + return reuse(null, provider, DEFAULT_EARLY_EXPIRY); + } + + /** + * Obtain a TokenProvider that exchanges an authorization grant for a new Token. + */ + static TokenProvider exchange(OidcConfig oidc, TokenProvider tp) { + return new ExchangeTokenProvider(oidc, tp); + } + + /** + * Obtain a TokenProvider which reuses tokens obtained + * from another TokenProvider until they expire. + */ + static TokenProvider reuse(Token t, TokenProvider tp) { + return ReuseTokenProvider.wrap(t, tp); + } + + /** + * Obtain a TokenProvider which reuses tokens obtained + * from another TokenProvider until they expire. + */ + static TokenProvider reuse(Token t, TokenProvider tp, long expiryDelta) { + return ReuseTokenProvider.wrap(t, tp, expiryDelta); + } + + /** + * Obtain a TokenProvider which refreshes tokens in a background thread. + * This ensures a refresh_token doesn't become stale. + */ + static TokenProvider background(TokenProvider tp) { + return BackgroundTokenProvider.wrap(tp); + } + + public record ProviderMetadata(URI tokenEndpoint) { + } + + /** + * Returns true if this OIDC provider's token endpoint is hosted at + * {@code login.microsoftonline.com}. + * + * @param oidc OIDC config. + * + * @throws WeaviateOAuthException if metadata could not be parsed. + */ + public static boolean isMicrosoft(OidcConfig oidc) { + var metadata = NimbusTokenProvider.parseProviderMetadata(oidc.providerMetadata()); + return metadata.tokenEndpoint().getHost().contains("login.microsoftonline.com"); + } + + /** + * Implementations which need to dispose of created resources, + * e.g. thread pools, should override this method. + */ + default void close() throws Exception { + } } diff --git a/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java b/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java index 7655db36a..a9a50c200 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java +++ b/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java @@ -1,6 +1,5 @@ package io.weaviate.client6.v1.internal.grpc; -import java.io.IOException; import java.util.concurrent.CompletableFuture; import javax.net.ssl.SSLException; @@ -26,6 +25,8 @@ public final class DefaultGrpcTransport implements GrpcTransport { private final WeaviateBlockingStub blockingStub; private final WeaviateFutureStub futureStub; + private TokenCallCredentials callCredentials; + public DefaultGrpcTransport(GrpcChannelOptions transportOptions) { this.channel = buildChannel(transportOptions); @@ -36,9 +37,9 @@ public DefaultGrpcTransport(GrpcChannelOptions transportOptions) { .withInterceptors(MetadataUtils.newAttachHeadersInterceptor(transportOptions.headers())); if (transportOptions.tokenProvider() != null) { - var credentials = new TokenCallCredentials(transportOptions.tokenProvider()); - blockingStub = blockingStub.withCallCredentials(credentials); - futureStub = futureStub.withCallCredentials(credentials); + this.callCredentials = new TokenCallCredentials(transportOptions.tokenProvider()); + blockingStub = blockingStub.withCallCredentials(callCredentials); + futureStub = futureStub.withCallCredentials(callCredentials); } this.blockingStub = blockingStub; @@ -121,7 +122,10 @@ private static ManagedChannel buildChannel(GrpcChannelOptions transportOptions) } @Override - public void close() throws IOException { - this.channel.shutdown(); + public void close() throws Exception { + channel.shutdown(); + if (callCredentials != null) { + callCredentials.close(); + } } } diff --git a/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcTransport.java b/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcTransport.java index a0fddbc67..952bfc336 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcTransport.java +++ b/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcTransport.java @@ -3,7 +3,7 @@ import java.io.Closeable; import java.util.concurrent.CompletableFuture; -public interface GrpcTransport extends Closeable { +public interface GrpcTransport extends AutoCloseable { ResponseT performRequest(RequestT request, Rpc rpc); diff --git a/src/main/java/io/weaviate/client6/v1/internal/grpc/TokenCallCredentials.java b/src/main/java/io/weaviate/client6/v1/internal/grpc/TokenCallCredentials.java index c24a9093a..c69696682 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/grpc/TokenCallCredentials.java +++ b/src/main/java/io/weaviate/client6/v1/internal/grpc/TokenCallCredentials.java @@ -5,29 +5,40 @@ import io.grpc.CallCredentials; import io.grpc.Metadata; import io.grpc.Status; +import io.weaviate.client6.v1.internal.AsyncTokenProvider; import io.weaviate.client6.v1.internal.TokenProvider; -class TokenCallCredentials extends CallCredentials { +class TokenCallCredentials extends CallCredentials implements AutoCloseable { private static final Metadata.Key AUTHORIZATION = Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER); - private final TokenProvider tokenProvider; + /** + * Since {@link #applyRequestMetadata} accepts an {@link Executor} anyways, + * we can always just use an async provider, instead of creating 2 separate + * instances for it. + */ + private final AsyncTokenProvider tokenProviderAsync; TokenCallCredentials(TokenProvider tokenProvider) { - this.tokenProvider = tokenProvider; + this.tokenProviderAsync = AsyncTokenProvider.wrap(tokenProvider); } @Override public void applyRequestMetadata(RequestInfo requestInfo, Executor executor, MetadataApplier metadataApplier) { - executor.execute(() -> { - try { - var headers = new Metadata(); - var token = tokenProvider.getToken().accessToken(); - headers.put(AUTHORIZATION, "Bearer " + token); - metadataApplier.apply(headers); - } catch (Exception e) { - metadataApplier.fail(Status.UNAUTHENTICATED.withCause(e)); - } - }); + tokenProviderAsync.getToken(executor) + .whenComplete((tok, ex) -> { + if (ex != null) { + metadataApplier.fail(Status.UNAUTHENTICATED.withCause(ex)); + return; + } + var headers = new Metadata(); + headers.put(AUTHORIZATION, "Bearer " + tok.accessToken()); + metadataApplier.apply(headers); + }); + } + + @Override + public void close() throws Exception { + tokenProviderAsync.close(); } } diff --git a/src/main/java/io/weaviate/client6/v1/internal/oidc/OidcConfig.java b/src/main/java/io/weaviate/client6/v1/internal/oidc/OidcConfig.java new file mode 100644 index 000000000..858fd4d82 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/internal/oidc/OidcConfig.java @@ -0,0 +1,36 @@ +package io.weaviate.client6.v1.internal.oidc; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public record OidcConfig( + String clientId, + String providerMetadata, + Set scopes) { + + public OidcConfig(String clientId, String providerMetadata, Set scopes) { + this.clientId = clientId; + this.providerMetadata = providerMetadata; + this.scopes = scopes != null ? Set.copyOf(scopes) : Collections.emptySet(); + } + + public OidcConfig(String clientId, String providerMetadata, List scopes) { + this(clientId, providerMetadata, scopes == null ? null : new HashSet<>(scopes)); + } + + /** Create a new OIDC config with extended scopes. */ + public OidcConfig withScopes(String... scopes) { + return withScopes(Arrays.asList(scopes)); + } + + /** Create a new OIDC config with extended scopes. */ + public OidcConfig withScopes(List scopes) { + var newScopes = Stream.concat(this.scopes.stream(), scopes.stream()).collect(Collectors.toSet()); + return new OidcConfig(clientId, providerMetadata, newScopes); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/internal/oidc/OidcUtils.java b/src/main/java/io/weaviate/client6/v1/internal/oidc/OidcUtils.java new file mode 100644 index 000000000..cafcc1289 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/internal/oidc/OidcUtils.java @@ -0,0 +1,59 @@ +package io.weaviate.client6.v1.internal.oidc; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; + +import com.google.gson.annotations.SerializedName; + +import io.weaviate.client6.v1.api.WeaviateOAuthException; +import io.weaviate.client6.v1.internal.rest.Endpoint; +import io.weaviate.client6.v1.internal.rest.ExternalEndpoint; +import io.weaviate.client6.v1.internal.rest.RestTransport; +import io.weaviate.client6.v1.internal.rest.SimpleEndpoint; + +public final class OidcUtils { + /** Prevents public initialization. */ + private OidcUtils() { + } + + private static final String OPENID_CONFIGURATION_URL = "/.well-known/openid-configuration"; + + private static final Endpoint GET_OPENID_ENDPOINT = SimpleEndpoint.noBody( + request -> "GET", + request -> OPENID_CONFIGURATION_URL, + request -> Collections.emptyMap(), + OpenIdConfiguration.class); + + private static final Endpoint GET_PROVIDER_METADATA_ENDPOINT = new ExternalEndpoint<>( + request -> "GET", + request -> request, // URL is the request body. + requesf -> Collections.emptyMap(), + request -> null, + (__, response) -> response); + + private static record OpenIdConfiguration( + @SerializedName("clientId") String clientId, + @SerializedName("scopes") List scopes, + @SerializedName("href") String endpoint) { + } + + /** Fetch cluster's OIDC config. */ + public static final OidcConfig getConfig(RestTransport transport) { + OpenIdConfiguration openid; + try { + openid = transport.performRequest(null, GET_OPENID_ENDPOINT); + } catch (IOException e) { + throw new WeaviateOAuthException("fetch OpenID configuration", e); + } + + String providerMetadata; + try { + providerMetadata = transport.performRequest(openid.endpoint(), GET_PROVIDER_METADATA_ENDPOINT); + } catch (IOException e) { + throw new WeaviateOAuthException("fetch provider metadata", e); + } + + return new OidcConfig(openid.clientId(), providerMetadata, openid.scopes()); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/internal/oidc/nimbus/Flow.java b/src/main/java/io/weaviate/client6/v1/internal/oidc/nimbus/Flow.java new file mode 100644 index 000000000..ad12561ff --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/internal/oidc/nimbus/Flow.java @@ -0,0 +1,45 @@ +package io.weaviate.client6.v1.internal.oidc.nimbus; + +import com.nimbusds.oauth2.sdk.AuthorizationGrant; +import com.nimbusds.oauth2.sdk.ClientCredentialsGrant; +import com.nimbusds.oauth2.sdk.ResourceOwnerPasswordCredentialsGrant; +import com.nimbusds.oauth2.sdk.auth.ClientAuthentication; +import com.nimbusds.oauth2.sdk.auth.ClientSecretPost; +import com.nimbusds.oauth2.sdk.auth.Secret; +import com.nimbusds.oauth2.sdk.id.ClientID; + +import io.weaviate.client6.v1.internal.TokenProvider.Token; + +@FunctionalInterface +interface Flow { + AuthorizationGrant getAuthorizationGrant(); + + default ClientAuthentication getClientAuthentication() { + return null; + } + + static Flow refreshToken(Token t) { + return new RefreshTokenFlow(t); + } + + static Flow resourceOwnerPassword(String username, String password) { + final var grant = new ResourceOwnerPasswordCredentialsGrant(username, new Secret(password)); + return () -> grant; // Reuse cached authorization grant + } + + static Flow clientCredentials(String clientId, String clientSecret) { + return new Flow() { + private static final AuthorizationGrant GRANT = new ClientCredentialsGrant(); + + @Override + public AuthorizationGrant getAuthorizationGrant() { + return GRANT; + } + + @Override + public ClientAuthentication getClientAuthentication() { + return new ClientSecretPost(new ClientID(clientId), new Secret(clientSecret)); + } + }; + } +} diff --git a/src/main/java/io/weaviate/client6/v1/internal/oidc/nimbus/HttpResponseParser.java b/src/main/java/io/weaviate/client6/v1/internal/oidc/nimbus/HttpResponseParser.java new file mode 100644 index 000000000..3c7f70436 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/internal/oidc/nimbus/HttpResponseParser.java @@ -0,0 +1,79 @@ +package io.weaviate.client6.v1.internal.oidc.nimbus; + +import com.google.gson.annotations.SerializedName; +import com.nimbusds.oauth2.sdk.ErrorResponse; +import com.nimbusds.oauth2.sdk.ParseException; +import com.nimbusds.oauth2.sdk.TokenResponse; +import com.nimbusds.oauth2.sdk.http.HTTPResponse; +import com.nimbusds.openid.connect.sdk.OIDCTokenResponse; +import com.nimbusds.openid.connect.sdk.OIDCTokenResponseParser; +import com.nimbusds.openid.connect.sdk.token.OIDCTokens; + +import io.weaviate.client6.v1.api.WeaviateOAuthException; +import io.weaviate.client6.v1.internal.json.JSON; + +final class OIDCTokensParser { + /** + * Prevents public instantiation. + */ + private OIDCTokensParser() { + } + + /** + * Parse HTTP response containing a new set of OIDC tokens (access_token and + * refresh_token). + * + *

+ * Nimbus expects the following format of an error response, as per RFC 6749: + * + *

+   *  {
+   *    "error_code": "invalid_client",
+   *    "error_description": "Invalid value for 'client_id' parameter."
+   *  }
+   * 
+ * + * Unfortunately, not all OIDC servers adhere to it. E.g. Okta returns + * {@code "errorCode"} and {@code "errorSummary"}, which Nimbus's + * {@link OIDCTokenResponseParser} fails to parse. In order to get a meaningful + * error message we make a second pass in case error details are incomplete + * after the first pass. + * + * @throws ParseException if {@link OIDCTokenResponseParser#parse()} + * failed. + * @throws WeaviateOAuthException if response indicates error. As determined by + * {@link TokenResponse#indicatesSuccess()}. + */ + static OIDCTokens parse(HTTPResponse httpResponse) throws ParseException { + var response = OIDCTokenResponseParser.parse(httpResponse); + if (response.indicatesSuccess()) { + return ((OIDCTokenResponse) response).getOIDCTokens(); + } + + var error = fromErrorResponse(response.toErrorResponse()); + if (!error.isComplete()) { + error = fromHttpResponse(httpResponse); + } + throw new WeaviateOAuthException("%s (code=%s)".formatted( + error.description(), + error.code())); + } + + private static record ErrorDetails( + @SerializedName("errorCode") String code, + @SerializedName("errorSummary") String description) { + + private boolean isComplete() { + return code != null && description != null; + } + } + + private static ErrorDetails fromErrorResponse(ErrorResponse response) { + var err = response.getErrorObject(); + return new ErrorDetails(err.getCode(), err.getDescription()); + } + + private static ErrorDetails fromHttpResponse(HTTPResponse response) { + return JSON.deserialize(response.getBody(), ErrorDetails.class); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/internal/oidc/nimbus/NimbusTokenProvider.java b/src/main/java/io/weaviate/client6/v1/internal/oidc/nimbus/NimbusTokenProvider.java new file mode 100644 index 000000000..5ae7bc3ac --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/internal/oidc/nimbus/NimbusTokenProvider.java @@ -0,0 +1,127 @@ +package io.weaviate.client6.v1.internal.oidc.nimbus; + +import java.io.IOException; + +import javax.annotation.concurrent.NotThreadSafe; + +import com.nimbusds.oauth2.sdk.ParseException; +import com.nimbusds.oauth2.sdk.Scope; +import com.nimbusds.oauth2.sdk.TokenRequest; +import com.nimbusds.oauth2.sdk.id.ClientID; +import com.nimbusds.openid.connect.sdk.op.OIDCProviderMetadata; +import com.nimbusds.openid.connect.sdk.token.OIDCTokens; + +import io.weaviate.client6.v1.api.WeaviateOAuthException; +import io.weaviate.client6.v1.internal.TokenProvider; +import io.weaviate.client6.v1.internal.oidc.OidcConfig; + +@NotThreadSafe +public final class NimbusTokenProvider implements TokenProvider { + private final OIDCProviderMetadata metadata; + private final ClientID clientId; + private final Scope scope; + private final Flow flow; + + /** + * Create a TokenProvider that uses Refresh Token authorization grant. + * + * @param oidc OIDC config. + * @param t Current token. Must not be null. + * + * @return A new instance of NimbusTokenProvider. Instances are never cached. + * @throws WeaviateOAuthException if an error occurred at any point of the + * exchange process. + */ + public static NimbusTokenProvider refreshToken(OidcConfig oidc, Token t) { + return new NimbusTokenProvider(oidc, Flow.refreshToken(t)); + } + + /** + * Create a TokenProvider that uses Resource Owner Password authorization grant. + * + * @param oidc OIDC config. + * @param username Resource owner username. + * @param password Resource owner password. + * + * @return A new instance of NimbusTokenProvider. Instances are never cached. + * @throws WeaviateOAuthException if an error occured at any point of the + * exchange process. + */ + public static NimbusTokenProvider resourceOwnerPassword(OidcConfig oidc, String username, String password) { + return new NimbusTokenProvider(oidc, Flow.resourceOwnerPassword(username, password)); + } + + /** + * Create a TokenProvider that uses Client Credentials authorization grant. + * + * @param oidc OIDC config. + * @param clientId Client ID. + * @param clientSecret Client secret. + * + * @return A new instance of NimbusTokenProvider. Instances are never cached. + * @throws WeaviateOAuthException if an error occured at any point of the + * exchange process. + */ + public static NimbusTokenProvider clientCredentials(OidcConfig oidc, String clientId, String clientSecret) { + return new NimbusTokenProvider(oidc, Flow.clientCredentials(clientId, clientSecret)); + } + + private NimbusTokenProvider(OidcConfig oidc, Flow flow) { + this.metadata = _parseProviderMetadata(oidc.providerMetadata()); + this.clientId = new ClientID(oidc.clientId()); + this.scope = new Scope(oidc.scopes().toArray(String[]::new)); + this.flow = flow; + } + + @Override + public Token getToken() { + var uri = metadata.getTokenEndpointURI(); + var grant = flow.getAuthorizationGrant(); + + var clientAuth = flow.getClientAuthentication(); + var tokenRequest = clientAuth == null + ? new TokenRequest(uri, clientId, grant, scope) + : new TokenRequest(uri, clientAuth, grant, scope); + var request = tokenRequest.toHTTPRequest(); + + OIDCTokens tokens; + try { + var response = request.send(); + tokens = OIDCTokensParser.parse(response); + } catch (IOException | ParseException e) { + throw new WeaviateOAuthException(e); + } + + var accessToken = tokens.getAccessToken(); + var refreshToken = tokens.getRefreshToken(); + + var newToken = refreshToken == null + ? Token.expireAfter(accessToken.getValue(), accessToken.getLifetime()) + : Token.expireAfter(accessToken.getValue(), refreshToken.getValue(), accessToken.getLifetime()); + + if (flow instanceof RefreshTokenFlow rtf) { + // Some IdP servers may omit refresh_token from the response if it is + // sufficiently long-lived. In such case we continue reusing the old one. + if (newToken.refreshToken() == null) { + var rt = rtf.getRefreshToken(); + newToken = newToken.withRefreshToken(rt); + } + rtf.setToken(newToken); + } + + return newToken; + } + + public static ProviderMetadata parseProviderMetadata(String providerMetadata) { + var metadata = _parseProviderMetadata(providerMetadata); + return new ProviderMetadata(metadata.getTokenEndpointURI()); + } + + private static OIDCProviderMetadata _parseProviderMetadata(String providerMetadata) { + try { + return OIDCProviderMetadata.parse(providerMetadata); + } catch (ParseException ex) { + throw new WeaviateOAuthException("parse provider metadata: ", ex); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/internal/oidc/nimbus/RefreshTokenFlow.java b/src/main/java/io/weaviate/client6/v1/internal/oidc/nimbus/RefreshTokenFlow.java new file mode 100644 index 000000000..ce7e18df3 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/internal/oidc/nimbus/RefreshTokenFlow.java @@ -0,0 +1,37 @@ +package io.weaviate.client6.v1.internal.oidc.nimbus; + +import javax.annotation.concurrent.NotThreadSafe; + +import com.nimbusds.oauth2.sdk.AuthorizationGrant; +import com.nimbusds.oauth2.sdk.RefreshTokenGrant; +import com.nimbusds.oauth2.sdk.token.RefreshToken; + +import io.weaviate.client6.v1.internal.TokenProvider.Token; + +/** + * RefreshTokenFlow provides {@link RefreshTokenGrant} with a refresh_token. + * Once the caller has obtained a new {@link Token} it must be updated using + * {@link #setToken} to ensure RefreshTokenFlow continues to return valid + * authorization grants. + */ +@NotThreadSafe +final class RefreshTokenFlow implements Flow { + private Token t; + + RefreshTokenFlow(Token t) { + this.t = t; + } + + @Override + public AuthorizationGrant getAuthorizationGrant() { + return new RefreshTokenGrant(new RefreshToken(t.refreshToken())); + } + + public String getRefreshToken() { + return t.refreshToken(); + } + + public void setToken(Token t) { + this.t = t; + } +} diff --git a/src/main/java/io/weaviate/client6/v1/internal/rest/AuthenticationInterceptor.java b/src/main/java/io/weaviate/client6/v1/internal/rest/AuthenticationInterceptor.java new file mode 100644 index 000000000..f4a61a7ea --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/internal/rest/AuthenticationInterceptor.java @@ -0,0 +1,83 @@ +package io.weaviate.client6.v1.internal.rest; + +import java.io.IOException; + +import org.apache.hc.client5.http.async.AsyncExecCallback; +import org.apache.hc.client5.http.async.AsyncExecChain; +import org.apache.hc.client5.http.async.AsyncExecChain.Scope; +import org.apache.hc.client5.http.async.AsyncExecChainHandler; +import org.apache.hc.core5.http.EntityDetails; +import org.apache.hc.core5.http.HttpException; +import org.apache.hc.core5.http.HttpRequest; +import org.apache.hc.core5.http.HttpRequestInterceptor; +import org.apache.hc.core5.http.message.BasicHeader; +import org.apache.hc.core5.http.nio.AsyncEntityProducer; +import org.apache.hc.core5.http.protocol.HttpContext; + +import io.weaviate.client6.v1.internal.AsyncTokenProvider; +import io.weaviate.client6.v1.internal.TokenProvider; +import io.weaviate.client6.v1.internal.TokenProvider.Token; + +/** + * AuthenticationInterceptor can supply Authorization headers to both + * synchronous and asynchronous Apache HttpClient. + */ +class AuthenticationInterceptor implements HttpRequestInterceptor, AsyncExecChainHandler, AutoCloseable { + private static final String AUTHORIZATION = "Authorization"; + + private final TokenProvider tokenProvider; + private final AsyncTokenProvider tokenProviderAsync; + + AuthenticationInterceptor(TokenProvider tokenProvider) { + this.tokenProvider = tokenProvider; + this.tokenProviderAsync = AsyncTokenProvider.wrap(tokenProvider); + } + + /** + * Add Authorization header to a blocking request. + * See {@link HttpRequestInterceptor}. + */ + @Override + public void process(HttpRequest request, EntityDetails entity, HttpContext context) + throws HttpException, IOException { + var token = tokenProvider.getToken(); + setAuthorization(request, token); + } + + /** + * Add Authorization header to a non-blocking request. + * See {@link AsyncExecChainHandler}. + */ + @Override + public void execute(HttpRequest request, AsyncEntityProducer entityProducer, Scope scope, AsyncExecChain chain, + AsyncExecCallback callback) throws HttpException, IOException { + + // CloseableHttpClient is backed by an internal I/O reactor, which runs its own + // threads for non-blocking I/O. It does not expose that executor, so we must + // schedule CompletableFutures on the default AsyncTokenProvider's executor. + tokenProviderAsync.getToken().whenComplete((tok, error) -> { + if (error != null) { + callback.failed(error instanceof Exception ex ? ex : new RuntimeException(error)); + return; + } + + setAuthorization(request, tok); + + try { + chain.proceed(request, entityProducer, scope, callback); + } catch (HttpException | IOException e) { + callback.failed(e); + } + }); + } + + private void setAuthorization(HttpRequest request, Token token) { + request.addHeader(new BasicHeader(AUTHORIZATION, "Bearer " + token.accessToken())); + } + + @Override + public void close() throws Exception { + tokenProvider.close(); + tokenProviderAsync.close(); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/internal/rest/AuthorizationInterceptor.java b/src/main/java/io/weaviate/client6/v1/internal/rest/AuthorizationInterceptor.java deleted file mode 100644 index 9fe109d23..000000000 --- a/src/main/java/io/weaviate/client6/v1/internal/rest/AuthorizationInterceptor.java +++ /dev/null @@ -1,29 +0,0 @@ -package io.weaviate.client6.v1.internal.rest; - -import java.io.IOException; - -import org.apache.hc.core5.http.EntityDetails; -import org.apache.hc.core5.http.HttpException; -import org.apache.hc.core5.http.HttpRequest; -import org.apache.hc.core5.http.HttpRequestInterceptor; -import org.apache.hc.core5.http.message.BasicHeader; -import org.apache.hc.core5.http.protocol.HttpContext; - -import io.weaviate.client6.v1.internal.TokenProvider; - -class AuthorizationInterceptor implements HttpRequestInterceptor { - private static final String AUTHORIZATION = "Authorization"; - - private final TokenProvider tokenProvider; - - AuthorizationInterceptor(TokenProvider tokenProvider) { - this.tokenProvider = tokenProvider; - } - - @Override - public void process(HttpRequest request, EntityDetails entity, HttpContext context) - throws HttpException, IOException { - var token = tokenProvider.getToken().accessToken(); - request.addHeader(new BasicHeader(AUTHORIZATION, "Bearer " + token)); - } -} diff --git a/src/main/java/io/weaviate/client6/v1/internal/rest/DefaultRestTransport.java b/src/main/java/io/weaviate/client6/v1/internal/rest/DefaultRestTransport.java index f8b4d255b..66e1f7862 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/rest/DefaultRestTransport.java +++ b/src/main/java/io/weaviate/client6/v1/internal/rest/DefaultRestTransport.java @@ -34,6 +34,8 @@ public class DefaultRestTransport implements RestTransport { private final CloseableHttpAsyncClient httpClientAsync; private final RestTransportOptions transportOptions; + private AuthenticationInterceptor authInterceptor; + public DefaultRestTransport(RestTransportOptions transportOptions) { this.transportOptions = transportOptions; @@ -65,9 +67,9 @@ public DefaultRestTransport(RestTransportOptions transportOptions) { } if (transportOptions.tokenProvider() != null) { - var interceptor = new AuthorizationInterceptor(transportOptions.tokenProvider()); - httpClient.addRequestInterceptorFirst(interceptor); - httpClientAsync.addRequestInterceptorFirst(interceptor); + this.authInterceptor = new AuthenticationInterceptor(transportOptions.tokenProvider()); + httpClient.addRequestInterceptorFirst(authInterceptor); + httpClientAsync.addExecInterceptorFirst("auth", authInterceptor); } this.httpClient = httpClient.build(); @@ -76,8 +78,7 @@ public DefaultRestTransport(RestTransportOptions transportOptions) { } private String uri(Endpoint ep, RequestT req) { - return transportOptions.baseUrl() - + ep.requestUrl(req) + return ep.requestUrl(transportOptions, req) + UrlEncoder.encodeQuery(ep.queryParameters(req)); } @@ -85,6 +86,7 @@ private String uri(Endpoint ep, RequestT req) { public ResponseT performRequest(RequestT request, Endpoint endpoint) throws IOException { + var req = prepareClassicRequest(request, endpoint); return this.httpClient.execute(req, r -> this.handleResponse(endpoint, req.getMethod(), req.getRequestUri(), r)); } @@ -183,8 +185,11 @@ private ResponseT _handleResponse(Endpoint endpoint, S } @Override - public void close() throws IOException { + public void close() throws Exception { httpClient.close(); httpClientAsync.close(CloseMode.GRACEFUL); + if (authInterceptor != null) { + authInterceptor.close(); + } } } diff --git a/src/main/java/io/weaviate/client6/v1/internal/rest/Endpoint.java b/src/main/java/io/weaviate/client6/v1/internal/rest/Endpoint.java index 6e1e33760..a3af57b88 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/rest/Endpoint.java +++ b/src/main/java/io/weaviate/client6/v1/internal/rest/Endpoint.java @@ -8,6 +8,10 @@ public interface Endpoint { String requestUrl(RequestT request); + default String requestUrl(RestTransportOptions transportOptions, RequestT request) { + return transportOptions.baseUrl() + requestUrl(request); + } + String body(RequestT request); Map queryParameters(RequestT request); diff --git a/src/main/java/io/weaviate/client6/v1/internal/rest/EndpointBase.java b/src/main/java/io/weaviate/client6/v1/internal/rest/EndpointBase.java index 4776d4afe..29cf803d7 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/rest/EndpointBase.java +++ b/src/main/java/io/weaviate/client6/v1/internal/rest/EndpointBase.java @@ -59,6 +59,9 @@ public boolean isError(int statusCode) { @Override public String deserializeError(int statusCode, String responseBody) { + if (responseBody.isBlank()) { + return responseBody; + } { var response = JSON.deserialize(responseBody, ErrorResponse1.class); if (response.errors != null && !response.errors.isEmpty()) { diff --git a/src/main/java/io/weaviate/client6/v1/internal/rest/ExternalEndpoint.java b/src/main/java/io/weaviate/client6/v1/internal/rest/ExternalEndpoint.java new file mode 100644 index 000000000..0331fdd53 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/internal/rest/ExternalEndpoint.java @@ -0,0 +1,23 @@ +package io.weaviate.client6.v1.internal.rest; + +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Function; + +public class ExternalEndpoint extends SimpleEndpoint { + + public ExternalEndpoint( + Function method, + Function requestUrl, + Function> queryParameters, + Function body, + BiFunction deserializeResponse) { + super(method, requestUrl, queryParameters, body, deserializeResponse); + } + + /** Returns {@code requestUrl} without {@code baseUrl} prefix. */ + @Override + public String requestUrl(RestTransportOptions __, RequestT request) { + return requestUrl(request); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/internal/rest/RestTransport.java b/src/main/java/io/weaviate/client6/v1/internal/rest/RestTransport.java index 1fc25836b..da26c9f12 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/rest/RestTransport.java +++ b/src/main/java/io/weaviate/client6/v1/internal/rest/RestTransport.java @@ -1,10 +1,9 @@ package io.weaviate.client6.v1.internal.rest; -import java.io.Closeable; import java.io.IOException; import java.util.concurrent.CompletableFuture; -public interface RestTransport extends Closeable { +public interface RestTransport extends AutoCloseable { ResponseT performRequest(RequestT request, Endpoint endpoint) throws IOException; diff --git a/src/main/java/io/weaviate/client6/v1/internal/rest/SimpleEndpoint.java b/src/main/java/io/weaviate/client6/v1/internal/rest/SimpleEndpoint.java index 963cf4e3a..c9b8d441f 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/rest/SimpleEndpoint.java +++ b/src/main/java/io/weaviate/client6/v1/internal/rest/SimpleEndpoint.java @@ -4,6 +4,8 @@ import java.util.function.BiFunction; import java.util.function.Function; +import io.weaviate.client6.v1.internal.json.JSON; + public class SimpleEndpoint extends EndpointBase implements JsonEndpoint { private static final BiFunction NULL_RESPONSE = (__code, __body) -> null; @@ -14,6 +16,10 @@ protected static BiFunction nullResponse() { return NULL_RESPONSE; } + protected static BiFunction deserializeClass(Class cls) { + return (statusCode, response) -> JSON.deserialize(response, cls); + } + public static SimpleEndpoint noBody( Function method, Function requestUrl, @@ -22,6 +28,14 @@ public static SimpleEndpoint noBody( return new SimpleEndpoint<>(method, requestUrl, queryParameters, nullBody(), deserializeResponse); } + public static SimpleEndpoint noBody( + Function method, + Function requestUrl, + Function> queryParameters, + Class cls) { + return new SimpleEndpoint<>(method, requestUrl, queryParameters, nullBody(), deserializeClass(cls)); + } + public static SimpleEndpoint sideEffect( Function method, Function requestUrl, diff --git a/src/test/java/io/weaviate/client6/v1/api/AuthorizationTest.java b/src/test/java/io/weaviate/client6/v1/api/AuthenticationTest.java similarity index 58% rename from src/test/java/io/weaviate/client6/v1/api/AuthorizationTest.java rename to src/test/java/io/weaviate/client6/v1/api/AuthenticationTest.java index 8c4d375e0..ea8531d48 100644 --- a/src/test/java/io/weaviate/client6/v1/api/AuthorizationTest.java +++ b/src/test/java/io/weaviate/client6/v1/api/AuthenticationTest.java @@ -3,6 +3,7 @@ import java.io.IOException; import java.util.Collections; +import org.assertj.core.api.Assertions; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -10,11 +11,13 @@ import org.mockserver.model.HttpRequest; import io.weaviate.client6.v1.internal.rest.DefaultRestTransport; -import io.weaviate.client6.v1.internal.rest.OptionalEndpoint; +import io.weaviate.client6.v1.internal.rest.RestTransport; import io.weaviate.client6.v1.internal.rest.RestTransportOptions; +import io.weaviate.client6.v1.internal.rest.SimpleEndpoint; -public class AuthorizationTest { +public class AuthenticationTest { private ClientAndServer mockServer; + private RestTransport noAuthTransport; @Before public void startMockServer() throws IOException { @@ -27,17 +30,26 @@ public void startMockServer() throws IOException { // if another webserver is listening to that port. // We use 0 to let the underlying system find an available port. mockServer = ClientAndServer.startClientAndServer(0); + noAuthTransport = new DefaultRestTransport( + new RestTransportOptions( + "http", "localhost", mockServer.getLocalPort(), + Collections.emptyMap(), null, null)); } @Test - public void testAuthorization_apiKey() throws IOException { + public void testAuthentication_apiKey() throws Exception { + var authz = Authentication.apiKey("my-api-key"); var transportOptions = new RestTransportOptions( "http", "localhost", mockServer.getLocalPort(), - Collections.emptyMap(), Authorization.apiKey("my-api-key"), null); + Collections.emptyMap(), authz.getTokenProvider(noAuthTransport), null); try (final var restClient = new DefaultRestTransport(transportOptions)) { - restClient.performRequest(null, OptionalEndpoint.noBodyOptional( - request -> "GET", request -> "/", request -> null, (code, response) -> null)); + restClient.performRequest(null, SimpleEndpoint.sideEffect( + request -> "GET", request -> "/", request -> null)); + } catch (WeaviateApiException ex) { + if (ex.httpStatusCode() != 404) { + Assertions.fail("unexpected error", ex); + } } mockServer.verify( @@ -48,7 +60,8 @@ public void testAuthorization_apiKey() throws IOException { } @After - public void stopMockServer() { + public void stopMockServer() throws Exception { mockServer.stop(); + noAuthTransport.close(); } } diff --git a/src/test/java/io/weaviate/client6/v1/internal/TokenTest.java b/src/test/java/io/weaviate/client6/v1/internal/TokenTest.java new file mode 100644 index 000000000..789ffca1a --- /dev/null +++ b/src/test/java/io/weaviate/client6/v1/internal/TokenTest.java @@ -0,0 +1,43 @@ +package io.weaviate.client6.v1.internal; + +import java.time.Instant; + +import org.assertj.core.api.Assertions; +import org.junit.Test; +import org.junit.runner.RunWith; + +import com.jparams.junit4.JParamsTestRunner; +import com.jparams.junit4.data.DataMethod; + +import io.weaviate.client6.v1.internal.TokenProvider.Token; + +@RunWith(JParamsTestRunner.class) +public class TokenTest { + + public static Object[][] testCaseTokens() { + return new Object[][] { + { Token.expireNever("access_token"), true }, + { Token.expireAfter("access_token", "refresh_token", 100), true }, + { Token.expireAfter("access_token", 100), true }, + { new Token("access_token", "refresh_token", Instant.now().minusSeconds(10), 5, 0), false }, + { Token.expireAfter("access_token", 0), false }, + { Token.expireAfter("access_token", 5).withExpiryDelta(10), false }, + { Token.expireAfter("access_token", 100).withExpiryDelta(10), true }, + }; + } + + @DataMethod(source = TokenTest.class, method = "testCaseTokens") + @Test + public void test_isValid(Token token, boolean wantValid) { + Assertions.assertThat(token.isValid()) + .as(token.toString()) + .isEqualTo(wantValid); + } + + @Test + public void test_expiryDeltaNonNegative() { + var t = Token.expireAfter("access_token", 90L); + var expireLater = t.withExpiryDelta(-10L); + Assertions.assertThat(expireLater.expiryDelta()).as("expiryDelta must be >= 0").isEqualTo(0L); + } +} diff --git a/src/test/java/io/weaviate/client6/v1/internal/rest/DefaultRestTransportTest.java b/src/test/java/io/weaviate/client6/v1/internal/rest/DefaultRestTransportTest.java index cf6de61a8..28909df31 100644 --- a/src/test/java/io/weaviate/client6/v1/internal/rest/DefaultRestTransportTest.java +++ b/src/test/java/io/weaviate/client6/v1/internal/rest/DefaultRestTransportTest.java @@ -73,7 +73,7 @@ public void testCustomTrustStore_async() throws IOException, ExecutionException, } @After - public void tearDown() throws IOException { + public void tearDown() throws Exception { mockServer.stop(); transport.close(); }