diff --git a/docs/src/main/asciidoc/security-oidc-bearer-token-authentication.adoc b/docs/src/main/asciidoc/security-oidc-bearer-token-authentication.adoc index a636c7dc392d8..16261ebf9deb5 100644 --- a/docs/src/main/asciidoc/security-oidc-bearer-token-authentication.adoc +++ b/docs/src/main/asciidoc/security-oidc-bearer-token-authentication.adoc @@ -1110,6 +1110,12 @@ xref:security-openid-connect-multitenancy.adoc#tenant-config-resolver[Dynamic te Authentication that requires dynamic tenant will fail. ==== +[[oidc-request-filters]] +== OIDC request filters + +You can filter OIDC requests made by Quarkus to the OIDC provider by registering one or more `OidcRequestFiler` implementations, which can update or add new request headers, as well as log requests. +For more information, see xref:security-code-flow-authentication#oidc-request-filters[OIDC request filters]. + == References * xref:security-oidc-configuration-properties-reference.adoc[OIDC configuration properties] diff --git a/docs/src/main/asciidoc/security-oidc-code-flow-authentication.adoc b/docs/src/main/asciidoc/security-oidc-code-flow-authentication.adoc index 96bafe7ccd471..d0e37b9d320d4 100644 --- a/docs/src/main/asciidoc/security-oidc-code-flow-authentication.adoc +++ b/docs/src/main/asciidoc/security-oidc-code-flow-authentication.adoc @@ -279,11 +279,70 @@ quarkus.oidc.introspection-credentials.name=introspection-user-name quarkus.oidc.introspection-credentials.secret=introspection-user-secret ---- -[[oidc-client-filters]] -==== OIDC request customization +[[oidc-request-filters]] +==== OIDC request filters -You can customize OIDC requests made by Quarkus to the OIDC provider by registering one or more `OidcRequestFiler` implementations, which can update or add new request headers. -For more information, see xref:security-openid-connect-client-reference#oidc-client-filters[Client request customization]. +You can filter OIDC requests made by Quarkus to the OIDC provider by registering one or more `OidcRequestFiler` implementations, which can update or add new request headers, as well as log requests. + +For example: + +[source,java] +---- +package io.quarkus.it.keycloak; + +import jakarta.enterprise.context.ApplicationScoped; + +import io.quarkus.arc.Unremovable; +import io.quarkus.oidc.common.OidcRequestContextProperties; +import io.quarkus.oidc.common.OidcRequestFilter; +import io.vertx.mutiny.core.buffer.Buffer; +import io.vertx.mutiny.ext.web.client.HttpRequest; + +@ApplicationScoped +@Unremovable +public class OidcTokenRequestCustomizer implements OidcRequestFilter { + @Override + public void filter(HttpRequest request, Buffer buffer, OidcRequestContextProperties contextProps) { + OidcConfigurationMetadata metadata = contextProps.get(OidcConfigurationMetadata.class.getName()); <1> + // Metadata URI is absolute, request URI value is relative + if (metadata.getTokenUri().endsWith(request.uri())) { <2> + request.putHeader("TokenGrantDigest", calculateDigest(buffer.toString())); + } + } + private String calculateDigest(String bodyString) { + // Apply the required digest algorithm to the body string + } +} +---- +<1> Get `OidcConfigurationMetadata` which contains all supported OIDC endpoint addresses. +<2> Use `OidcConfigurationMetadata` to filter requests to the OIDC token endpoint only. + +Alternatively, you can use `OidcRequestFilter.Endpoint` enum to make sure this filter is applied to the token endpoint requests only: + +[source,java] +---- +import jakarta.enterprise.context.ApplicationScoped; + +import io.quarkus.arc.Unremovable; +import io.quarkus.oidc.common.OidcEndpoint; +import io.quarkus.oidc.common.OidcEndpoint.Type; +import io.quarkus.oidc.common.OidcRequestContextProperties; +import io.quarkus.oidc.common.OidcRequestFilter; +import io.vertx.mutiny.core.buffer.Buffer; +import io.vertx.mutiny.ext.web.client.HttpRequest; + +@ApplicationScoped +@Unremovable +@OidcEndpoint(value = Type.DISCOVERY) <1> +public class OidcDiscoveryRequestCustomizer implements OidcRequestFilter { + + @Override + public void filter(HttpRequest request, Buffer buffer, OidcRequestContextProperties contextProps) { + request.putHeader("Discovery", "OK"); + } +} +---- +<1> Restrict this filter to requests targeting the OIDC discovery endpoint only. ==== Redirecting to and from the OIDC provider diff --git a/docs/src/main/asciidoc/security-openid-connect-client-reference.adoc b/docs/src/main/asciidoc/security-openid-connect-client-reference.adoc index 692575db26ef7..c547625a21a67 100644 --- a/docs/src/main/asciidoc/security-openid-connect-client-reference.adoc +++ b/docs/src/main/asciidoc/security-openid-connect-client-reference.adoc @@ -872,10 +872,10 @@ quarkus.log.category."io.quarkus.oidc.client.runtime.OidcClientRecorder".level=T quarkus.log.category."io.quarkus.oidc.client.runtime.OidcClientRecorder".min-level=TRACE ---- -[[oidc-client-filters]] -== OIDC request customization +[[oidc-request-filters]] +== OIDC request filters -You can customize OIDC requests made by Quarkus to the OIDC provider by registering one or more `OidcRequestFiler` implementations which can update or add new request headers, for example, a filter can analyze the request body and add its digest as a new header value: +You can filter OIDC requests made by Quarkus to the OIDC provider by registering one or more `OidcRequestFiler` implementations which can update or add new request headers, for example, a filter can analyze the request body and add its digest as a new header value: [source,java] ---- diff --git a/extensions/oidc-client/runtime/src/main/java/io/quarkus/oidc/client/runtime/OidcClientImpl.java b/extensions/oidc-client/runtime/src/main/java/io/quarkus/oidc/client/runtime/OidcClientImpl.java index 0f0252a3a003e..8dcf143c6cadb 100644 --- a/extensions/oidc-client/runtime/src/main/java/io/quarkus/oidc/client/runtime/OidcClientImpl.java +++ b/extensions/oidc-client/runtime/src/main/java/io/quarkus/oidc/client/runtime/OidcClientImpl.java @@ -17,6 +17,8 @@ import io.quarkus.oidc.client.OidcClientConfig; import io.quarkus.oidc.client.OidcClientException; import io.quarkus.oidc.client.Tokens; +import io.quarkus.oidc.common.OidcEndpoint; +import io.quarkus.oidc.common.OidcRequestContextProperties; import io.quarkus.oidc.common.OidcRequestFilter; import io.quarkus.oidc.common.runtime.OidcCommonUtils; import io.quarkus.oidc.common.runtime.OidcConstants; @@ -46,12 +48,12 @@ public class OidcClientImpl implements OidcClient { private final String clientSecretBasicAuthScheme; private final Key clientJwtKey; private final OidcClientConfig oidcConfig; - private final List filters; + private final Map> filters; private volatile boolean closed; public OidcClientImpl(WebClient client, String tokenRequestUri, String tokenRevokeUri, String grantType, MultiMap tokenGrantParams, MultiMap commonRefreshGrantParams, OidcClientConfig oidcClientConfig, - List filters) { + Map> filters) { this.client = client; this.tokenRequestUri = tokenRequestUri; this.tokenRevokeUri = tokenRevokeUri; @@ -71,7 +73,7 @@ public Uni getTokens(Map additionalGrantParameters) { throw new OidcClientException( "Only 'refresh_token' grant is supported, please call OidcClient#refreshTokens method instead"); } - return getJsonResponse(tokenGrantParams, additionalGrantParameters, false); + return getJsonResponse(OidcEndpoint.Type.TOKEN, tokenGrantParams, additionalGrantParameters, false); } @Override @@ -82,7 +84,7 @@ public Uni refreshTokens(String refreshToken, Map additi } MultiMap refreshGrantParams = copyMultiMap(commonRefreshGrantParams); refreshGrantParams.add(OidcConstants.REFRESH_TOKEN_VALUE, refreshToken); - return getJsonResponse(refreshGrantParams, additionalGrantParameters, true); + return getJsonResponse(OidcEndpoint.Type.TOKEN, refreshGrantParams, additionalGrantParameters, true); } @Override @@ -94,7 +96,8 @@ public Uni revokeAccessToken(String accessToken, Map ad if (tokenRevokeUri != null) { MultiMap tokenRevokeParams = new MultiMap(io.vertx.core.MultiMap.caseInsensitiveMultiMap()); tokenRevokeParams.set(OidcConstants.REVOCATION_TOKEN, accessToken); - return postRequest(client.postAbs(tokenRevokeUri), tokenRevokeParams, additionalParameters, false) + return postRequest(OidcEndpoint.Type.TOKEN_REVOCATION, client.postAbs(tokenRevokeUri), tokenRevokeParams, + additionalParameters, false) .transform(resp -> toRevokeResponse(resp)); } else { LOG.debugf("%s OidcClient can not revoke the access token because the revocation endpoint URL is not set"); @@ -111,20 +114,23 @@ private Boolean toRevokeResponse(HttpResponse resp) { return resp.statusCode() == 503 ? false : true; } - private Uni getJsonResponse(MultiMap formBody, Map additionalGrantParameters, boolean refresh) { + private Uni getJsonResponse(OidcEndpoint.Type endpointType, MultiMap formBody, + Map additionalGrantParameters, + boolean refresh) { //Uni needs to be lazy by default, we don't send the request unless //something has subscribed to it. This is important for the CAS state //management in TokensHelper return Uni.createFrom().deferred(new Supplier>() { @Override public Uni get() { - return postRequest(client.postAbs(tokenRequestUri), formBody, additionalGrantParameters, refresh) + return postRequest(endpointType, client.postAbs(tokenRequestUri), formBody, additionalGrantParameters, refresh) .transform(resp -> emitGrantTokens(resp, refresh)); } }); } - private UniOnItem> postRequest(HttpRequest request, MultiMap formBody, + private UniOnItem> postRequest(OidcEndpoint.Type endpointType, HttpRequest request, + MultiMap formBody, Map additionalGrantParameters, boolean refresh) { MultiMap body = formBody; @@ -165,7 +171,7 @@ private UniOnItem> postRequest(HttpRequest request, } // Retry up to three times with a one-second delay between the retries if the connection is closed Buffer buffer = OidcCommonUtils.encodeForm(body); - Uni> response = filter(request, buffer).sendBuffer(buffer) + Uni> response = filter(endpointType, request, buffer).sendBuffer(buffer) .onFailure(ConnectException.class) .retry() .atMost(oidcConfig.connectionRetryCount) @@ -259,9 +265,12 @@ private void checkClosed() { } } - private HttpRequest filter(HttpRequest request, Buffer body) { - for (OidcRequestFilter filter : filters) { - filter.filter(request, body, null); + private HttpRequest filter(OidcEndpoint.Type endpointType, HttpRequest request, Buffer body) { + if (!filters.isEmpty()) { + OidcRequestContextProperties props = new OidcRequestContextProperties(); + for (OidcRequestFilter filter : OidcCommonUtils.getMatchingOidcRequestFilters(filters, endpointType)) { + filter.filter(request, body, props); + } } return request; } diff --git a/extensions/oidc-client/runtime/src/main/java/io/quarkus/oidc/client/runtime/OidcClientRecorder.java b/extensions/oidc-client/runtime/src/main/java/io/quarkus/oidc/client/runtime/OidcClientRecorder.java index 88004463f2e5f..cff9f35a930cc 100644 --- a/extensions/oidc-client/runtime/src/main/java/io/quarkus/oidc/client/runtime/OidcClientRecorder.java +++ b/extensions/oidc-client/runtime/src/main/java/io/quarkus/oidc/client/runtime/OidcClientRecorder.java @@ -17,6 +17,7 @@ import io.quarkus.oidc.client.OidcClientException; import io.quarkus.oidc.client.OidcClients; import io.quarkus.oidc.client.Tokens; +import io.quarkus.oidc.common.OidcEndpoint; import io.quarkus.oidc.common.OidcRequestFilter; import io.quarkus.oidc.common.runtime.OidcCommonUtils; import io.quarkus.oidc.common.runtime.OidcConstants; @@ -122,7 +123,7 @@ protected static Uni createOidcClientUni(OidcClientConfig oidcConfig WebClient client = WebClient.create(new io.vertx.mutiny.core.Vertx(vertx.get()), options); - List clientRequestFilters = OidcCommonUtils.getClientRequestCustomizer(); + Map> oidcRequestFilters = OidcCommonUtils.getOidcRequestFilters(); Uni tokenUrisUni = null; if (OidcCommonUtils.isAbsoluteUrl(oidcConfig.tokenPath)) { @@ -137,7 +138,7 @@ protected static Uni createOidcClientUni(OidcClientConfig oidcConfig OidcCommonUtils.getOidcEndpointUrl(authServerUriString, oidcConfig.tokenPath), OidcCommonUtils.getOidcEndpointUrl(authServerUriString, oidcConfig.revokePath))); } else { - tokenUrisUni = discoverTokenUris(client, clientRequestFilters, authServerUriString.toString(), oidcConfig); + tokenUrisUni = discoverTokenUris(client, oidcRequestFilters, authServerUriString.toString(), oidcConfig); } } return tokenUrisUni.onItemOrFailure() @@ -193,7 +194,7 @@ public OidcClient apply(OidcConfigurationMetadata metadata, Throwable t) { tokenGrantParams, commonRefreshGrantParams, oidcConfig, - clientRequestFilters); + oidcRequestFilters); } }); @@ -211,10 +212,10 @@ private static void setGrantClientParams(OidcClientConfig oidcConfig, MultiMap g } private static Uni discoverTokenUris(WebClient client, - List clientRequestFilters, + Map> oidcRequestFilters, String authServerUrl, OidcClientConfig oidcConfig) { final long connectionDelayInMillisecs = OidcCommonUtils.getConnectionDelayInMillis(oidcConfig); - return OidcCommonUtils.discoverMetadata(client, clientRequestFilters, authServerUrl, connectionDelayInMillisecs) + return OidcCommonUtils.discoverMetadata(client, oidcRequestFilters, authServerUrl, connectionDelayInMillisecs) .onItem().transform(json -> new OidcConfigurationMetadata(json.getString("token_endpoint"), json.getString("revocation_endpoint"))); } diff --git a/extensions/oidc-common/runtime/src/main/java/io/quarkus/oidc/common/OidcEndpoint.java b/extensions/oidc-common/runtime/src/main/java/io/quarkus/oidc/common/OidcEndpoint.java new file mode 100644 index 0000000000000..2707f8f3bb09c --- /dev/null +++ b/extensions/oidc-common/runtime/src/main/java/io/quarkus/oidc/common/OidcEndpoint.java @@ -0,0 +1,52 @@ +package io.quarkus.oidc.common; + +import static java.lang.annotation.ElementType.TYPE; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +/** + * Annotation that can be used to restrict {@link OidcRequestFilter} to specific OIDC endpoints + */ +@Target({ TYPE }) +@Retention(RUNTIME) +public @interface OidcEndpoint { + + enum Type { + ALL, + + /** + * Applies to OIDC discovery requests + */ + DISCOVERY, + + /** + * Applies to OIDC token endpoint requests + */ + TOKEN, + + /** + * Applies to OIDC token revocation endpoint requests + */ + TOKEN_REVOCATION, + + /** + * Applies to OIDC token introspection requests + */ + INTROSPECTION, + /** + * Applies to OIDC JSON Web Key Set endpoint requests + */ + JWKS, + /** + * Applies to OIDC UserInfo endpoint requests + */ + USERINFO + } + + /** + * Identifies an OIDC tenant to which a given feature applies. + */ + Type value() default Type.ALL; +} diff --git a/extensions/oidc-common/runtime/src/main/java/io/quarkus/oidc/common/OidcRequestContextProperties.java b/extensions/oidc-common/runtime/src/main/java/io/quarkus/oidc/common/OidcRequestContextProperties.java index d7a1f620a48af..e5dee80db7fe3 100644 --- a/extensions/oidc-common/runtime/src/main/java/io/quarkus/oidc/common/OidcRequestContextProperties.java +++ b/extensions/oidc-common/runtime/src/main/java/io/quarkus/oidc/common/OidcRequestContextProperties.java @@ -1,20 +1,28 @@ package io.quarkus.oidc.common; +import java.util.Collections; import java.util.Map; public class OidcRequestContextProperties { public static String TOKEN = "token"; public static String TOKEN_CREDENTIAL = "token_credential"; + public static String DISCOVERY_ENDPOINT = "discovery_endpoint"; private final Map properties; + public OidcRequestContextProperties() { + this(Map.of()); + } + public OidcRequestContextProperties(Map properties) { this.properties = properties; } - public Object get(String name) { - return properties.get(name); + public T get(String name) { + @SuppressWarnings("unchecked") + T value = (T) properties.get(name); + return value; } public String getString(String name) { @@ -25,4 +33,8 @@ public T get(String name, Class type) { return type.cast(get(name)); } + public Map getAll() { + return Collections.unmodifiableMap(properties); + } + } diff --git a/extensions/oidc-common/runtime/src/main/java/io/quarkus/oidc/common/OidcRequestFilter.java b/extensions/oidc-common/runtime/src/main/java/io/quarkus/oidc/common/OidcRequestFilter.java index 7318f34eff3b1..93834a53fb41e 100644 --- a/extensions/oidc-common/runtime/src/main/java/io/quarkus/oidc/common/OidcRequestFilter.java +++ b/extensions/oidc-common/runtime/src/main/java/io/quarkus/oidc/common/OidcRequestFilter.java @@ -5,15 +5,18 @@ /** * Request filter which can be used to customize requests such as the verification JsonWebKey set and token grant requests - * which are made from the OIDC adapter to the OIDC provider + * which are made from the OIDC adapter to the OIDC provider. + *

+ * Filter can be restricted to a specific OIDC endpoint with a {@link OidcEndpoint} annotation. */ public interface OidcRequestFilter { + /** * Filter OIDC requests * * @param request HTTP request that can have its headers customized * @param body request body, will be null for HTTP GET methods, may be null for other HTTP methods - * @param contextProperties context properties that can be available in context of some requests, can be null + * @param contextProperties context properties that can be available in context of some requests */ void filter(HttpRequest request, Buffer requestBody, OidcRequestContextProperties contextProperties); } diff --git a/extensions/oidc-common/runtime/src/main/java/io/quarkus/oidc/common/runtime/OidcCommonUtils.java b/extensions/oidc-common/runtime/src/main/java/io/quarkus/oidc/common/runtime/OidcCommonUtils.java index 3b0e9dbdc8f76..2bb75cba468df 100644 --- a/extensions/oidc-common/runtime/src/main/java/io/quarkus/oidc/common/runtime/OidcCommonUtils.java +++ b/extensions/oidc-common/runtime/src/main/java/io/quarkus/oidc/common/runtime/OidcCommonUtils.java @@ -13,7 +13,9 @@ import java.security.KeyStore; import java.security.PrivateKey; import java.time.Duration; +import java.util.ArrayList; import java.util.Base64; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -30,6 +32,8 @@ import io.quarkus.arc.ArcContainer; import io.quarkus.credentials.CredentialsProvider; import io.quarkus.credentials.runtime.CredentialsProviderFinder; +import io.quarkus.oidc.common.OidcEndpoint; +import io.quarkus.oidc.common.OidcRequestContextProperties; import io.quarkus.oidc.common.OidcRequestFilter; import io.quarkus.oidc.common.runtime.OidcCommonConfig.Credentials; import io.quarkus.oidc.common.runtime.OidcCommonConfig.Credentials.Provider; @@ -427,12 +431,16 @@ public static Predicate oidcEndpointNotAvailable() { || (t instanceof OidcEndpointAccessException && ((OidcEndpointAccessException) t).getErrorStatus() == 404)); } - public static Uni discoverMetadata(WebClient client, List filters, + public static Uni discoverMetadata(WebClient client, Map> filters, String authServerUrl, long connectionDelayInMillisecs) { - final String discoveryUrl = authServerUrl + OidcConstants.WELL_KNOWN_CONFIGURATION; + final String discoveryUrl = getDiscoveryUri(authServerUrl); HttpRequest request = client.getAbs(discoveryUrl); - for (OidcRequestFilter filter : filters) { - filter.filter(request, null, null); + if (!filters.isEmpty()) { + OidcRequestContextProperties requestProps = new OidcRequestContextProperties( + Map.of(OidcRequestContextProperties.DISCOVERY_ENDPOINT, discoveryUrl)); + for (OidcRequestFilter filter : getMatchingOidcRequestFilters(filters, OidcEndpoint.Type.DISCOVERY)) { + filter.filter(request, null, requestProps); + } } return request.send().onItem().transform(resp -> { if (resp.statusCode() == 200) { @@ -452,6 +460,10 @@ public static Uni discoverMetadata(WebClient client, List getClientRequestCustomizer() { + public static Map> getOidcRequestFilters() { ArcContainer container = Arc.container(); if (container != null) { - return container.listAll(OidcRequestFilter.class).stream().map(handle -> handle.get()) - .collect(Collectors.toList()); + Map> map = new HashMap<>(); + for (OidcRequestFilter filter : container.listAll(OidcRequestFilter.class).stream().map(handle -> handle.get()) + .collect(Collectors.toList())) { + OidcEndpoint endpoint = filter.getClass().getAnnotation(OidcEndpoint.class); + OidcEndpoint.Type type = endpoint != null ? endpoint.value() : OidcEndpoint.Type.ALL; + map.computeIfAbsent(type, k -> new ArrayList()).add(filter); + } + return map; + } + return Map.of(); + } + + public static List getMatchingOidcRequestFilters(Map> filters, + OidcEndpoint.Type type) { + List typeSpecific = filters.get(type); + List all = filters.get(OidcEndpoint.Type.ALL); + if (typeSpecific == null && all == null) { + return List.of(); } - return List.of(); + if (typeSpecific != null && all == null) { + return typeSpecific; + } else if (typeSpecific == null && all != null) { + return all; + } else { + List combined = new ArrayList<>(typeSpecific.size() + all.size()); + combined.addAll(typeSpecific); + combined.addAll(all); + return combined; + } + } } diff --git a/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/OidcConfigurationMetadata.java b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/OidcConfigurationMetadata.java index 4e7795a802431..aee8379f99360 100644 --- a/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/OidcConfigurationMetadata.java +++ b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/OidcConfigurationMetadata.java @@ -17,6 +17,7 @@ public class OidcConfigurationMetadata { private static final String END_SESSION_ENDPOINT = "end_session_endpoint"; private static final String SCOPES_SUPPORTED = "scopes_supported"; + private final String discoveryUri; private final String tokenUri; private final String introspectionUri; private final String authorizationUri; @@ -33,6 +34,7 @@ public OidcConfigurationMetadata(String tokenUri, String userInfoUri, String endSessionUri, String issuer) { + this.discoveryUri = null; this.tokenUri = tokenUri; this.introspectionUri = introspectionUri; this.authorizationUri = authorizationUri; @@ -44,10 +46,12 @@ public OidcConfigurationMetadata(String tokenUri, } public OidcConfigurationMetadata(JsonObject wellKnownConfig) { - this(wellKnownConfig, null); + this(wellKnownConfig, null, null); } - public OidcConfigurationMetadata(JsonObject wellKnownConfig, OidcConfigurationMetadata localMetadataConfig) { + public OidcConfigurationMetadata(JsonObject wellKnownConfig, OidcConfigurationMetadata localMetadataConfig, + String discoveryUri) { + this.discoveryUri = discoveryUri; this.tokenUri = getMetadataValue(wellKnownConfig, TOKEN_ENDPOINT, localMetadataConfig == null ? null : localMetadataConfig.tokenUri); this.introspectionUri = getMetadataValue(wellKnownConfig, INTROSPECTION_ENDPOINT, @@ -69,6 +73,10 @@ private static String getMetadataValue(JsonObject wellKnownConfig, String proper return localValue != null ? localValue : wellKnownConfig.getString(propertyName); } + public String getDiscoveryUri() { + return discoveryUri; + } + public String getTokenUri() { return tokenUri; } diff --git a/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcProviderClient.java b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcProviderClient.java index 204c38984259c..4aad502590622 100644 --- a/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcProviderClient.java +++ b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcProviderClient.java @@ -4,6 +4,7 @@ import java.net.ConnectException; import java.nio.charset.StandardCharsets; import java.security.Key; +import java.util.HashMap; import java.util.List; import java.util.Map; @@ -15,6 +16,7 @@ import io.quarkus.oidc.OidcTenantConfig; import io.quarkus.oidc.TokenIntrospection; import io.quarkus.oidc.UserInfo; +import io.quarkus.oidc.common.OidcEndpoint; import io.quarkus.oidc.common.OidcRequestContextProperties; import io.quarkus.oidc.common.OidcRequestFilter; import io.quarkus.oidc.common.runtime.OidcCommonUtils; @@ -48,13 +50,13 @@ public class OidcProviderClient implements Closeable { private final String clientSecretBasicAuthScheme; private final String introspectionBasicAuthScheme; private final Key clientJwtKey; - private final List filters; + private final Map> filters; public OidcProviderClient(WebClient client, Vertx vertx, OidcConfigurationMetadata metadata, OidcTenantConfig oidcConfig, - List filters) { + Map> filters) { this.client = client; this.vertx = vertx; this.metadata = metadata; @@ -80,13 +82,14 @@ public OidcConfigurationMetadata getMetadata() { } public Uni getJsonWebKeySet(OidcRequestContextProperties contextProperties) { - return filter(client.getAbs(metadata.getJsonWebKeySetUri()), null, contextProperties).send().onItem() + return filter(OidcEndpoint.Type.JWKS, client.getAbs(metadata.getJsonWebKeySetUri()), null, contextProperties).send() + .onItem() .transform(resp -> getJsonWebKeySet(resp)); } public Uni getUserInfo(String token) { LOG.debugf("Get UserInfo on: %s auth: %s", metadata.getUserInfoUri(), OidcConstants.BEARER_SCHEME + " " + token); - return filter(client.getAbs(metadata.getUserInfoUri()), null, null) + return filter(OidcEndpoint.Type.USERINFO, client.getAbs(metadata.getUserInfoUri()), null, null) .putHeader(AUTHORIZATION_HEADER, OidcConstants.BEARER_SCHEME + " " + token) .send().onItem().transform(resp -> getUserInfo(resp)); } @@ -168,7 +171,9 @@ private UniOnItem> getHttpResponse(String uri, MultiMap for LOG.debugf("Get token on: %s params: %s headers: %s", metadata.getTokenUri(), formBody, request.headers()); // Retry up to three times with a one-second delay between the retries if the connection is closed. Buffer buffer = OidcCommonUtils.encodeForm(formBody); - Uni> response = filter(request, buffer, null).sendBuffer(buffer) + + OidcEndpoint.Type endpoint = introspect ? OidcEndpoint.Type.INTROSPECTION : OidcEndpoint.Type.TOKEN; + Uni> response = filter(endpoint, request, buffer, null).sendBuffer(buffer) .onFailure(ConnectException.class) .retry() .atMost(oidcConfig.connectionRetryCount).onFailure().transform(t -> t.getCause()); @@ -224,10 +229,16 @@ public Key getClientJwtKey() { return clientJwtKey; } - private HttpRequest filter(HttpRequest request, Buffer body, + private HttpRequest filter(OidcEndpoint.Type endpointType, HttpRequest request, Buffer body, OidcRequestContextProperties contextProperties) { - for (OidcRequestFilter filter : filters) { - filter.filter(request, body, contextProperties); + if (!filters.isEmpty()) { + Map newProperties = contextProperties == null ? new HashMap<>() + : new HashMap<>(contextProperties.getAll()); + newProperties.put(OidcConfigurationMetadata.class.getName(), metadata); + OidcRequestContextProperties newContextProperties = new OidcRequestContextProperties(newProperties); + for (OidcRequestFilter filter : OidcCommonUtils.getMatchingOidcRequestFilters(filters, endpointType)) { + filter.filter(request, body, newContextProperties); + } } return request; } diff --git a/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcRecorder.java b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcRecorder.java index 50264f617dfd5..542920ff741b7 100644 --- a/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcRecorder.java +++ b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcRecorder.java @@ -30,6 +30,7 @@ import io.quarkus.oidc.OidcTenantConfig.TokenStateManager.Strategy; import io.quarkus.oidc.TenantConfigResolver; import io.quarkus.oidc.TenantIdentityProvider; +import io.quarkus.oidc.common.OidcEndpoint; import io.quarkus.oidc.common.OidcRequestFilter; import io.quarkus.oidc.common.runtime.OidcCommonConfig; import io.quarkus.oidc.common.runtime.OidcCommonUtils; @@ -433,7 +434,7 @@ protected static Uni createOidcClientUni(OidcTenantConfig oi WebClient client = WebClient.create(new io.vertx.mutiny.core.Vertx(vertx), options); - List clientRequestFilters = OidcCommonUtils.getClientRequestCustomizer(); + Map> oidcRequestFilters = OidcCommonUtils.getOidcRequestFilters(); Uni metadataUni = null; if (!oidcConfig.discoveryEnabled.orElse(true)) { @@ -441,12 +442,13 @@ protected static Uni createOidcClientUni(OidcTenantConfig oi } else { final long connectionDelayInMillisecs = OidcCommonUtils.getConnectionDelayInMillis(oidcConfig); metadataUni = OidcCommonUtils - .discoverMetadata(client, clientRequestFilters, authServerUriString, connectionDelayInMillisecs) + .discoverMetadata(client, oidcRequestFilters, authServerUriString, connectionDelayInMillisecs) .onItem() .transform(new Function() { @Override public OidcConfigurationMetadata apply(JsonObject json) { - return new OidcConfigurationMetadata(json, createLocalMetadata(oidcConfig, authServerUriString)); + return new OidcConfigurationMetadata(json, createLocalMetadata(oidcConfig, authServerUriString), + OidcCommonUtils.getDiscoveryUri(authServerUriString)); } }); } @@ -478,7 +480,7 @@ public Uni apply(OidcConfigurationMetadata metadata, Throwab + " Use 'quarkus.oidc.user-info-path' if the discovery is disabled.")); } return Uni.createFrom() - .item(new OidcProviderClient(client, vertx, metadata, oidcConfig, clientRequestFilters)); + .item(new OidcProviderClient(client, vertx, metadata, oidcConfig, oidcRequestFilters)); } }); diff --git a/integration-tests/oidc-client-wiremock/src/main/java/io/quarkus/it/keycloak/OidcRequestCustomizer.java b/integration-tests/oidc-client-wiremock/src/main/java/io/quarkus/it/keycloak/OidcRequestCustomizer.java index b0d0f2282c034..513b1584b3a91 100644 --- a/integration-tests/oidc-client-wiremock/src/main/java/io/quarkus/it/keycloak/OidcRequestCustomizer.java +++ b/integration-tests/oidc-client-wiremock/src/main/java/io/quarkus/it/keycloak/OidcRequestCustomizer.java @@ -3,6 +3,8 @@ import jakarta.enterprise.context.ApplicationScoped; import io.quarkus.arc.Unremovable; +import io.quarkus.oidc.common.OidcEndpoint; +import io.quarkus.oidc.common.OidcEndpoint.Type; import io.quarkus.oidc.common.OidcRequestContextProperties; import io.quarkus.oidc.common.OidcRequestFilter; import io.vertx.mutiny.core.buffer.Buffer; @@ -10,6 +12,7 @@ @ApplicationScoped @Unremovable +@OidcEndpoint(value = Type.TOKEN) public class OidcRequestCustomizer implements OidcRequestFilter { @Override diff --git a/integration-tests/oidc-wiremock/src/main/java/io/quarkus/it/keycloak/OidcDiscoveryRequestCustomizer.java b/integration-tests/oidc-wiremock/src/main/java/io/quarkus/it/keycloak/OidcDiscoveryRequestCustomizer.java new file mode 100644 index 0000000000000..069450ce91ec2 --- /dev/null +++ b/integration-tests/oidc-wiremock/src/main/java/io/quarkus/it/keycloak/OidcDiscoveryRequestCustomizer.java @@ -0,0 +1,22 @@ +package io.quarkus.it.keycloak; + +import jakarta.enterprise.context.ApplicationScoped; + +import io.quarkus.arc.Unremovable; +import io.quarkus.oidc.common.OidcEndpoint; +import io.quarkus.oidc.common.OidcEndpoint.Type; +import io.quarkus.oidc.common.OidcRequestContextProperties; +import io.quarkus.oidc.common.OidcRequestFilter; +import io.vertx.mutiny.core.buffer.Buffer; +import io.vertx.mutiny.ext.web.client.HttpRequest; + +@ApplicationScoped +@Unremovable +@OidcEndpoint(value = Type.DISCOVERY) +public class OidcDiscoveryRequestCustomizer implements OidcRequestFilter { + + @Override + public void filter(HttpRequest request, Buffer buffer, OidcRequestContextProperties contextProps) { + request.putHeader("Discovery", "OK"); + } +} diff --git a/integration-tests/oidc-wiremock/src/main/java/io/quarkus/it/keycloak/OidcRequestCustomizer.java b/integration-tests/oidc-wiremock/src/main/java/io/quarkus/it/keycloak/OidcJwksRequestCustomizer.java similarity index 74% rename from integration-tests/oidc-wiremock/src/main/java/io/quarkus/it/keycloak/OidcRequestCustomizer.java rename to integration-tests/oidc-wiremock/src/main/java/io/quarkus/it/keycloak/OidcJwksRequestCustomizer.java index 0f76995ecd0ed..2b6080cdfef15 100644 --- a/integration-tests/oidc-wiremock/src/main/java/io/quarkus/it/keycloak/OidcRequestCustomizer.java +++ b/integration-tests/oidc-wiremock/src/main/java/io/quarkus/it/keycloak/OidcJwksRequestCustomizer.java @@ -4,22 +4,25 @@ import io.quarkus.arc.Unremovable; import io.quarkus.oidc.AccessTokenCredential; +import io.quarkus.oidc.OidcConfigurationMetadata; import io.quarkus.oidc.common.OidcRequestContextProperties; import io.quarkus.oidc.common.OidcRequestFilter; -import io.vertx.core.http.HttpMethod; import io.vertx.mutiny.core.buffer.Buffer; import io.vertx.mutiny.ext.web.client.HttpRequest; @ApplicationScoped @Unremovable -public class OidcRequestCustomizer implements OidcRequestFilter { +// Or @OidcEndpoint(value = Type.JWKS) +public class OidcJwksRequestCustomizer implements OidcRequestFilter { @Override public void filter(HttpRequest request, Buffer buffer, OidcRequestContextProperties contextProps) { - HttpMethod method = request.method(); + OidcConfigurationMetadata metadata = contextProps.get(OidcConfigurationMetadata.class.getName()); + // There are many tenants in the test so the URI check is still required String uri = request.uri(); - if (method == HttpMethod.GET && uri.endsWith("/auth/azure/jwk")) { - String token = contextProps.getString(OidcRequestContextProperties.TOKEN); + if (uri.equals("/auth/azure/jwk") && + metadata.getJsonWebKeySetUri().endsWith(uri)) { + String token = contextProps.get(OidcRequestContextProperties.TOKEN); AccessTokenCredential tokenCred = contextProps.get(OidcRequestContextProperties.TOKEN_CREDENTIAL, AccessTokenCredential.class); // or @@ -37,5 +40,4 @@ public void filter(HttpRequest request, Buffer buffer, OidcRequestContex } } } - } diff --git a/integration-tests/oidc-wiremock/src/test/java/io/quarkus/it/keycloak/WiremockTestResource.java b/integration-tests/oidc-wiremock/src/test/java/io/quarkus/it/keycloak/WiremockTestResource.java index 31f834d31d05b..0c7e47985a5a4 100644 --- a/integration-tests/oidc-wiremock/src/test/java/io/quarkus/it/keycloak/WiremockTestResource.java +++ b/integration-tests/oidc-wiremock/src/test/java/io/quarkus/it/keycloak/WiremockTestResource.java @@ -1,6 +1,7 @@ package io.quarkus.it.keycloak; import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.equalTo; import static com.github.tomakehurst.wiremock.client.WireMock.get; import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig; @@ -24,6 +25,7 @@ public void start() { server.stubFor( get(urlEqualTo("/auth/realms/quarkus2/.well-known/openid-configuration")) + .withHeader("Discovery", equalTo("OK")) .willReturn(aResponse() .withHeader("Content-Type", "application/json") .withBody("{\n" +