From 31926182206d8b821b470944ead628308d7ff0a3 Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Wed, 2 Nov 2022 15:34:54 -0600 Subject: [PATCH] Add authenticationFailureHandler - To ServerHttpSecurity#httpBasic - To ServerHttpSecurity#oauthResourceServer Closes gh-12132 --- .../config/web/server/ServerHttpSecurity.java | 43 +++++++++++++--- .../config/web/server/ServerHttpBasicDsl.kt | 3 ++ .../server/ServerOAuth2ResourceServerDsl.kt | 3 ++ .../server/OAuth2ResourceServerSpecTests.java | 50 +++++++++++++++++++ .../web/server/ServerHttpSecurityTests.java | 23 +++++++++ .../web/server/ServerHttpBasicDslTests.kt | 40 ++++++++++++++- .../ServerOAuth2ResourceServerDslTests.kt | 50 +++++++++++++++++-- 7 files changed, 199 insertions(+), 13 deletions(-) diff --git a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java index 3da6c116d9c..3316fc8ffa4 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java @@ -2023,6 +2023,8 @@ public final class HttpBasicSpec { private ServerAuthenticationEntryPoint entryPoint; + private ServerAuthenticationFailureHandler authenticationFailureHandler; + private HttpBasicSpec() { List entryPoints = new ArrayList<>(); entryPoints @@ -2071,6 +2073,13 @@ public HttpBasicSpec authenticationEntryPoint(ServerAuthenticationEntryPoint aut return this; } + public HttpBasicSpec authenticationFailureHandler( + ServerAuthenticationFailureHandler authenticationFailureHandler) { + Assert.notNull(authenticationFailureHandler, "authenticationFailureHandler cannot be null"); + this.authenticationFailureHandler = authenticationFailureHandler; + return this; + } + /** * Allows method chaining to continue configuring the {@link ServerHttpSecurity} * @return the {@link ServerHttpSecurity} to continue configuring @@ -2102,13 +2111,19 @@ protected void configure(ServerHttpSecurity http) { Arrays.asList(this.xhrMatcher, restNotHtmlMatcher)); ServerHttpSecurity.this.defaultEntryPoints.add(new DelegateEntry(preferredMatcher, this.entryPoint)); AuthenticationWebFilter authenticationFilter = new AuthenticationWebFilter(this.authenticationManager); - authenticationFilter - .setAuthenticationFailureHandler(new ServerAuthenticationEntryPointFailureHandler(this.entryPoint)); + authenticationFilter.setAuthenticationFailureHandler(authenticationFailureHandler()); authenticationFilter.setAuthenticationConverter(new ServerHttpBasicAuthenticationConverter()); authenticationFilter.setSecurityContextRepository(this.securityContextRepository); http.addFilterAt(authenticationFilter, SecurityWebFiltersOrder.HTTP_BASIC); } + private ServerAuthenticationFailureHandler authenticationFailureHandler() { + if (this.authenticationFailureHandler != null) { + return this.authenticationFailureHandler; + } + return new ServerAuthenticationEntryPointFailureHandler(this.entryPoint); + } + } /** @@ -3996,6 +4011,8 @@ public class OAuth2ResourceServerSpec { private ServerAuthenticationEntryPoint entryPoint = new BearerTokenServerAuthenticationEntryPoint(); + private ServerAuthenticationFailureHandler authenticationFailureHandler; + private ServerAccessDeniedHandler accessDeniedHandler = new BearerTokenServerAccessDeniedHandler(); private ServerAuthenticationConverter bearerTokenConverter = new ServerBearerTokenAuthenticationConverter(); @@ -4038,6 +4055,12 @@ public OAuth2ResourceServerSpec authenticationEntryPoint(ServerAuthenticationEnt return this; } + public OAuth2ResourceServerSpec authenticationFailureHandler( + ServerAuthenticationFailureHandler authenticationFailureHandler) { + this.authenticationFailureHandler = authenticationFailureHandler; + return this; + } + /** * Configures the {@link ServerAuthenticationConverter} to use for requests * authenticating with @@ -4127,8 +4150,7 @@ protected void configure(ServerHttpSecurity http) { if (this.authenticationManagerResolver != null) { AuthenticationWebFilter oauth2 = new AuthenticationWebFilter(this.authenticationManagerResolver); oauth2.setServerAuthenticationConverter(this.bearerTokenConverter); - oauth2.setAuthenticationFailureHandler( - new ServerAuthenticationEntryPointFailureHandler(this.entryPoint)); + oauth2.setAuthenticationFailureHandler(authenticationFailureHandler()); http.addFilterAt(oauth2, SecurityWebFiltersOrder.AUTHENTICATION); } else if (this.jwt != null) { @@ -4181,6 +4203,13 @@ private void registerDefaultCsrfOverride(ServerHttpSecurity http) { } } + private ServerAuthenticationFailureHandler authenticationFailureHandler() { + if (this.authenticationFailureHandler != null) { + return this.authenticationFailureHandler; + } + return new ServerAuthenticationEntryPointFailureHandler(this.entryPoint); + } + public ServerHttpSecurity and() { return ServerHttpSecurity.this; } @@ -4262,8 +4291,7 @@ protected void configure(ServerHttpSecurity http) { ReactiveAuthenticationManager authenticationManager = getAuthenticationManager(); AuthenticationWebFilter oauth2 = new AuthenticationWebFilter(authenticationManager); oauth2.setServerAuthenticationConverter(OAuth2ResourceServerSpec.this.bearerTokenConverter); - oauth2.setAuthenticationFailureHandler( - new ServerAuthenticationEntryPointFailureHandler(OAuth2ResourceServerSpec.this.entryPoint)); + oauth2.setAuthenticationFailureHandler(authenticationFailureHandler()); http.addFilterAt(oauth2, SecurityWebFiltersOrder.AUTHENTICATION); } @@ -4398,8 +4426,7 @@ protected void configure(ServerHttpSecurity http) { ReactiveAuthenticationManager authenticationManager = getAuthenticationManager(); AuthenticationWebFilter oauth2 = new AuthenticationWebFilter(authenticationManager); oauth2.setServerAuthenticationConverter(OAuth2ResourceServerSpec.this.bearerTokenConverter); - oauth2.setAuthenticationFailureHandler( - new ServerAuthenticationEntryPointFailureHandler(OAuth2ResourceServerSpec.this.entryPoint)); + oauth2.setAuthenticationFailureHandler(authenticationFailureHandler()); http.addFilterAt(oauth2, SecurityWebFiltersOrder.AUTHENTICATION); } diff --git a/config/src/main/kotlin/org/springframework/security/config/web/server/ServerHttpBasicDsl.kt b/config/src/main/kotlin/org/springframework/security/config/web/server/ServerHttpBasicDsl.kt index 91b157c2644..7aa73ff0ed3 100644 --- a/config/src/main/kotlin/org/springframework/security/config/web/server/ServerHttpBasicDsl.kt +++ b/config/src/main/kotlin/org/springframework/security/config/web/server/ServerHttpBasicDsl.kt @@ -21,6 +21,7 @@ import org.springframework.security.core.Authentication import org.springframework.security.core.context.SecurityContext import org.springframework.security.web.authentication.www.BasicAuthenticationFilter import org.springframework.security.web.server.ServerAuthenticationEntryPoint +import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler import org.springframework.security.web.server.context.ReactorContextWebFilter import org.springframework.security.web.server.context.ServerSecurityContextRepository @@ -42,6 +43,7 @@ import org.springframework.security.web.server.context.ServerSecurityContextRepo class ServerHttpBasicDsl { var authenticationManager: ReactiveAuthenticationManager? = null var securityContextRepository: ServerSecurityContextRepository? = null + var authenticationFailureHandler: ServerAuthenticationFailureHandler? = null var authenticationEntryPoint: ServerAuthenticationEntryPoint? = null private var disabled = false @@ -57,6 +59,7 @@ class ServerHttpBasicDsl { return { httpBasic -> authenticationManager?.also { httpBasic.authenticationManager(authenticationManager) } securityContextRepository?.also { httpBasic.securityContextRepository(securityContextRepository) } + authenticationFailureHandler?.also { httpBasic.authenticationFailureHandler(authenticationFailureHandler) } authenticationEntryPoint?.also { httpBasic.authenticationEntryPoint(authenticationEntryPoint) } if (disabled) { httpBasic.disable() diff --git a/config/src/main/kotlin/org/springframework/security/config/web/server/ServerOAuth2ResourceServerDsl.kt b/config/src/main/kotlin/org/springframework/security/config/web/server/ServerOAuth2ResourceServerDsl.kt index ee48923469b..66dac8ec158 100644 --- a/config/src/main/kotlin/org/springframework/security/config/web/server/ServerOAuth2ResourceServerDsl.kt +++ b/config/src/main/kotlin/org/springframework/security/config/web/server/ServerOAuth2ResourceServerDsl.kt @@ -19,6 +19,7 @@ package org.springframework.security.config.web.server import org.springframework.security.authentication.ReactiveAuthenticationManagerResolver import org.springframework.security.web.server.ServerAuthenticationEntryPoint import org.springframework.security.web.server.authentication.ServerAuthenticationConverter +import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler import org.springframework.web.server.ServerWebExchange @@ -38,6 +39,7 @@ import org.springframework.web.server.ServerWebExchange @ServerSecurityMarker class ServerOAuth2ResourceServerDsl { var accessDeniedHandler: ServerAccessDeniedHandler? = null + var authenticationFailureHandler: ServerAuthenticationFailureHandler? = null var authenticationEntryPoint: ServerAuthenticationEntryPoint? = null var bearerTokenConverter: ServerAuthenticationConverter? = null var authenticationManagerResolver: ReactiveAuthenticationManagerResolver? = null @@ -107,6 +109,7 @@ class ServerOAuth2ResourceServerDsl { internal fun get(): (ServerHttpSecurity.OAuth2ResourceServerSpec) -> Unit { return { oauth2ResourceServer -> accessDeniedHandler?.also { oauth2ResourceServer.accessDeniedHandler(accessDeniedHandler) } + authenticationFailureHandler?.also { oauth2ResourceServer.authenticationFailureHandler(authenticationFailureHandler) } authenticationEntryPoint?.also { oauth2ResourceServer.authenticationEntryPoint(authenticationEntryPoint) } bearerTokenConverter?.also { oauth2ResourceServer.bearerTokenConverter(bearerTokenConverter) } authenticationManagerResolver?.also { oauth2ResourceServer.authenticationManagerResolver(authenticationManagerResolver!!) } diff --git a/config/src/test/java/org/springframework/security/config/web/server/OAuth2ResourceServerSpecTests.java b/config/src/test/java/org/springframework/security/config/web/server/OAuth2ResourceServerSpecTests.java index 2076ca28dec..51be0c1dda0 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/OAuth2ResourceServerSpecTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/OAuth2ResourceServerSpecTests.java @@ -51,6 +51,7 @@ import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.security.authentication.AbstractAuthenticationToken; +import org.springframework.security.authentication.BadCredentialsException; import org.springframework.security.authentication.ReactiveAuthenticationManager; import org.springframework.security.authentication.ReactiveAuthenticationManagerResolver; import org.springframework.security.authentication.TestingAuthenticationToken; @@ -73,6 +74,7 @@ import org.springframework.security.web.server.SecurityWebFilterChain; import org.springframework.security.web.server.authentication.HttpStatusServerEntryPoint; import org.springframework.security.web.server.authentication.ServerAuthenticationConverter; +import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler; import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler; import org.springframework.test.context.junit.jupiter.SpringExtension; import org.springframework.test.web.reactive.server.WebTestClient; @@ -348,6 +350,25 @@ public void getWhenUsingCustomAuthenticationManagerResolverThenUsesItAccordingly // @formatter:on } + @Test + public void getWhenUsingCustomAuthenticationFailureHandlerThenUsesIsAccordingly() { + this.spring.register(CustomAuthenticationFailureHandlerConfig.class).autowire(); + ServerAuthenticationFailureHandler handler = this.spring.getContext() + .getBean(ServerAuthenticationFailureHandler.class); + ReactiveAuthenticationManager authenticationManager = this.spring.getContext() + .getBean(ReactiveAuthenticationManager.class); + given(authenticationManager.authenticate(any())) + .willReturn(Mono.error(() -> new BadCredentialsException("bad"))); + given(handler.onAuthenticationFailure(any(), any())).willReturn(Mono.empty()); + // @formatter:off + this.client.get() + .headers((headers) -> headers.setBearerAuth(this.messageReadToken)) + .exchange() + .expectStatus().isOk(); + // @formatter:on + verify(handler).onAuthenticationFailure(any(), any()); + } + @Test public void postWhenSignedThenReturnsOk() { this.spring.register(PublicKeyConfig.class, RootController.class).autowire(); @@ -893,6 +914,35 @@ ReactiveAuthenticationManager authenticationManager() { } + @EnableWebFlux + @EnableWebFluxSecurity + static class CustomAuthenticationFailureHandlerConfig { + + @Bean + SecurityWebFilterChain springSecurity(ServerHttpSecurity http) { + // @formatter:off + http + .authorizeExchange((authorize) -> authorize.anyExchange().authenticated()) + .oauth2ResourceServer((oauth2) -> oauth2 + .authenticationFailureHandler(authenticationFailureHandler()) + .jwt((jwt) -> jwt.authenticationManager(authenticationManager())) + ); + // @formatter:on + return http.build(); + } + + @Bean + ReactiveAuthenticationManager authenticationManager() { + return mock(ReactiveAuthenticationManager.class); + } + + @Bean + ServerAuthenticationFailureHandler authenticationFailureHandler() { + return mock(ServerAuthenticationFailureHandler.class); + } + + } + @EnableWebFlux @EnableWebFluxSecurity static class CustomBearerTokenServerAuthenticationConverter { diff --git a/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java b/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java index 23660f14d9f..149ba09b299 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java @@ -35,6 +35,7 @@ import org.springframework.http.HttpStatus; import org.springframework.mock.http.server.reactive.MockServerHttpRequest; import org.springframework.mock.web.server.MockServerWebExchange; +import org.springframework.security.authentication.BadCredentialsException; import org.springframework.security.authentication.ReactiveAuthenticationManager; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.config.annotation.web.reactive.ServerHttpSecurityConfigurationBuilder; @@ -57,6 +58,7 @@ import org.springframework.security.web.server.authentication.AnonymousAuthenticationWebFilterTests; import org.springframework.security.web.server.authentication.HttpBasicServerAuthenticationEntryPoint; import org.springframework.security.web.server.authentication.HttpStatusServerEntryPoint; +import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler; import org.springframework.security.web.server.authentication.ServerX509AuthenticationConverter; import org.springframework.security.web.server.authentication.logout.DelegatingServerLogoutHandler; import org.springframework.security.web.server.authentication.logout.LogoutWebFilter; @@ -218,6 +220,27 @@ public void basicWhenXHRRequestThenUnauthorized() { verify(authenticationEntryPoint).commence(any(), any()); } + @Test + public void basicWhenCustomAuthenticationFailureHandlerThenUses() { + ReactiveAuthenticationManager authenticationManager = mock(ReactiveAuthenticationManager.class); + ServerAuthenticationFailureHandler authenticationFailureHandler = mock( + ServerAuthenticationFailureHandler.class); + this.http.httpBasic().authenticationFailureHandler(authenticationFailureHandler); + this.http.httpBasic().authenticationManager(authenticationManager); + this.http.authorizeExchange().anyExchange().authenticated(); + given(authenticationManager.authenticate(any())) + .willReturn(Mono.error(() -> new BadCredentialsException("bad"))); + given(authenticationFailureHandler.onAuthenticationFailure(any(), any())).willReturn(Mono.empty()); + WebTestClient client = buildClient(); + // @formatter:off + client.get().uri("/") + .headers((headers) -> headers.setBasicAuth("user", "password")) + .exchange() + .expectStatus().isOk(); + // @formatter:on + verify(authenticationFailureHandler).onAuthenticationFailure(any(), any()); + } + @Test public void buildWhenServerWebExchangeFromContextThenFound() { SecurityWebFilterChain filter = this.http.build(); diff --git a/config/src/test/kotlin/org/springframework/security/config/web/server/ServerHttpBasicDslTests.kt b/config/src/test/kotlin/org/springframework/security/config/web/server/ServerHttpBasicDslTests.kt index 52a3fe3ec97..d4c3b1bce78 100644 --- a/config/src/test/kotlin/org/springframework/security/config/web/server/ServerHttpBasicDslTests.kt +++ b/config/src/test/kotlin/org/springframework/security/config/web/server/ServerHttpBasicDslTests.kt @@ -19,7 +19,6 @@ package org.springframework.security.config.web.server import io.mockk.every import io.mockk.mockkObject import io.mockk.verify -import java.util.Base64 import org.junit.jupiter.api.Test import org.junit.jupiter.api.extension.ExtendWith import org.springframework.beans.factory.annotation.Autowired @@ -36,6 +35,7 @@ import org.springframework.security.core.userdetails.MapReactiveUserDetailsServi import org.springframework.security.core.userdetails.User import org.springframework.security.web.server.SecurityWebFilterChain import org.springframework.security.web.server.ServerAuthenticationEntryPoint +import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler import org.springframework.security.web.server.context.ServerSecurityContextRepository import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository import org.springframework.test.web.reactive.server.WebTestClient @@ -43,6 +43,7 @@ import org.springframework.web.bind.annotation.RequestMapping import org.springframework.web.bind.annotation.RestController import org.springframework.web.reactive.config.EnableWebFlux import reactor.core.publisher.Mono +import java.util.* /** * Tests for [ServerHttpBasicDsl] @@ -216,6 +217,43 @@ class ServerHttpBasicDslTests { } } + @Test + fun `http basic when custom authentication failure handler then failure handler used`() { + this.spring.register(CustomAuthenticationFailureHandlerConfig::class.java, UserDetailsConfig::class.java).autowire() + mockkObject(CustomAuthenticationFailureHandlerConfig.FAILURE_HANDLER) + every { + CustomAuthenticationFailureHandlerConfig.FAILURE_HANDLER.onAuthenticationFailure(any(), any()) + } returns Mono.empty() + + this.client.get() + .uri("/") + .header("Authorization", "Basic " + Base64.getEncoder().encodeToString("user:wrong".toByteArray())) + .exchange() + + verify(exactly = 1) { CustomAuthenticationFailureHandlerConfig.FAILURE_HANDLER.onAuthenticationFailure(any(), any()) } + } + + @EnableWebFluxSecurity + @EnableWebFlux + open class CustomAuthenticationFailureHandlerConfig { + + companion object { + val FAILURE_HANDLER: ServerAuthenticationFailureHandler = ServerAuthenticationFailureHandler { _, _ -> Mono.empty() } + } + + @Bean + open fun springWebFilterChain(http: ServerHttpSecurity): SecurityWebFilterChain { + return http { + authorizeExchange { + authorize(anyExchange, authenticated) + } + httpBasic { + authenticationFailureHandler = FAILURE_HANDLER + } + } + } + } + @Configuration open class UserDetailsConfig { @Bean diff --git a/config/src/test/kotlin/org/springframework/security/config/web/server/ServerOAuth2ResourceServerDslTests.kt b/config/src/test/kotlin/org/springframework/security/config/web/server/ServerOAuth2ResourceServerDslTests.kt index 868226d971d..82b781765df 100644 --- a/config/src/test/kotlin/org/springframework/security/config/web/server/ServerOAuth2ResourceServerDslTests.kt +++ b/config/src/test/kotlin/org/springframework/security/config/web/server/ServerOAuth2ResourceServerDslTests.kt @@ -19,10 +19,6 @@ package org.springframework.security.config.web.server import io.mockk.every import io.mockk.mockkObject import io.mockk.verify -import java.math.BigInteger -import java.security.KeyFactory -import java.security.interfaces.RSAPublicKey -import java.security.spec.RSAPublicKeySpec import org.junit.jupiter.api.Test import org.junit.jupiter.api.extension.ExtendWith import org.springframework.beans.factory.annotation.Autowired @@ -36,11 +32,16 @@ import org.springframework.security.config.test.SpringTestContextExtension import org.springframework.security.oauth2.server.resource.web.server.authentication.ServerBearerTokenAuthenticationConverter import org.springframework.security.web.server.SecurityWebFilterChain import org.springframework.security.web.server.authentication.HttpStatusServerEntryPoint +import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler import org.springframework.test.web.reactive.server.WebTestClient import org.springframework.web.reactive.config.EnableWebFlux import org.springframework.web.server.ServerWebExchange import reactor.core.publisher.Mono +import java.math.BigInteger +import java.security.KeyFactory +import java.security.interfaces.RSAPublicKey +import java.security.spec.RSAPublicKeySpec /** * Tests for [ServerOAuth2ResourceServerDsl] @@ -125,6 +126,47 @@ class ServerOAuth2ResourceServerDslTests { } } + @Test + fun `http basic when custom authentication failure handler then failure handler used`() { + this.spring.register(AuthenticationFailureHandlerConfig::class.java).autowire() + mockkObject(AuthenticationFailureHandlerConfig.FAILURE_HANDLER) + every { + AuthenticationFailureHandlerConfig.FAILURE_HANDLER.onAuthenticationFailure(any(), any()) + } returns Mono.empty() + + this.client.get() + .uri("/") + .header("Authorization", "Bearer token") + .exchange() + .expectStatus().isOk + + verify(exactly = 1) { AuthenticationFailureHandlerConfig.FAILURE_HANDLER.onAuthenticationFailure(any(), any()) } + } + + @EnableWebFluxSecurity + @EnableWebFlux + open class AuthenticationFailureHandlerConfig { + + companion object { + val FAILURE_HANDLER: ServerAuthenticationFailureHandler = ServerAuthenticationFailureHandler { _, _ -> Mono.empty() } + } + + @Bean + open fun springWebFilterChain(http: ServerHttpSecurity): SecurityWebFilterChain { + return http { + authorizeExchange { + authorize(anyExchange, authenticated) + } + oauth2ResourceServer { + authenticationFailureHandler = FAILURE_HANDLER + jwt { + publicKey = publicKey() + } + } + } + } + } + @Test fun `request when custom bearer token converter configured then custom converter used`() { this.spring.register(BearerTokenConverterConfig::class.java).autowire()