Skip to content

Commit

Permalink
Merge branch '5.8.x'
Browse files Browse the repository at this point in the history
Closes gh-12133
  • Loading branch information
jzheaux committed Nov 2, 2022
2 parents 983f1d4 + 3192618 commit fc8e20b
Show file tree
Hide file tree
Showing 7 changed files with 199 additions and 13 deletions.
Expand Up @@ -2023,6 +2023,8 @@ public final class HttpBasicSpec {

private ServerAuthenticationEntryPoint entryPoint;

private ServerAuthenticationFailureHandler authenticationFailureHandler;

private HttpBasicSpec() {
List<DelegateEntry> entryPoints = new ArrayList<>();
entryPoints
Expand Down Expand Up @@ -2071,6 +2073,13 @@ public HttpBasicSpec authenticationEntryPoint(ServerAuthenticationEntryPoint aut
return this;
}

public HttpBasicSpec authenticationFailureHandler(
ServerAuthenticationFailureHandler authenticationFailureHandler) {
Assert.notNull(authenticationFailureHandler, "authenticationFailureHandler cannot be null");
this.authenticationFailureHandler = authenticationFailureHandler;
return this;
}

/**
* Allows method chaining to continue configuring the {@link ServerHttpSecurity}
* @return the {@link ServerHttpSecurity} to continue configuring
Expand Down Expand Up @@ -2102,13 +2111,19 @@ protected void configure(ServerHttpSecurity http) {
Arrays.asList(this.xhrMatcher, restNotHtmlMatcher));
ServerHttpSecurity.this.defaultEntryPoints.add(new DelegateEntry(preferredMatcher, this.entryPoint));
AuthenticationWebFilter authenticationFilter = new AuthenticationWebFilter(this.authenticationManager);
authenticationFilter
.setAuthenticationFailureHandler(new ServerAuthenticationEntryPointFailureHandler(this.entryPoint));
authenticationFilter.setAuthenticationFailureHandler(authenticationFailureHandler());
authenticationFilter.setAuthenticationConverter(new ServerHttpBasicAuthenticationConverter());
authenticationFilter.setSecurityContextRepository(this.securityContextRepository);
http.addFilterAt(authenticationFilter, SecurityWebFiltersOrder.HTTP_BASIC);
}

private ServerAuthenticationFailureHandler authenticationFailureHandler() {
if (this.authenticationFailureHandler != null) {
return this.authenticationFailureHandler;
}
return new ServerAuthenticationEntryPointFailureHandler(this.entryPoint);
}

}

/**
Expand Down Expand Up @@ -3996,6 +4011,8 @@ public class OAuth2ResourceServerSpec {

private ServerAuthenticationEntryPoint entryPoint = new BearerTokenServerAuthenticationEntryPoint();

private ServerAuthenticationFailureHandler authenticationFailureHandler;

private ServerAccessDeniedHandler accessDeniedHandler = new BearerTokenServerAccessDeniedHandler();

private ServerAuthenticationConverter bearerTokenConverter = new ServerBearerTokenAuthenticationConverter();
Expand Down Expand Up @@ -4038,6 +4055,12 @@ public OAuth2ResourceServerSpec authenticationEntryPoint(ServerAuthenticationEnt
return this;
}

public OAuth2ResourceServerSpec authenticationFailureHandler(
ServerAuthenticationFailureHandler authenticationFailureHandler) {
this.authenticationFailureHandler = authenticationFailureHandler;
return this;
}

/**
* Configures the {@link ServerAuthenticationConverter} to use for requests
* authenticating with
Expand Down Expand Up @@ -4127,8 +4150,7 @@ protected void configure(ServerHttpSecurity http) {
if (this.authenticationManagerResolver != null) {
AuthenticationWebFilter oauth2 = new AuthenticationWebFilter(this.authenticationManagerResolver);
oauth2.setServerAuthenticationConverter(this.bearerTokenConverter);
oauth2.setAuthenticationFailureHandler(
new ServerAuthenticationEntryPointFailureHandler(this.entryPoint));
oauth2.setAuthenticationFailureHandler(authenticationFailureHandler());
http.addFilterAt(oauth2, SecurityWebFiltersOrder.AUTHENTICATION);
}
else if (this.jwt != null) {
Expand Down Expand Up @@ -4181,6 +4203,13 @@ private void registerDefaultCsrfOverride(ServerHttpSecurity http) {
}
}

private ServerAuthenticationFailureHandler authenticationFailureHandler() {
if (this.authenticationFailureHandler != null) {
return this.authenticationFailureHandler;
}
return new ServerAuthenticationEntryPointFailureHandler(this.entryPoint);
}

public ServerHttpSecurity and() {
return ServerHttpSecurity.this;
}
Expand Down Expand Up @@ -4262,8 +4291,7 @@ protected void configure(ServerHttpSecurity http) {
ReactiveAuthenticationManager authenticationManager = getAuthenticationManager();
AuthenticationWebFilter oauth2 = new AuthenticationWebFilter(authenticationManager);
oauth2.setServerAuthenticationConverter(OAuth2ResourceServerSpec.this.bearerTokenConverter);
oauth2.setAuthenticationFailureHandler(
new ServerAuthenticationEntryPointFailureHandler(OAuth2ResourceServerSpec.this.entryPoint));
oauth2.setAuthenticationFailureHandler(authenticationFailureHandler());
http.addFilterAt(oauth2, SecurityWebFiltersOrder.AUTHENTICATION);
}

Expand Down Expand Up @@ -4398,8 +4426,7 @@ protected void configure(ServerHttpSecurity http) {
ReactiveAuthenticationManager authenticationManager = getAuthenticationManager();
AuthenticationWebFilter oauth2 = new AuthenticationWebFilter(authenticationManager);
oauth2.setServerAuthenticationConverter(OAuth2ResourceServerSpec.this.bearerTokenConverter);
oauth2.setAuthenticationFailureHandler(
new ServerAuthenticationEntryPointFailureHandler(OAuth2ResourceServerSpec.this.entryPoint));
oauth2.setAuthenticationFailureHandler(authenticationFailureHandler());
http.addFilterAt(oauth2, SecurityWebFiltersOrder.AUTHENTICATION);
}

Expand Down
Expand Up @@ -21,6 +21,7 @@ import org.springframework.security.core.Authentication
import org.springframework.security.core.context.SecurityContext
import org.springframework.security.web.authentication.www.BasicAuthenticationFilter
import org.springframework.security.web.server.ServerAuthenticationEntryPoint
import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler
import org.springframework.security.web.server.context.ReactorContextWebFilter
import org.springframework.security.web.server.context.ServerSecurityContextRepository

Expand All @@ -42,6 +43,7 @@ import org.springframework.security.web.server.context.ServerSecurityContextRepo
class ServerHttpBasicDsl {
var authenticationManager: ReactiveAuthenticationManager? = null
var securityContextRepository: ServerSecurityContextRepository? = null
var authenticationFailureHandler: ServerAuthenticationFailureHandler? = null
var authenticationEntryPoint: ServerAuthenticationEntryPoint? = null

private var disabled = false
Expand All @@ -57,6 +59,7 @@ class ServerHttpBasicDsl {
return { httpBasic ->
authenticationManager?.also { httpBasic.authenticationManager(authenticationManager) }
securityContextRepository?.also { httpBasic.securityContextRepository(securityContextRepository) }
authenticationFailureHandler?.also { httpBasic.authenticationFailureHandler(authenticationFailureHandler) }
authenticationEntryPoint?.also { httpBasic.authenticationEntryPoint(authenticationEntryPoint) }
if (disabled) {
httpBasic.disable()
Expand Down
Expand Up @@ -19,6 +19,7 @@ package org.springframework.security.config.web.server
import org.springframework.security.authentication.ReactiveAuthenticationManagerResolver
import org.springframework.security.web.server.ServerAuthenticationEntryPoint
import org.springframework.security.web.server.authentication.ServerAuthenticationConverter
import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler
import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler
import org.springframework.web.server.ServerWebExchange

Expand All @@ -38,6 +39,7 @@ import org.springframework.web.server.ServerWebExchange
@ServerSecurityMarker
class ServerOAuth2ResourceServerDsl {
var accessDeniedHandler: ServerAccessDeniedHandler? = null
var authenticationFailureHandler: ServerAuthenticationFailureHandler? = null
var authenticationEntryPoint: ServerAuthenticationEntryPoint? = null
var bearerTokenConverter: ServerAuthenticationConverter? = null
var authenticationManagerResolver: ReactiveAuthenticationManagerResolver<ServerWebExchange>? = null
Expand Down Expand Up @@ -109,6 +111,7 @@ class ServerOAuth2ResourceServerDsl {
internal fun get(): (ServerHttpSecurity.OAuth2ResourceServerSpec) -> Unit {
return { oauth2ResourceServer ->
accessDeniedHandler?.also { oauth2ResourceServer.accessDeniedHandler(accessDeniedHandler) }
authenticationFailureHandler?.also { oauth2ResourceServer.authenticationFailureHandler(authenticationFailureHandler) }
authenticationEntryPoint?.also { oauth2ResourceServer.authenticationEntryPoint(authenticationEntryPoint) }
bearerTokenConverter?.also { oauth2ResourceServer.bearerTokenConverter(bearerTokenConverter) }
authenticationManagerResolver?.also { oauth2ResourceServer.authenticationManagerResolver(authenticationManagerResolver!!) }
Expand Down
Expand Up @@ -51,6 +51,7 @@
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.authentication.ReactiveAuthenticationManager;
import org.springframework.security.authentication.ReactiveAuthenticationManagerResolver;
import org.springframework.security.authentication.TestingAuthenticationToken;
Expand All @@ -73,6 +74,7 @@
import org.springframework.security.web.server.SecurityWebFilterChain;
import org.springframework.security.web.server.authentication.HttpStatusServerEntryPoint;
import org.springframework.security.web.server.authentication.ServerAuthenticationConverter;
import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler;
import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler;
import org.springframework.test.web.reactive.server.WebTestClient;
import org.springframework.web.bind.annotation.GetMapping;
Expand Down Expand Up @@ -347,6 +349,25 @@ public void getWhenUsingCustomAuthenticationManagerResolverThenUsesItAccordingly
// @formatter:on
}

@Test
public void getWhenUsingCustomAuthenticationFailureHandlerThenUsesIsAccordingly() {
this.spring.register(CustomAuthenticationFailureHandlerConfig.class).autowire();
ServerAuthenticationFailureHandler handler = this.spring.getContext()
.getBean(ServerAuthenticationFailureHandler.class);
ReactiveAuthenticationManager authenticationManager = this.spring.getContext()
.getBean(ReactiveAuthenticationManager.class);
given(authenticationManager.authenticate(any()))
.willReturn(Mono.error(() -> new BadCredentialsException("bad")));
given(handler.onAuthenticationFailure(any(), any())).willReturn(Mono.empty());
// @formatter:off
this.client.get()
.headers((headers) -> headers.setBearerAuth(this.messageReadToken))
.exchange()
.expectStatus().isOk();
// @formatter:on
verify(handler).onAuthenticationFailure(any(), any());
}

@Test
public void postWhenSignedThenReturnsOk() {
this.spring.register(PublicKeyConfig.class, RootController.class).autowire();
Expand Down Expand Up @@ -903,6 +924,35 @@ ReactiveAuthenticationManager authenticationManager() {
}

@Configuration
@EnableWebFlux
@EnableWebFluxSecurity
static class CustomAuthenticationFailureHandlerConfig {

@Bean
SecurityWebFilterChain springSecurity(ServerHttpSecurity http) {
// @formatter:off
http
.authorizeExchange((authorize) -> authorize.anyExchange().authenticated())
.oauth2ResourceServer((oauth2) -> oauth2
.authenticationFailureHandler(authenticationFailureHandler())
.jwt((jwt) -> jwt.authenticationManager(authenticationManager()))
);
// @formatter:on
return http.build();
}

@Bean
ReactiveAuthenticationManager authenticationManager() {
return mock(ReactiveAuthenticationManager.class);
}

@Bean
ServerAuthenticationFailureHandler authenticationFailureHandler() {
return mock(ServerAuthenticationFailureHandler.class);
}

}

@EnableWebFlux
@EnableWebFluxSecurity
static class CustomBearerTokenServerAuthenticationConverter {
Expand Down
Expand Up @@ -35,6 +35,7 @@
import org.springframework.http.HttpStatus;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.authentication.ReactiveAuthenticationManager;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.config.annotation.web.reactive.ServerHttpSecurityConfigurationBuilder;
Expand All @@ -57,6 +58,7 @@
import org.springframework.security.web.server.authentication.AnonymousAuthenticationWebFilterTests;
import org.springframework.security.web.server.authentication.HttpBasicServerAuthenticationEntryPoint;
import org.springframework.security.web.server.authentication.HttpStatusServerEntryPoint;
import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler;
import org.springframework.security.web.server.authentication.ServerX509AuthenticationConverter;
import org.springframework.security.web.server.authentication.logout.DelegatingServerLogoutHandler;
import org.springframework.security.web.server.authentication.logout.LogoutWebFilter;
Expand Down Expand Up @@ -218,6 +220,27 @@ public void basicWhenXHRRequestThenUnauthorized() {
verify(authenticationEntryPoint).commence(any(), any());
}

@Test
public void basicWhenCustomAuthenticationFailureHandlerThenUses() {
ReactiveAuthenticationManager authenticationManager = mock(ReactiveAuthenticationManager.class);
ServerAuthenticationFailureHandler authenticationFailureHandler = mock(
ServerAuthenticationFailureHandler.class);
this.http.httpBasic().authenticationFailureHandler(authenticationFailureHandler);
this.http.httpBasic().authenticationManager(authenticationManager);
this.http.authorizeExchange().anyExchange().authenticated();
given(authenticationManager.authenticate(any()))
.willReturn(Mono.error(() -> new BadCredentialsException("bad")));
given(authenticationFailureHandler.onAuthenticationFailure(any(), any())).willReturn(Mono.empty());
WebTestClient client = buildClient();
// @formatter:off
client.get().uri("/")
.headers((headers) -> headers.setBasicAuth("user", "password"))
.exchange()
.expectStatus().isOk();
// @formatter:on
verify(authenticationFailureHandler).onAuthenticationFailure(any(), any());
}

@Test
public void buildWhenServerWebExchangeFromContextThenFound() {
SecurityWebFilterChain filter = this.http.build();
Expand Down
Expand Up @@ -19,7 +19,6 @@ package org.springframework.security.config.web.server
import io.mockk.every
import io.mockk.mockkObject
import io.mockk.verify
import java.util.Base64
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.extension.ExtendWith
import org.springframework.beans.factory.annotation.Autowired
Expand All @@ -38,13 +37,15 @@ import org.springframework.security.core.userdetails.User
import org.springframework.security.web.server.SecurityWebFilterChain
import org.springframework.security.web.server.ServerAuthenticationEntryPoint
import org.springframework.security.web.server.authentication.HttpStatusServerEntryPoint
import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler
import org.springframework.security.web.server.context.ServerSecurityContextRepository
import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository
import org.springframework.test.web.reactive.server.WebTestClient
import org.springframework.web.bind.annotation.RequestMapping
import org.springframework.web.bind.annotation.RestController
import org.springframework.web.reactive.config.EnableWebFlux
import reactor.core.publisher.Mono
import java.util.*

/**
* Tests for [ServerHttpBasicDsl]
Expand Down Expand Up @@ -228,6 +229,43 @@ class ServerHttpBasicDslTests {
}
}

@Test
fun `http basic when custom authentication failure handler then failure handler used`() {
this.spring.register(CustomAuthenticationFailureHandlerConfig::class.java, UserDetailsConfig::class.java).autowire()
mockkObject(CustomAuthenticationFailureHandlerConfig.FAILURE_HANDLER)
every {
CustomAuthenticationFailureHandlerConfig.FAILURE_HANDLER.onAuthenticationFailure(any(), any())
} returns Mono.empty()

this.client.get()
.uri("/")
.header("Authorization", "Basic " + Base64.getEncoder().encodeToString("user:wrong".toByteArray()))
.exchange()

verify(exactly = 1) { CustomAuthenticationFailureHandlerConfig.FAILURE_HANDLER.onAuthenticationFailure(any(), any()) }
}

@EnableWebFluxSecurity
@EnableWebFlux
open class CustomAuthenticationFailureHandlerConfig {

companion object {
val FAILURE_HANDLER: ServerAuthenticationFailureHandler = ServerAuthenticationFailureHandler { _, _ -> Mono.empty() }
}

@Bean
open fun springWebFilterChain(http: ServerHttpSecurity): SecurityWebFilterChain {
return http {
authorizeExchange {
authorize(anyExchange, authenticated)
}
httpBasic {
authenticationFailureHandler = FAILURE_HANDLER
}
}
}
}

@Configuration
open class UserDetailsConfig {
@Bean
Expand Down

0 comments on commit fc8e20b

Please sign in to comment.