Skip to content

Commit

Permalink
Verify ReactorContext when using Virtual Threads
Browse files Browse the repository at this point in the history
Closes gh-12791
  • Loading branch information
sjohnr committed Sep 25, 2023
1 parent d6fac11 commit ff37493
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,14 +17,20 @@
package org.springframework.security.config.annotation.web.configuration;

import java.net.URI;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadFactory;

import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.DisabledOnJre;
import org.junit.jupiter.api.condition.JRE;
import org.junit.jupiter.api.extension.ExtendWith;
import reactor.core.CoreSubscriber;
import reactor.core.publisher.BaseSubscriber;
Expand All @@ -35,6 +41,8 @@

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.task.SimpleAsyncTaskExecutor;
import org.springframework.core.task.VirtualThreadTaskExecutor;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.mock.web.MockHttpServletRequest;
Expand All @@ -46,6 +54,7 @@
import org.springframework.security.config.test.SpringTestContext;
import org.springframework.security.config.test.SpringTestContextExtension;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.oauth2.client.web.reactive.function.client.MockExchangeFunction;
Expand Down Expand Up @@ -271,6 +280,58 @@ public void createPublisherWhenCustomSecurityContextHolderStrategyThenUses() {
verify(strategy, times(2)).getContext();
}

@Test
public void createPublisherWhenThreadFactoryIsPlatformThenSecurityContextAttributesAvailable() throws Exception {
this.spring.register(SecurityConfig.class).autowire();

ThreadFactory threadFactory = Executors.defaultThreadFactory();
assertContextAttributesAvailable(threadFactory);
}

@Test
@DisabledOnJre(JRE.JAVA_17)
public void createPublisherWhenThreadFactoryIsVirtualThenSecurityContextAttributesAvailable() throws Exception {
this.spring.register(SecurityConfig.class).autowire();

ThreadFactory threadFactory = new VirtualThreadTaskExecutor().getVirtualThreadFactory();
assertContextAttributesAvailable(threadFactory);
}

private void assertContextAttributesAvailable(ThreadFactory threadFactory) throws Exception {
Map<Object, Object> expectedContextAttributes = new HashMap<>();
expectedContextAttributes.put(HttpServletRequest.class, this.servletRequest);
expectedContextAttributes.put(HttpServletResponse.class, this.servletResponse);
expectedContextAttributes.put(Authentication.class, this.authentication);

try (SimpleAsyncTaskExecutor taskExecutor = new SimpleAsyncTaskExecutor(threadFactory)) {
Future<Map<Object, Object>> future = taskExecutor.submit(this::propagateRequestAttributes);
assertThat(future.get()).isEqualTo(expectedContextAttributes);
}
}

private Map<Object, Object> propagateRequestAttributes() {
RequestAttributes requestAttributes = new ServletRequestAttributes(this.servletRequest, this.servletResponse);
RequestContextHolder.setRequestAttributes(requestAttributes);

SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
securityContext.setAuthentication(this.authentication);
SecurityContextHolder.setContext(securityContext);

// @formatter:off
return Mono.deferContextual(Mono::just)
.filter((ctx) -> ctx.hasKey(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES))
.map((ctx) -> ctx.<Map<Object, Object>>get(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES))
.map((attributes) -> {
Map<Object, Object> map = new HashMap<>();
// Copy over items from lazily loaded map
Arrays.asList(HttpServletRequest.class, HttpServletResponse.class, Authentication.class)
.forEach((key) -> map.put(key, attributes.get(key)));
return map;
})
.block();
// @formatter:on
}

@Configuration
@EnableWebSecurity
static class SecurityConfig {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2017 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,10 +16,17 @@

package org.springframework.security.core.context;

import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.DisabledOnJre;
import org.junit.jupiter.api.condition.JRE;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;
import reactor.test.StepVerifier;

import org.springframework.core.task.VirtualThreadTaskExecutor;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication;

Expand Down Expand Up @@ -99,4 +106,53 @@ public void setAuthenticationAndGetContextThenEmitsContext() {
// @formatter:on
}

@Test
public void getContextWhenThreadFactoryIsPlatformThenPropagated() {
verifySecurityContextIsPropagated(Executors.defaultThreadFactory());
}

@Test
@DisabledOnJre(JRE.JAVA_17)
public void getContextWhenThreadFactoryIsVirtualThenPropagated() {
verifySecurityContextIsPropagated(new VirtualThreadTaskExecutor().getVirtualThreadFactory());
}

private static void verifySecurityContextIsPropagated(ThreadFactory threadFactory) {
Authentication authentication = new TestingAuthenticationToken("user", null);

// @formatter:off
Mono<Authentication> publisher = ReactiveSecurityContextHolder.getContext()
.map(SecurityContext::getAuthentication)
.contextWrite((context) -> ReactiveSecurityContextHolder.withAuthentication(authentication))
.subscribeOn(Schedulers.newSingle(threadFactory));
// @formatter:on

StepVerifier.create(publisher).expectNext(authentication).verifyComplete();
}

@Test
public void clearContextWhenThreadFactoryIsPlatformThenCleared() {
verifySecurityContextIsCleared(Executors.defaultThreadFactory());
}

@Test
@DisabledOnJre(JRE.JAVA_17)
public void clearContextWhenThreadFactoryIsVirtualThenCleared() {
verifySecurityContextIsCleared(new VirtualThreadTaskExecutor().getVirtualThreadFactory());
}

private static void verifySecurityContextIsCleared(ThreadFactory threadFactory) {
Authentication authentication = new TestingAuthenticationToken("user", null);

// @formatter:off
Mono<Authentication> publisher = ReactiveSecurityContextHolder.getContext()
.map(SecurityContext::getAuthentication)
.contextWrite(ReactiveSecurityContextHolder.clearContext())
.contextWrite((context) -> ReactiveSecurityContextHolder.withAuthentication(authentication))
.subscribeOn(Schedulers.newSingle(threadFactory));
// @formatter:on

StepVerifier.create(publisher).verifyComplete();
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2017 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,17 +17,23 @@
package org.springframework.security.web.server.context;

import java.util.List;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.DisabledOnJre;
import org.junit.jupiter.api.condition.JRE;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;
import reactor.test.StepVerifier;
import reactor.test.publisher.TestPublisher;
import reactor.util.context.Context;

import org.springframework.core.task.VirtualThreadTaskExecutor;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.security.core.Authentication;
Expand Down Expand Up @@ -117,4 +123,32 @@ public void filterWhenMainContextThenDoesNotOverride() {
StepVerifier.create(filter).expectAccessibleContext().hasKey(contextKey).then().verifyComplete();
}

@Test
public void filterWhenThreadFactoryIsPlatformThenSecurityContextLoaded() {
ThreadFactory threadFactory = Executors.defaultThreadFactory();
assertSecurityContextLoaded(threadFactory);
}

@Test
@DisabledOnJre(JRE.JAVA_17)
public void filterWhenThreadFactoryIsVirtualThenSecurityContextLoaded() {
ThreadFactory threadFactory = new VirtualThreadTaskExecutor().getVirtualThreadFactory();
assertSecurityContextLoaded(threadFactory);
}

private void assertSecurityContextLoaded(ThreadFactory threadFactory) {
SecurityContextImpl context = new SecurityContextImpl(this.principal);
given(this.repository.load(any())).willReturn(Mono.just(context));
// @formatter:off
WebFilter subscribeOnThreadFactory = (exchange, chain) -> chain.filter(exchange)
.subscribeOn(Schedulers.newSingle(threadFactory));
WebFilter assertSecurityContext = (exchange, chain) -> ReactiveSecurityContextHolder.getContext()
.map(SecurityContext::getAuthentication)
.doOnSuccess((authentication) -> assertThat(authentication).isSameAs(this.principal))
.then(chain.filter(exchange));
// @formatter:on
this.handler = WebTestHandler.bindToWebFilters(subscribeOnThreadFactory, this.filter, assertSecurityContext);
this.handler.exchange(this.exchange);
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2017 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,17 +17,25 @@
package org.springframework.security.web.server.context;

import java.util.Collections;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.DisabledOnJre;
import org.junit.jupiter.api.condition.JRE;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;
import reactor.test.StepVerifier;

import org.springframework.core.task.VirtualThreadTaskExecutor;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.test.web.reactive.server.WebTestHandler;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.handler.DefaultWebFilterChain;

import static org.assertj.core.api.Assertions.assertThat;
Expand Down Expand Up @@ -80,4 +88,31 @@ public void filterWhenPrincipalNullThenContextEmpty() {
StepVerifier.create(result).verifyComplete();
}

@Test
public void filterWhenThreadFactoryIsPlatformThenContextPopulated() {
ThreadFactory threadFactory = Executors.defaultThreadFactory();
assertPrincipalPopulated(threadFactory);
}

@Test
@DisabledOnJre(JRE.JAVA_17)
public void filterWhenThreadFactoryIsVirtualThenContextPopulated() {
ThreadFactory threadFactory = new VirtualThreadTaskExecutor().getVirtualThreadFactory();
assertPrincipalPopulated(threadFactory);
}

private void assertPrincipalPopulated(ThreadFactory threadFactory) {
// @formatter:off
WebFilter subscribeOnThreadFactory = (exchange, chain) -> chain.filter(exchange)
.contextWrite(ReactiveSecurityContextHolder.withAuthentication(this.principal))
.subscribeOn(Schedulers.newSingle(threadFactory));
WebFilter assertPrincipal = (exchange, chain) -> exchange.getPrincipal()
.doOnSuccess((principal) -> assertThat(principal).isSameAs(this.principal))
.then(chain.filter(exchange));
// @formatter:on
WebTestHandler handler = WebTestHandler.bindToWebFilters(subscribeOnThreadFactory, this.filter,
assertPrincipal);
handler.exchange(this.exchange);
}

}

0 comments on commit ff37493

Please sign in to comment.