Skip to content

Commit

Permalink
Default to Xor CSRF tokens in CsrfWebFilter
Browse files Browse the repository at this point in the history
Closes gh-11960
  • Loading branch information
sjohnr committed Oct 13, 2022
1 parent 2a2051c commit 2407d07
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import org.springframework.context.annotation.Bean
import org.springframework.context.annotation.Configuration
import org.springframework.http.HttpStatus
import org.springframework.http.MediaType
import org.springframework.mock.http.server.reactive.MockServerHttpRequest
import org.springframework.mock.web.server.MockServerWebExchange
import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity
import org.springframework.security.config.test.SpringTestContext
import org.springframework.security.config.test.SpringTestContextExtension
Expand All @@ -39,6 +41,7 @@ import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository
import org.springframework.security.web.server.csrf.ServerCsrfTokenRequestAttributeHandler
import org.springframework.security.web.server.csrf.ServerCsrfTokenRequestHandler
import org.springframework.security.web.server.csrf.WebSessionServerCsrfTokenRepository
import org.springframework.security.web.server.csrf.XorServerCsrfTokenRequestAttributeHandler
import org.springframework.security.web.server.util.matcher.PathPatternParserServerWebExchangeMatcher
import org.springframework.test.web.reactive.server.WebTestClient
import org.springframework.web.bind.annotation.PostMapping
Expand Down Expand Up @@ -278,14 +281,23 @@ class ServerCsrfDslTests {
MultipartFormDataEnabledConfig.TOKEN_REPOSITORY.generateToken(any())
} returns Mono.just(this.token)

val csrfToken = createXorCsrfToken()
this.client.post()
.uri("/")
.contentType(MediaType.MULTIPART_FORM_DATA)
.body(fromMultipartData(this.token.parameterName, this.token.token))
.body(fromMultipartData(csrfToken.parameterName, csrfToken.token))
.exchange()
.expectStatus().isOk
}

private fun createXorCsrfToken(): CsrfToken {
val handler = XorServerCsrfTokenRequestAttributeHandler()
val exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/"))
handler.handle(exchange, Mono.just(this.token))
val deferredCsrfToken: Mono<CsrfToken>? = exchange.getAttribute(CsrfToken::class.java.name)
return deferredCsrfToken?.block()!!
}

@Configuration
@EnableWebFluxSecurity
@EnableWebFlux
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public class CsrfWebFilter implements WebFilter {
private ServerAccessDeniedHandler accessDeniedHandler = new HttpStatusServerAccessDeniedHandler(
HttpStatus.FORBIDDEN);

private ServerCsrfTokenRequestHandler requestHandler = new ServerCsrfTokenRequestAttributeHandler();
private ServerCsrfTokenRequestHandler requestHandler = new XorServerCsrfTokenRequestAttributeHandler();

public void setAccessDeniedHandler(ServerAccessDeniedHandler accessDeniedHandler) {
Assert.notNull(accessDeniedHandler, "accessDeniedHandler");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,10 @@ public void filterWhenPostAndEstablishedCsrfTokenAndRequestParamValidTokenThenCo
this.csrfFilter.setCsrfTokenRepository(this.repository);
given(this.repository.loadToken(any())).willReturn(Mono.just(this.token));
given(this.repository.generateToken(any())).willReturn(Mono.just(this.token));
CsrfToken csrfToken = createXorCsrfToken();
this.post = MockServerWebExchange
.from(MockServerHttpRequest.post("/").contentType(MediaType.APPLICATION_FORM_URLENCODED)
.body(this.token.getParameterName() + "=" + this.token.getToken()));
.body(csrfToken.getParameterName() + "=" + csrfToken.getToken()));
Mono<Void> result = this.csrfFilter.filter(this.post, this.chain);
StepVerifier.create(result).verifyComplete();
chainResult.assertWasSubscribed();
Expand All @@ -151,8 +152,9 @@ public void filterWhenPostAndEstablishedCsrfTokenAndHeaderValidTokenThenContinue
this.csrfFilter.setCsrfTokenRepository(this.repository);
given(this.repository.loadToken(any())).willReturn(Mono.just(this.token));
given(this.repository.generateToken(any())).willReturn(Mono.just(this.token));
CsrfToken csrfToken = createXorCsrfToken();
this.post = MockServerWebExchange
.from(MockServerHttpRequest.post("/").header(this.token.getHeaderName(), this.token.getToken()));
.from(MockServerHttpRequest.post("/").header(csrfToken.getHeaderName(), csrfToken.getToken()));
Mono<Void> result = this.csrfFilter.filter(this.post, this.chain);
StepVerifier.create(result).verifyComplete();
chainResult.assertWasSubscribed();
Expand Down Expand Up @@ -181,30 +183,22 @@ public void filterWhenRequestHandlerSetThenUsed() {
}

@Test
public void filterWhenXorServerCsrfTokenRequestProcessorAndValidTokenThenSuccess() {
public void filterWhenXorServerCsrfTokenRequestAttributeHandlerAndValidTokenThenSuccess() {
PublisherProbe<Void> chainResult = PublisherProbe.empty();
given(this.chain.filter(any())).willReturn(chainResult.mono());
this.csrfFilter.setCsrfTokenRepository(this.repository);
given(this.repository.generateToken(any())).willReturn(Mono.just(this.token));
given(this.repository.loadToken(any())).willReturn(Mono.just(this.token));
XorServerCsrfTokenRequestAttributeHandler requestHandler = new XorServerCsrfTokenRequestAttributeHandler();
this.csrfFilter.setRequestHandler(requestHandler);
StepVerifier.create(this.csrfFilter.filter(this.get, this.chain)).verifyComplete();
chainResult.assertWasSubscribed();

Mono<CsrfToken> csrfTokenAttribute = this.get.getAttribute(CsrfToken.class.getName());
assertThat(csrfTokenAttribute).isNotNull();
StepVerifier.create(csrfTokenAttribute)
.consumeNextWith((csrfToken) -> this.post = MockServerWebExchange
.from(MockServerHttpRequest.post("/").header(csrfToken.getHeaderName(), csrfToken.getToken())))
.verifyComplete();

CsrfToken csrfToken = createXorCsrfToken();
this.post = MockServerWebExchange
.from(MockServerHttpRequest.post("/").header(csrfToken.getHeaderName(), csrfToken.getToken()));
StepVerifier.create(this.csrfFilter.filter(this.post, this.chain)).verifyComplete();
chainResult.assertWasSubscribed();
}

@Test
public void filterWhenXorServerCsrfTokenRequestProcessorAndRawTokenThenAccessDeniedException() {
public void filterWhenXorServerCsrfTokenRequestAttributeHandlerAndRawTokenThenAccessDeniedException() {
PublisherProbe<Void> chainResult = PublisherProbe.empty();
this.csrfFilter.setCsrfTokenRepository(this.repository);
given(this.repository.loadToken(any())).willReturn(Mono.just(this.token));
Expand Down Expand Up @@ -305,6 +299,7 @@ public void filterWhenMultipartMixedAndEnabledThenNotRead() {
}

// gh-9561

@Test
public void doFilterWhenTokenIsNullThenNoNullPointer() {
this.csrfFilter.setCsrfTokenRepository(this.repository);
Expand All @@ -318,8 +313,8 @@ public void doFilterWhenTokenIsNullThenNoNullPointer() {
.bodyValue(this.token.getParameterName() + "=" + this.token.getToken()).exchange().expectStatus()
.isForbidden();
}

// gh-9113

@Test
public void filterWhenSubscribingCsrfTokenMultipleTimesThenGenerateOnlyOnce() {
PublisherProbe<CsrfToken> chainResult = PublisherProbe.empty();
Expand All @@ -334,6 +329,14 @@ public void filterWhenSubscribingCsrfTokenMultipleTimesThenGenerateOnlyOnce() {
assertThat(chainResult.subscribeCount()).isEqualTo(1);
}

private CsrfToken createXorCsrfToken() {
ServerCsrfTokenRequestHandler handler = new XorServerCsrfTokenRequestAttributeHandler();
MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/"));
handler.handle(exchange, Mono.just(this.token));
Mono<CsrfToken> csrfToken = exchange.getAttribute(CsrfToken.class.getName());
return csrfToken.block();
}

@RestController
static class OkController {

Expand Down

0 comments on commit 2407d07

Please sign in to comment.