Skip to content

Commit

Permalink
Polish "Support custom token validators for OAuth2"
Browse files Browse the repository at this point in the history
  • Loading branch information
wilkinsona committed Jul 5, 2023
1 parent 7500dab commit 4feaa28
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 164 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,12 @@ static class JwtConfiguration {

private final OAuth2ResourceServerProperties.Jwt properties;

private final List<OAuth2TokenValidator<Jwt>> customOAuth2TokenValidators;
private final List<OAuth2TokenValidator<Jwt>> additionalValidators;

JwtConfiguration(OAuth2ResourceServerProperties properties,
List<OAuth2TokenValidator<Jwt>> customOAuth2TokenValidators) {
ObjectProvider<OAuth2TokenValidator<Jwt>> additionalValidators) {
this.properties = properties.getJwt();
this.customOAuth2TokenValidators = customOAuth2TokenValidators;
this.additionalValidators = additionalValidators.orderedStream().toList();
}

@Bean
Expand All @@ -102,17 +102,17 @@ private void jwsAlgorithms(Set<SignatureAlgorithm> signatureAlgorithms) {
}

private OAuth2TokenValidator<Jwt> getValidators(OAuth2TokenValidator<Jwt> defaultValidator) {
List<String> audiences = this.properties.getAudiences();
if (CollectionUtils.isEmpty(audiences) && this.additionalValidators.isEmpty()) {
return defaultValidator;
}
List<OAuth2TokenValidator<Jwt>> validators = new ArrayList<>();
validators.add(defaultValidator);
validators.addAll(this.customOAuth2TokenValidators);
List<String> audiences = this.properties.getAudiences();
if (!CollectionUtils.isEmpty(audiences)) {
validators.add(new JwtClaimValidator<List<String>>(JwtClaimNames.AUD,
(aud) -> aud != null && !Collections.disjoint(aud, audiences)));
}
if (validators.size() == 1) {
return validators.get(0);
}
validators.addAll(this.additionalValidators);
return new DelegatingOAuth2TokenValidator<>(validators);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,12 @@ static class JwtDecoderConfiguration {

private final OAuth2ResourceServerProperties.Jwt properties;

private final List<OAuth2TokenValidator<Jwt>> customOAuth2TokenValidators;
private final List<OAuth2TokenValidator<Jwt>> additionalValidators;

JwtDecoderConfiguration(OAuth2ResourceServerProperties properties,
List<OAuth2TokenValidator<Jwt>> customOAuth2TokenValidators) {
ObjectProvider<OAuth2TokenValidator<Jwt>> additionalValidators) {
this.properties = properties.getJwt();
this.customOAuth2TokenValidators = customOAuth2TokenValidators;
this.additionalValidators = additionalValidators.orderedStream().toList();
}

@Bean
Expand All @@ -102,17 +102,17 @@ private void jwsAlgorithms(Set<SignatureAlgorithm> signatureAlgorithms) {
}

private OAuth2TokenValidator<Jwt> getValidators(OAuth2TokenValidator<Jwt> defaultValidator) {
List<String> audiences = this.properties.getAudiences();
if (CollectionUtils.isEmpty(audiences) && this.additionalValidators.isEmpty()) {
return defaultValidator;
}
List<OAuth2TokenValidator<Jwt>> validators = new ArrayList<>();
validators.add(defaultValidator);
validators.addAll(this.customOAuth2TokenValidators);
List<String> audiences = this.properties.getAudiences();
if (!CollectionUtils.isEmpty(audiences)) {
validators.add(new JwtClaimValidator<List<String>>(JwtClaimNames.AUD,
(aud) -> aud != null && !Collections.disjoint(aud, audiences)));
}
if (validators.size() == 1) {
return validators.get(0);
}
validators.addAll(this.additionalValidators);
return new DelegatingOAuth2TokenValidator<>(validators);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,17 @@
package org.springframework.boot.autoconfigure.security.oauth2.resource.reactive;

import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URI;
import java.net.URL;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.function.Consumer;
import java.util.stream.Stream;

import com.fasterxml.jackson.core.JsonProcessingException;
Expand All @@ -35,6 +36,7 @@
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import org.assertj.core.api.InstanceOfAssertFactories;
import org.assertj.core.api.ThrowingConsumer;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.mockito.InOrder;
Expand Down Expand Up @@ -441,7 +443,6 @@ void autoConfigurationWhenIntrospectionUriAvailableShouldBeConditionalOnClass()
.run((context) -> assertThat(context).doesNotHaveBean(ReactiveOpaqueTokenIntrospector.class));
}

@SuppressWarnings("unchecked")
@Test
void autoConfigurationShouldConfigureResourceServerUsingJwkSetUriAndIssuerUri() throws Exception {
this.server = new MockWebServer();
Expand All @@ -457,15 +458,11 @@ void autoConfigurationShouldConfigureResourceServerUsingJwkSetUriAndIssuerUri()
.run((context) -> {
assertThat(context).hasSingleBean(ReactiveJwtDecoder.class);
ReactiveJwtDecoder reactiveJwtDecoder = context.getBean(ReactiveJwtDecoder.class);
DelegatingOAuth2TokenValidator<Jwt> jwtValidator = (DelegatingOAuth2TokenValidator<Jwt>) ReflectionTestUtils
.getField(reactiveJwtDecoder, "jwtValidator");
Collection<OAuth2TokenValidator<Jwt>> tokenValidators = (Collection<OAuth2TokenValidator<Jwt>>) ReflectionTestUtils
.getField(jwtValidator, "tokenValidators");
assertThat(tokenValidators).hasAtLeastOneElementOfType(JwtIssuerValidator.class);
validate(jwt().claim("iss", issuer), reactiveJwtDecoder,
(validators) -> assertThat(validators).hasAtLeastOneElementOfType(JwtIssuerValidator.class));
});
}

@SuppressWarnings("unchecked")
@Test
void autoConfigurationShouldNotConfigureIssuerUriAndAudienceJwtValidatorIfPropertyNotConfigured() throws Exception {
this.server = new MockWebServer();
Expand All @@ -479,13 +476,8 @@ void autoConfigurationShouldNotConfigureIssuerUriAndAudienceJwtValidatorIfProper
.run((context) -> {
assertThat(context).hasSingleBean(ReactiveJwtDecoder.class);
ReactiveJwtDecoder reactiveJwtDecoder = context.getBean(ReactiveJwtDecoder.class);
DelegatingOAuth2TokenValidator<Jwt> jwtValidator = (DelegatingOAuth2TokenValidator<Jwt>) ReflectionTestUtils
.getField(reactiveJwtDecoder, "jwtValidator");
Collection<OAuth2TokenValidator<Jwt>> tokenValidators = (Collection<OAuth2TokenValidator<Jwt>>) ReflectionTestUtils
.getField(jwtValidator, "tokenValidators");
assertThat(tokenValidators).hasExactlyElementsOfTypes(JwtTimestampValidator.class);
assertThat(tokenValidators).doesNotHaveAnyElementsOfTypes(JwtClaimValidator.class);
assertThat(tokenValidators).doesNotHaveAnyElementsOfTypes(JwtIssuerValidator.class);
validate(jwt(), reactiveJwtDecoder, (validators) -> assertThat(validators).singleElement()
.isInstanceOf(JwtTimestampValidator.class));
});
}

Expand All @@ -505,112 +497,82 @@ void autoConfigurationShouldConfigureIssuerAndAudienceJwtValidatorIfPropertyProv
.run((context) -> {
assertThat(context).hasSingleBean(ReactiveJwtDecoder.class);
ReactiveJwtDecoder reactiveJwtDecoder = context.getBean(ReactiveJwtDecoder.class);
validate(issuerUri, reactiveJwtDecoder, null);
validate(
jwt().claim("iss", URI.create(issuerUri).toURL())
.claim("aud", List.of("https://test-audience.com")),
reactiveJwtDecoder,
(validators) -> assertThat(validators).hasAtLeastOneElementOfType(JwtIssuerValidator.class)
.satisfiesOnlyOnce(audClaimValidator()));
});
}

@SuppressWarnings("unchecked")
@Test
void autoConfigurationShouldConfigureAudienceAndCustomValidatorsIfPropertyProvidedAndIssuerUri() throws Exception {
void autoConfigurationShouldConfigureAudienceValidatorIfPropertyProvidedAndIssuerUri() throws Exception {
this.server = new MockWebServer();
this.server.start();
String path = "test";
String issuer = this.server.url(path).toString();
String cleanIssuerPath = cleanIssuerPath(issuer);
setupMockResponse(cleanIssuerPath);
String issuerUri = "http://" + this.server.getHostName() + ":" + this.server.getPort() + "/" + path;
this.contextRunner.withPropertyValues(
"spring.security.oauth2.resourceserver.jwt.jwk-set-uri=https://jwk-set-uri.com",
"spring.security.oauth2.resourceserver.jwt.issuer-uri=" + issuerUri,
this.contextRunner.withPropertyValues("spring.security.oauth2.resourceserver.jwt.issuer-uri=" + issuerUri,
"spring.security.oauth2.resourceserver.jwt.audiences=https://test-audience.com,https://test-audience1.com")
.withUserConfiguration(CustomTokenValidatorsConfig.class)
.run((context) -> {
assertThat(context).hasSingleBean(ReactiveJwtDecoder.class);
ReactiveJwtDecoder reactiveJwtDecoder = context.getBean(ReactiveJwtDecoder.class);
assertThat(context).hasBean("customJwtClaimValidator");
OAuth2TokenValidator<Jwt> customValidator = (OAuth2TokenValidator<Jwt>) context
.getBean("customJwtClaimValidator");
validate(issuerUri, reactiveJwtDecoder, customValidator);
SupplierReactiveJwtDecoder supplierJwtDecoderBean = context.getBean(SupplierReactiveJwtDecoder.class);
Mono<ReactiveJwtDecoder> jwtDecoderSupplier = (Mono<ReactiveJwtDecoder>) ReflectionTestUtils
.getField(supplierJwtDecoderBean, "jwtDecoderMono");
ReactiveJwtDecoder jwtDecoder = jwtDecoderSupplier.block();
validate(
jwt().claim("iss", URI.create(issuerUri).toURL())
.claim("aud", List.of("https://test-audience.com")),
jwtDecoder,
(validators) -> assertThat(validators).hasAtLeastOneElementOfType(JwtIssuerValidator.class)
.satisfiesOnlyOnce(audClaimValidator()));
});
}

@SuppressWarnings("unchecked")
private void validate(String issuerUri, ReactiveJwtDecoder jwtDecoder, OAuth2TokenValidator<Jwt> customValidator)
throws MalformedURLException {
DelegatingOAuth2TokenValidator<Jwt> jwtValidator = (DelegatingOAuth2TokenValidator<Jwt>) ReflectionTestUtils
.getField(jwtDecoder, "jwtValidator");
Jwt.Builder builder = jwt().claim("aud", Collections.singletonList("https://test-audience.com"));
if (issuerUri != null) {
builder.claim("iss", new URL(issuerUri));
}
if (customValidator != null) {
builder.claim("custom_claim", "custom_claim_value");
}
Jwt jwt = builder.build();
assertThat(jwtValidator.validate(jwt).hasErrors()).isFalse();
Collection<OAuth2TokenValidator<Jwt>> delegates = (Collection<OAuth2TokenValidator<Jwt>>) ReflectionTestUtils
.getField(jwtValidator, "tokenValidators");
validateDelegates(issuerUri, delegates, customValidator);
}

private void validateDelegates(String issuerUri, Collection<OAuth2TokenValidator<Jwt>> delegates,
OAuth2TokenValidator<Jwt> customValidator) {
assertThat(delegates).hasAtLeastOneElementOfType(JwtClaimValidator.class);
OAuth2TokenValidator<Jwt> delegatingValidator = delegates.stream()
.filter((v) -> v instanceof DelegatingOAuth2TokenValidator)
.findFirst()
.get();
if (issuerUri != null) {
assertThat(delegatingValidator).extracting("tokenValidators")
.asInstanceOf(InstanceOfAssertFactories.collection(OAuth2TokenValidator.class))
.hasAtLeastOneElementOfType(JwtIssuerValidator.class);
}
List<OAuth2TokenValidator<Jwt>> claimValidators = delegates.stream()
.filter((d) -> d instanceof JwtClaimValidator<?>)
.collect(Collectors.toList());
assertThat(claimValidators).anyMatch((v) -> "aud".equals(ReflectionTestUtils.getField(v, "claim")));
if (customValidator != null) {
assertThat(claimValidators)
.anyMatch((v) -> "custom_claim".equals(ReflectionTestUtils.getField(v, "claim")));
}
}

@SuppressWarnings("unchecked")
@Test
void autoConfigurationShouldConfigureAudienceValidatorIfPropertyProvidedAndIssuerUri() throws Exception {
void autoConfigurationShouldConfigureAudienceValidatorIfPropertyProvidedAndPublicKey() throws Exception {
this.server = new MockWebServer();
this.server.start();
String path = "test";
String issuer = this.server.url(path).toString();
String cleanIssuerPath = cleanIssuerPath(issuer);
setupMockResponse(cleanIssuerPath);
String issuerUri = "http://" + this.server.getHostName() + ":" + this.server.getPort() + "/" + path;
this.contextRunner.withPropertyValues("spring.security.oauth2.resourceserver.jwt.issuer-uri=" + issuerUri,
this.contextRunner.withPropertyValues(
"spring.security.oauth2.resourceserver.jwt.public-key-location=classpath:public-key-location",
"spring.security.oauth2.resourceserver.jwt.audiences=https://test-audience.com,https://test-audience1.com")
.run((context) -> {
SupplierReactiveJwtDecoder supplierJwtDecoderBean = context.getBean(SupplierReactiveJwtDecoder.class);
Mono<ReactiveJwtDecoder> jwtDecoderSupplier = (Mono<ReactiveJwtDecoder>) ReflectionTestUtils
.getField(supplierJwtDecoderBean, "jwtDecoderMono");
ReactiveJwtDecoder jwtDecoder = jwtDecoderSupplier.block();
validate(issuerUri, jwtDecoder, null);
assertThat(context).hasSingleBean(ReactiveJwtDecoder.class);
ReactiveJwtDecoder jwtDecoder = context.getBean(ReactiveJwtDecoder.class);
validate(jwt().claim("aud", List.of("https://test-audience.com")), jwtDecoder,
(validators) -> assertThat(validators).satisfiesOnlyOnce(audClaimValidator()));
});
}

@SuppressWarnings("unchecked")
@Test
void autoConfigurationShouldConfigureAudienceValidatorIfPropertyProvidedAndPublicKey() throws Exception {
void autoConfigurationShouldConfigureCustomValidators() throws Exception {
this.server = new MockWebServer();
this.server.start();
String path = "test";
String issuer = this.server.url(path).toString();
String cleanIssuerPath = cleanIssuerPath(issuer);
setupMockResponse(cleanIssuerPath);
this.contextRunner.withPropertyValues(
"spring.security.oauth2.resourceserver.jwt.public-key-location=classpath:public-key-location",
"spring.security.oauth2.resourceserver.jwt.audiences=https://test-audience.com,https://test-audience1.com")
String issuerUri = "http://" + this.server.getHostName() + ":" + this.server.getPort() + "/" + path;
this.contextRunner
.withPropertyValues("spring.security.oauth2.resourceserver.jwt.jwk-set-uri=https://jwk-set-uri.com",
"spring.security.oauth2.resourceserver.jwt.issuer-uri=" + issuerUri)
.withUserConfiguration(CustomJwtClaimValidatorConfig.class)
.run((context) -> {
assertThat(context).hasSingleBean(ReactiveJwtDecoder.class);
ReactiveJwtDecoder jwtDecoder = context.getBean(ReactiveJwtDecoder.class);
validate(null, jwtDecoder, null);
ReactiveJwtDecoder reactiveJwtDecoder = context.getBean(ReactiveJwtDecoder.class);
OAuth2TokenValidator<Jwt> customValidator = (OAuth2TokenValidator<Jwt>) context
.getBean("customJwtClaimValidator");
validate(jwt().claim("iss", URI.create(issuerUri).toURL()).claim("custom_claim", "custom_claim_value"),
reactiveJwtDecoder, (validators) -> assertThat(validators).contains(customValidator)
.hasAtLeastOneElementOfType(JwtIssuerValidator.class));
});
}

Expand Down Expand Up @@ -640,6 +602,30 @@ void audienceValidatorWhenAudienceInvalid() throws Exception {
});
}

@SuppressWarnings("unchecked")
@Test
void customValidatorWhenInvalid() throws Exception {
this.server = new MockWebServer();
this.server.start();
String path = "test";
String issuer = this.server.url(path).toString();
String cleanIssuerPath = cleanIssuerPath(issuer);
setupMockResponse(cleanIssuerPath);
String issuerUri = "http://" + this.server.getHostName() + ":" + this.server.getPort() + "/" + path;
this.contextRunner
.withPropertyValues("spring.security.oauth2.resourceserver.jwt.jwk-set-uri=https://jwk-set-uri.com",
"spring.security.oauth2.resourceserver.jwt.issuer-uri=" + issuerUri)
.withUserConfiguration(CustomJwtClaimValidatorConfig.class)
.run((context) -> {
assertThat(context).hasSingleBean(ReactiveJwtDecoder.class);
ReactiveJwtDecoder jwtDecoder = context.getBean(ReactiveJwtDecoder.class);
DelegatingOAuth2TokenValidator<Jwt> jwtValidator = (DelegatingOAuth2TokenValidator<Jwt>) ReflectionTestUtils
.getField(jwtDecoder, "jwtValidator");
Jwt jwt = jwt().claim("iss", new URL(issuerUri)).claim("custom_claim", "invalid_value").build();
assertThat(jwtValidator.validate(jwt).hasErrors()).isTrue();
});
}

private void assertFilterConfiguredWithJwtAuthenticationManager(AssertableReactiveWebApplicationContext context) {
MatcherSecurityWebFilterChain filterChain = (MatcherSecurityWebFilterChain) context
.getBean(BeanIds.SPRING_SECURITY_FILTER_CHAIN);
Expand Down Expand Up @@ -723,6 +709,37 @@ static Jwt.Builder jwt() {
.subject("mock-test-subject");
}

@SuppressWarnings("unchecked")
private void validate(Jwt.Builder builder, ReactiveJwtDecoder jwtDecoder,
ThrowingConsumer<List<OAuth2TokenValidator<Jwt>>> validatorsConsumer) {
DelegatingOAuth2TokenValidator<Jwt> jwtValidator = (DelegatingOAuth2TokenValidator<Jwt>) ReflectionTestUtils
.getField(jwtDecoder, "jwtValidator");
assertThat(jwtValidator.validate(builder.build()).hasErrors()).isFalse();
validatorsConsumer.accept(extractValidators(jwtValidator));
}

@SuppressWarnings("unchecked")
private List<OAuth2TokenValidator<Jwt>> extractValidators(DelegatingOAuth2TokenValidator<Jwt> delegatingValidator) {
Collection<OAuth2TokenValidator<Jwt>> delegates = (Collection<OAuth2TokenValidator<Jwt>>) ReflectionTestUtils
.getField(delegatingValidator, "tokenValidators");
List<OAuth2TokenValidator<Jwt>> extracted = new ArrayList<>();
for (OAuth2TokenValidator<Jwt> delegate : delegates) {
if (delegate instanceof DelegatingOAuth2TokenValidator<Jwt> delegatingDelegate) {
extracted.addAll(extractValidators(delegatingDelegate));
}
else {
extracted.add(delegate);
}
}
return extracted;
}

private Consumer<OAuth2TokenValidator<Jwt>> audClaimValidator() {
return (validator) -> assertThat(validator).isInstanceOf(JwtClaimValidator.class)
.extracting("claim")
.isEqualTo("aud");
}

@EnableWebFluxSecurity
static class TestConfig {

Expand Down Expand Up @@ -781,7 +798,7 @@ SecurityWebFilterChain testSpringSecurityFilterChain(ServerHttpSecurity http) {
}

@Configuration(proxyBeanMethods = false)
static class CustomTokenValidatorsConfig {
static class CustomJwtClaimValidatorConfig {

@Bean
JwtClaimValidator<String> customJwtClaimValidator() {
Expand Down
Loading

0 comments on commit 4feaa28

Please sign in to comment.