diff --git a/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDsl.kt b/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDsl.kt index 17151948a6b4..eb63a9f9339e 100644 --- a/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDsl.kt +++ b/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDsl.kt @@ -27,6 +27,7 @@ import org.springframework.http.HttpMethod import org.springframework.http.HttpStatusCode import org.springframework.http.MediaType import org.springframework.web.reactive.function.server.RouterFunctions.nest +import org.springframework.web.server.CoWebFilter import reactor.core.publisher.Mono import java.net.URI import kotlin.coroutines.CoroutineContext @@ -731,7 +732,8 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct ) : HandlerFunction { override fun handle(request: ServerRequest): Mono { - return handle(Dispatchers.Unconfined, request) + val context = request.attributes()[CoWebFilter.COROUTINE_CONTEXT_ATTRIBUTE] as CoroutineContext? + return handle(context ?: Dispatchers.Unconfined, request) } fun handle(context: CoroutineContext, request: ServerRequest) = asMono(request, context) { diff --git a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDslTests.kt b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDslTests.kt index b9fc1b9b121f..c2e48f4883aa 100644 --- a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDslTests.kt +++ b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/server/CoRouterFunctionDslTests.kt @@ -25,7 +25,11 @@ import org.springframework.http.HttpHeaders.CONTENT_TYPE import org.springframework.http.HttpMethod.PATCH import org.springframework.http.HttpStatus import org.springframework.http.MediaType.* +import org.springframework.web.server.CoWebFilter +import org.springframework.web.server.CoWebFilterChain +import org.springframework.web.server.ServerWebExchange import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest.* +import org.springframework.web.testfixture.http.server.reactive.MockServerHttpResponse import org.springframework.web.testfixture.server.MockServerWebExchange import reactor.test.StepVerifier @@ -204,6 +208,16 @@ class CoRouterFunctionDslTests { .verifyComplete() } + @Test + fun webFilterAndContext() { + val strategies = HandlerStrategies.builder().webFilter(MyCoWebFilterWithContext()).build() + val httpHandler = RouterFunctions.toHttpHandler(routerWithoutContext, strategies) + val mockRequest = get("https://example.com/").build() + val mockResponse = MockServerHttpResponse() + StepVerifier.create(httpHandler.handle(mockRequest, mockResponse)).verifyComplete() + assertThat(mockResponse.headers.getFirst("context")).contains("Filter context") + } + @Test fun multipleContextProviders() { assertThatIllegalStateException().isThrownBy { @@ -309,6 +323,12 @@ class CoRouterFunctionDslTests { } } + private val routerWithoutContext = coRouter { + GET("/") { + ok().header("context", currentCoroutineContext().toString()).buildAndAwait() + } + } + private val otherRouter = router { "/other" { ok().build() @@ -369,3 +389,12 @@ class CoRouterFunctionDslTests { @Suppress("UNUSED_PARAMETER") private suspend fun handle(req: ServerRequest) = ServerResponse.ok().buildAndAwait() + + +private class MyCoWebFilterWithContext : CoWebFilter() { + override suspend fun filter(exchange: ServerWebExchange, chain: CoWebFilterChain) { + withContext(CoroutineName("Filter context")) { + chain.filter(exchange) + } + } +}