diff --git a/spring-integration-core/src/main/java/org/springframework/integration/handler/AbstractMessageProducingHandler.java b/spring-integration-core/src/main/java/org/springframework/integration/handler/AbstractMessageProducingHandler.java index f92e9d6270d..07689cfca12 100644 --- a/spring-integration-core/src/main/java/org/springframework/integration/handler/AbstractMessageProducingHandler.java +++ b/spring-integration-core/src/main/java/org/springframework/integration/handler/AbstractMessageProducingHandler.java @@ -321,25 +321,42 @@ private void doProduceOutput(Message requestMessage, MessageHeaders requestHe replyChannel = getOutputChannel(); } + Object replyPayload = reply; + Message replyMessage = reply instanceof Message message ? message : null; + + if (replyMessage != null) { + replyPayload = replyMessage.getPayload(); + } + if (this.async) { - boolean isFutureReply = reply instanceof CompletableFuture; + boolean isFutureReply = replyPayload instanceof CompletableFuture; ReactiveAdapter reactiveAdapter = null; if (!isFutureReply) { - reactiveAdapter = ReactiveAdapterRegistry.getSharedInstance().getAdapter(null, reply); + reactiveAdapter = ReactiveAdapterRegistry.getSharedInstance().getAdapter(null, replyPayload); } if (isFutureReply || reactiveAdapter != null) { if (replyChannel instanceof ReactiveStreamsSubscribableChannel reactiveStreamsSubscribableChannel) { - Publisher reactiveReply = toPublisherReply(reply, reactiveAdapter); + Publisher reactiveReply = toPublisherReply(replyPayload, reactiveAdapter); reactiveStreamsSubscribableChannel .subscribeTo( Flux.from(reactiveReply) .doOnError((ex) -> sendErrorMessage(requestMessage, ex)) - .map(result -> createOutputMessage(result, requestHeaders))); + .map(result -> { + if (replyMessage != null) { + return getMessageBuilderFactory() + .withPayload(result) + .copyHeaders(replyMessage.getHeaders()) + .build(); + } + else { + return createOutputMessage(result, requestHeaders); + } + })); } else { - CompletableFuture futureReply = toFutureReply(reply, reactiveAdapter); + CompletableFuture futureReply = toFutureReply(replyPayload, replyMessage, reactiveAdapter); futureReply.whenComplete(new ReplyFutureCallback(requestMessage, replyChannel)); } @@ -359,8 +376,12 @@ private Publisher toPublisherReply(Object reply, @Nullable ReactiveAdapter re } } - @SuppressWarnings("try") - private CompletableFuture toFutureReply(Object reply, @Nullable ReactiveAdapter reactiveAdapter) { + @SuppressWarnings({"try", "unchecked"}) + private CompletableFuture toFutureReply(Object reply, @Nullable Message replyMessage, + @Nullable ReactiveAdapter reactiveAdapter) { + + CompletableFuture replyFuture; + if (reactiveAdapter != null) { Mono reactiveReply; Publisher publisher = reactiveAdapter.toPublisher(reply); @@ -371,7 +392,7 @@ private CompletableFuture toFutureReply(Object reply, @Nullable ReactiveAdapt reactiveReply = Mono.from(publisher); } - CompletableFuture replyFuture = new CompletableFuture<>(); + replyFuture = new CompletableFuture<>(); reactiveReply /* @@ -379,7 +400,7 @@ private CompletableFuture toFutureReply(Object reply, @Nullable ReactiveAdapt and it does not suppose to, since there is no guarantee how this Future is going to be handled downstream. However, in our case we process it directly in this class in the doProduceOutput() - via whenComplete() callback. So, when value is set into the Future, it is available + via whenComplete() callback. So, when the value is set into the Future, it is available in the callback in the same thread immediately. */ .doOnEach((signal) -> { @@ -400,12 +421,20 @@ via whenComplete() callback. So, when value is set into the Future, it is availa }) .contextCapture() .subscribe(); - - return replyFuture; } else { - return (CompletableFuture) reply; + replyFuture = (CompletableFuture) reply; } + + if (replyMessage == null) { + return replyFuture; + } + + return replyFuture.thenApply(result -> + getMessageBuilderFactory() + .withPayload(result) + .copyHeaders(replyMessage.getHeaders()) + .build()); } private AbstractIntegrationMessageBuilder addRoutingSlipHeader(Object reply, List routingSlip, diff --git a/spring-integration-core/src/test/java/org/springframework/integration/dsl/transformers/TransformerTests.java b/spring-integration-core/src/test/java/org/springframework/integration/dsl/transformers/TransformerTests.java index 1a761c6cf5b..7caa7c78757 100644 --- a/spring-integration-core/src/test/java/org/springframework/integration/dsl/transformers/TransformerTests.java +++ b/spring-integration-core/src/test/java/org/springframework/integration/dsl/transformers/TransformerTests.java @@ -18,11 +18,15 @@ import java.io.InputStream; import java.io.OutputStream; +import java.time.Duration; import java.util.Collections; import java.util.Date; import java.util.Map; +import java.util.concurrent.CompletableFuture; import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; +import reactor.test.StepVerifier; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Qualifier; @@ -32,6 +36,7 @@ import org.springframework.integration.annotation.Transformer; import org.springframework.integration.channel.DirectChannel; import org.springframework.integration.channel.FixedSubscriberChannel; +import org.springframework.integration.channel.FluxMessageChannel; import org.springframework.integration.channel.QueueChannel; import org.springframework.integration.codec.Codec; import org.springframework.integration.config.EnableIntegration; @@ -273,6 +278,41 @@ public void testFailedTransformWithRequestHeadersCopy() { .isEqualTo("transform failed"); } + @Autowired + @Qualifier("asyncTransformerFlow.input") + MessageChannel asyncTransformerFlowInput; + + @Test + void asyncTransformerReplyIsProcessed() { + QueueChannel replyChannel = new QueueChannel(); + this.asyncTransformerFlowInput.send( + MessageBuilder.withPayload("test") + .setReplyChannel(replyChannel) + .build()); + + Message receive = replyChannel.receive(10_000); + + assertThat(receive).extracting(Message::getPayload).isEqualTo("test async"); + + } + + @Test + void reactiveTransformerReplyIsProcessed() { + FluxMessageChannel replyChannel = new FluxMessageChannel(); + this.asyncTransformerFlowInput.send( + MessageBuilder.withPayload("test") + .setReplyChannel(replyChannel) + .build()); + + StepVerifier.create( + Flux.from(replyChannel) + .map(Message::getPayload) + .cast(String.class)) + .expectNext("test async") + .thenCancel() + .verify(Duration.ofSeconds(10)); + } + @Configuration @EnableIntegration public static class ContextConfiguration { @@ -465,6 +505,15 @@ public IntegrationFlow transformFlowWithError() { .log(); } + @Bean + public IntegrationFlow asyncTransformerFlow() { + return f -> f + .transformWith(endpoint -> endpoint + .>transformer(payload -> + CompletableFuture.completedFuture(payload + " async")) + .async(true)); + } + } private static final class TestPojo {