diff --git a/spring-integration-core/src/main/java/org/springframework/integration/scattergather/ScatterGatherHandler.java b/spring-integration-core/src/main/java/org/springframework/integration/scattergather/ScatterGatherHandler.java index f3eecb2eb43..7364e85fb21 100644 --- a/spring-integration-core/src/main/java/org/springframework/integration/scattergather/ScatterGatherHandler.java +++ b/spring-integration-core/src/main/java/org/springframework/integration/scattergather/ScatterGatherHandler.java @@ -16,7 +16,11 @@ package org.springframework.integration.scattergather; +import java.time.Duration; + import org.jspecify.annotations.Nullable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; import org.springframework.aop.support.AopUtils; import org.springframework.beans.factory.BeanFactory; @@ -32,6 +36,7 @@ import org.springframework.integration.endpoint.PollingConsumer; import org.springframework.integration.endpoint.ReactiveStreamsConsumer; import org.springframework.integration.handler.AbstractReplyProducingMessageHandler; +import org.springframework.integration.support.AbstractIntegrationMessageBuilder; import org.springframework.integration.support.management.ManageableLifecycle; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; @@ -48,6 +53,9 @@ /** * The {@link MessageHandler} implementation for the * Scatter-Gather EIP pattern. + *

+ * When {@link #setAsync(boolean)} is {@code true}, the {@link ScatterGatherHandler} produces + * a {@link Mono} as a reply based on the gather result. * * @author Artem Bilan * @author Abdul Zaheer @@ -146,11 +154,11 @@ public Message preSend(Message message, MessageChannel channel) { } }); - if (this.gatherChannel instanceof SubscribableChannel) { - this.gatherEndpoint = new EventDrivenConsumer((SubscribableChannel) this.gatherChannel, this.gatherer); + if (this.gatherChannel instanceof SubscribableChannel subscribableChannel) { + this.gatherEndpoint = new EventDrivenConsumer(subscribableChannel, this.gatherer); } - else if (this.gatherChannel instanceof PollableChannel) { - this.gatherEndpoint = new PollingConsumer((PollableChannel) this.gatherChannel, this.gatherer); + else if (this.gatherChannel instanceof PollableChannel pollableChannel) { + this.gatherEndpoint = new PollingConsumer(pollableChannel, this.gatherer); ((PollingConsumer) this.gatherEndpoint).setReceiveTimeout(this.gatherTimeout); } else if (this.gatherChannel instanceof ReactiveStreamsSubscribableChannel) { @@ -191,7 +199,18 @@ private Message enhanceScatterReplyMessage(Message message) { @Override protected @Nullable Object handleRequestMessage(Message requestMessage) { MessageHeaders requestMessageHeaders = requestMessage.getHeaders(); - PollableChannel gatherResultChannel = new QueueChannel(); + boolean async = isAsync(); + MessageChannel gatherResultChannel; + Sinks.One> replyMono; + + if (async) { + replyMono = Sinks.one(); + gatherResultChannel = (message, timeout) -> replyMono.tryEmitValue(message).isSuccess(); + } + else { + replyMono = null; + gatherResultChannel = new QueueChannel(); + } Message scatterMessage = getMessageBuilderFactory() @@ -204,17 +223,28 @@ private Message enhanceScatterReplyMessage(Message message) { this.messagingTemplate.send(this.scatterChannel, scatterMessage); - Message gatherResult = gatherResultChannel.receive(this.gatherTimeout); - if (gatherResult != null) { - return getMessageBuilderFactory() - .fromMessage(gatherResult) - .removeHeaders(GATHER_RESULT_CHANNEL, ORIGINAL_ERROR_CHANNEL, - MessageHeaders.REPLY_CHANNEL, MessageHeaders.ERROR_CHANNEL); + if (replyMono != null) { + return replyMono.asMono() + .map(this::replyFromGatherResult) + .timeout(Duration.ofMillis(this.gatherTimeout), Mono.empty()); + } + else { + Message gatherResult = ((PollableChannel) gatherResultChannel).receive(this.gatherTimeout); + if (gatherResult != null) { + return replyFromGatherResult(gatherResult); + } } return null; } + private AbstractIntegrationMessageBuilder replyFromGatherResult(Message gatherResult) { + return getMessageBuilderFactory() + .fromMessage(gatherResult) + .removeHeaders(GATHER_RESULT_CHANNEL, ORIGINAL_ERROR_CHANNEL, + MessageHeaders.REPLY_CHANNEL, MessageHeaders.ERROR_CHANNEL); + } + @Override public void start() { if (this.gatherEndpoint != null) { @@ -240,8 +270,8 @@ private static void checkClass(Class gathererClass, String className, String Assert.isAssignable(clazz, gathererClass, () -> "the '" + type + "' must be an " + className + " " + "instance"); } - catch (ClassNotFoundException e) { - throw new IllegalStateException("The class for '" + className + "' cannot be loaded", e); + catch (ClassNotFoundException ex) { + throw new IllegalStateException("The class for '" + className + "' cannot be loaded", ex); } } diff --git a/spring-integration-core/src/test/java/org/springframework/integration/dsl/routers/RouterTests.java b/spring-integration-core/src/test/java/org/springframework/integration/dsl/routers/RouterTests.java index 7094a8f72d8..c0944ba1cc2 100644 --- a/spring-integration-core/src/test/java/org/springframework/integration/dsl/routers/RouterTests.java +++ b/spring-integration-core/src/test/java/org/springframework/integration/dsl/routers/RouterTests.java @@ -16,6 +16,7 @@ package org.springframework.integration.dsl.routers; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; @@ -463,15 +464,21 @@ public void testRouterAsNonLastComponent() { @Test public void testScatterGather() { QueueChannel replyChannel = new QueueChannel(); - Message request = MessageBuilder.withPayload("foo") + Message request = MessageBuilder.withPayload("test") .setReplyChannel(replyChannel) .build(); this.scatterGatherFlowInput.send(request); Message bestQuoteMessage = replyChannel.receive(10000); - assertThat(bestQuoteMessage).isNotNull(); - Object payload = bestQuoteMessage.getPayload(); - assertThat(payload).isInstanceOf(List.class); - assertThat(((List) payload).size()).isGreaterThanOrEqualTo(1); + assertThat(bestQuoteMessage) + .extracting(Message::getPayload) + .asInstanceOf(InstanceOfAssertFactories.LIST) + .hasSizeGreaterThanOrEqualTo(1) + .first() + .asInstanceOf(InstanceOfAssertFactories.type(Message.class)) + .extracting(Message::getHeaders) + .asInstanceOf(InstanceOfAssertFactories.MAP) + .extractingByKey("gatherResultChannel") + .isNotInstanceOf(PollableChannel.class); } @Autowired @@ -859,9 +866,11 @@ public IntegrationFlow scatterGatherFlow() { group.size() == 3 || group.getMessages() .stream() - .anyMatch(m -> (Double) m.getPayload() > 5)), + .anyMatch(m -> (Double) m.getPayload() > 5)) + .outputProcessor(group -> new ArrayList<>(group.getMessages())), scatterGather -> scatterGather - .gatherTimeout(10_000)); + .gatherTimeout(10_000) + .async(true)); } @Bean diff --git a/src/reference/antora/modules/ROOT/pages/scatter-gather.adoc b/src/reference/antora/modules/ROOT/pages/scatter-gather.adoc index 0e56a93ce66..271684e7dc2 100644 --- a/src/reference/antora/modules/ROOT/pages/scatter-gather.adoc +++ b/src/reference/antora/modules/ROOT/pages/scatter-gather.adoc @@ -154,6 +154,10 @@ Mutually exclusive with `scatter-channel` attribute. <13> The `` options. Required. +NOTE: Starting with version `6.5.3`, when a `ScatterGatherHandler` is configured for the `async = true` option, the request message handling thread is not blocked anymore waiting for a gather result on an internal `((PollableChannel) gatherResultChannel).receive(this.gatherTimeout)` operation. +Instead, a `reactor.core.publisher.Mono` is returned as a reply object based on a gather result eventually produced from the `gatherResultChannel`. +Such a `Mono` is handled then according to the xref:reactive-streams.adoc#reactive-reply-payload[Reactive Streams support] in the framework. + [[scatter-gather-error-handling]] == Error Handling