Skip to content

Commit

Permalink
Pickup CoroutineContext saved by CoWebFilter in coRouter
Browse files Browse the repository at this point in the history
Closes gh-31793
  • Loading branch information
sdeleuze committed Dec 11, 2023
1 parent 5700742 commit aabe4d0
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.springframework.http.HttpMethod
import org.springframework.http.HttpStatusCode
import org.springframework.http.MediaType
import org.springframework.web.reactive.function.server.RouterFunctions.nest
import org.springframework.web.server.CoWebFilter
import reactor.core.publisher.Mono
import java.net.URI
import kotlin.coroutines.CoroutineContext
Expand Down Expand Up @@ -731,7 +732,8 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
) : HandlerFunction<T> {

override fun handle(request: ServerRequest): Mono<T> {
return handle(Dispatchers.Unconfined, request)
val context = request.attributes()[CoWebFilter.COROUTINE_CONTEXT_ATTRIBUTE] as CoroutineContext?
return handle(context ?: Dispatchers.Unconfined, request)
}

fun handle(context: CoroutineContext, request: ServerRequest) = asMono(request, context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ import org.springframework.http.HttpHeaders.CONTENT_TYPE
import org.springframework.http.HttpMethod.PATCH
import org.springframework.http.HttpStatus
import org.springframework.http.MediaType.*
import org.springframework.web.server.CoWebFilter
import org.springframework.web.server.CoWebFilterChain
import org.springframework.web.server.ServerWebExchange
import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest.*
import org.springframework.web.testfixture.http.server.reactive.MockServerHttpResponse
import org.springframework.web.testfixture.server.MockServerWebExchange
import reactor.test.StepVerifier

Expand Down Expand Up @@ -204,6 +208,16 @@ class CoRouterFunctionDslTests {
.verifyComplete()
}

@Test
fun webFilterAndContext() {
val strategies = HandlerStrategies.builder().webFilter(MyCoWebFilterWithContext()).build()
val httpHandler = RouterFunctions.toHttpHandler(routerWithoutContext, strategies)
val mockRequest = get("https://example.com/").build()
val mockResponse = MockServerHttpResponse()
StepVerifier.create(httpHandler.handle(mockRequest, mockResponse)).verifyComplete()
assertThat(mockResponse.headers.getFirst("context")).contains("Filter context")
}

@Test
fun multipleContextProviders() {
assertThatIllegalStateException().isThrownBy {
Expand Down Expand Up @@ -309,6 +323,12 @@ class CoRouterFunctionDslTests {
}
}

private val routerWithoutContext = coRouter {
GET("/") {
ok().header("context", currentCoroutineContext().toString()).buildAndAwait()
}
}

private val otherRouter = router {
"/other" {
ok().build()
Expand Down Expand Up @@ -369,3 +389,12 @@ class CoRouterFunctionDslTests {

@Suppress("UNUSED_PARAMETER")
private suspend fun handle(req: ServerRequest) = ServerResponse.ok().buildAndAwait()


private class MyCoWebFilterWithContext : CoWebFilter() {
override suspend fun filter(exchange: ServerWebExchange, chain: CoWebFilterChain) {
withContext(CoroutineName("Filter context")) {
chain.filter(exchange)
}
}
}

0 comments on commit aabe4d0

Please sign in to comment.