Skip to content

Commit

Permalink
Propagate CoroutineContext in coRouter filters
Browse files Browse the repository at this point in the history
  • Loading branch information
sdeleuze committed Aug 29, 2023
1 parent bcf11e8 commit 8e77a30
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,14 +17,18 @@
package org.springframework.web.reactive.function.server

import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.currentCoroutineContext
import kotlinx.coroutines.reactor.awaitSingle
import kotlinx.coroutines.reactor.mono
import org.springframework.core.io.Resource
import org.springframework.http.HttpMethod
import org.springframework.http.HttpStatusCode
import org.springframework.http.MediaType
import org.springframework.web.reactive.function.server.RouterFunctions.nest
import reactor.core.publisher.Mono
import java.net.URI
import kotlin.coroutines.CoroutineContext

/**
* Allow to create easily a WebFlux.fn [RouterFunction] with a [Coroutines router Kotlin DSL][CoRouterFunctionDsl].
Expand Down Expand Up @@ -532,7 +536,12 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
builder.filter { serverRequest, handlerFunction ->
mono(Dispatchers.Unconfined) {
filterFunction(serverRequest) { handlerRequest ->
handlerFunction.handle(handlerRequest).awaitSingle()
if (handlerFunction is CoroutineContextAwareHandlerFunction<*>) {
handlerFunction.handle(currentCoroutineContext().minusKey(Job.Key), handlerRequest).awaitSingle()
}
else {
handlerFunction.handle(handlerRequest).awaitSingle()
}
}
}
}
Expand Down Expand Up @@ -618,11 +627,8 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
return builder.build()
}

private fun asHandlerFunction(init: suspend (ServerRequest) -> ServerResponse) = HandlerFunction {
mono(Dispatchers.Unconfined) {
init(it)
}
}
private fun <T : ServerResponse> asHandlerFunction(handler: suspend (ServerRequest) -> T) =
CoroutineContextAwareHandlerFunction(handler)

/**
* @see ServerResponse.from
Expand Down Expand Up @@ -691,6 +697,18 @@ class CoRouterFunctionDsl internal constructor (private val init: (CoRouterFunct
*/
fun status(status: Int) = ServerResponse.status(status)

private class CoroutineContextAwareHandlerFunction<T : ServerResponse>(private val handler: suspend (ServerRequest) -> T) : HandlerFunction<T> {

override fun handle(request: ServerRequest): Mono<T> {
return handle(Dispatchers.Unconfined, request)
}

fun handle(context: CoroutineContext, request: ServerRequest): Mono<T> = mono(context) {
handler(request)
}

}

}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,17 +16,20 @@

package org.springframework.web.reactive.function.server

import kotlinx.coroutines.CoroutineName
import kotlinx.coroutines.currentCoroutineContext
import kotlinx.coroutines.withContext
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatExceptionOfType
import org.junit.jupiter.api.Test
import org.springframework.core.io.ClassPathResource
import org.springframework.http.HttpHeaders.*
import org.springframework.http.HttpMethod.*
import org.springframework.http.HttpHeaders.ACCEPT
import org.springframework.http.HttpHeaders.CONTENT_TYPE
import org.springframework.http.HttpMethod.PATCH
import org.springframework.http.HttpStatus
import org.springframework.http.MediaType.*
import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest.*
import org.springframework.web.testfixture.server.MockServerWebExchange
import org.springframework.web.reactive.function.server.AttributesTestVisitor
import reactor.test.StepVerifier

/**
Expand Down Expand Up @@ -165,6 +168,17 @@ class CoRouterFunctionDslTests {
.verifyComplete()
}

@Test
fun filteringWithContext() {
val mockRequest = get("https://example.com/").build()
val request = DefaultServerRequest(MockServerWebExchange.from(mockRequest), emptyList())
StepVerifier.create(filterRouterWithContext.route(request).flatMap { it.handle(request) })
.expectNextMatches { response ->
response.headers().getFirst("context")!!.contains("Filter context")
}
.verifyComplete()
}

@Test
fun attributes() {
val visitor = AttributesTestVisitor()
Expand Down Expand Up @@ -226,6 +240,17 @@ class CoRouterFunctionDslTests {
}
}

private val filterRouterWithContext = coRouter {
filter { request, next ->
withContext(CoroutineName("Filter context")) {
next(request)
}
}
GET("/") {
ok().header("context", currentCoroutineContext().toString()).buildAndAwait()
}
}

private val otherRouter = router {
"/other" {
ok().build()
Expand Down

0 comments on commit 8e77a30

Please sign in to comment.