diff --git a/spring-integration-amqp/src/main/java/org/springframework/integration/amqp/outbound/RabbitStreamMessageHandler.java b/spring-integration-amqp/src/main/java/org/springframework/integration/amqp/outbound/RabbitStreamMessageHandler.java index f55609d644f..ec21b52915e 100644 --- a/spring-integration-amqp/src/main/java/org/springframework/integration/amqp/outbound/RabbitStreamMessageHandler.java +++ b/spring-integration-amqp/src/main/java/org/springframework/integration/amqp/outbound/RabbitStreamMessageHandler.java @@ -29,6 +29,7 @@ import org.springframework.integration.amqp.support.AmqpHeaderMapper; import org.springframework.integration.amqp.support.DefaultAmqpHeaderMapper; import org.springframework.integration.amqp.support.MappingUtils; +import org.springframework.integration.context.IntegrationContextUtils; import org.springframework.integration.core.MessagingTemplate; import org.springframework.integration.handler.AbstractMessageHandler; import org.springframework.messaging.Message; @@ -45,6 +46,7 @@ * * @author Gary Russell * @author Chris Bono + * @author Ryan Riley * @since 6.0 * */ @@ -173,14 +175,14 @@ public RabbitStreamOperations getStreamOperations() { } protected @Nullable MessageChannel getSendFailureChannel() { - if (this.sendFailureChannel != null) { - return this.sendFailureChannel; - } - else if (this.sendFailureChannelName != null) { - this.sendFailureChannel = getChannelResolver().resolveDestination(this.sendFailureChannelName); - return this.sendFailureChannel; + if (this.sendFailureChannel == null && (this.sendFailureChannelName != null || !this.sync)) { + String sendFailureChannelNameToUse = this.sendFailureChannelName; + if (sendFailureChannelNameToUse == null) { + sendFailureChannelNameToUse = IntegrationContextUtils.ERROR_CHANNEL_BEAN_NAME; + } + this.sendFailureChannel = getChannelResolver().resolveDestination(sendFailureChannelNameToUse); } - return null; + return this.sendFailureChannel; } protected @Nullable MessageChannel getSendSuccessChannel() { diff --git a/spring-integration-amqp/src/test/java/org/springframework/integration/amqp/outbound/RabbitStreamMessageHandlerTests.java b/spring-integration-amqp/src/test/java/org/springframework/integration/amqp/outbound/RabbitStreamMessageHandlerTests.java index 8ac2af1f6b9..1435f0e7084 100644 --- a/spring-integration-amqp/src/test/java/org/springframework/integration/amqp/outbound/RabbitStreamMessageHandlerTests.java +++ b/spring-integration-amqp/src/test/java/org/springframework/integration/amqp/outbound/RabbitStreamMessageHandlerTests.java @@ -16,26 +16,37 @@ package org.springframework.integration.amqp.outbound; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import com.rabbitmq.stream.Consumer; import com.rabbitmq.stream.Environment; +import com.rabbitmq.stream.Message; import com.rabbitmq.stream.OffsetSpecification; +import com.rabbitmq.stream.codec.SimpleCodec; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentMatchers; +import org.mockito.Mockito; import org.springframework.integration.amqp.dsl.RabbitStream; import org.springframework.integration.amqp.support.RabbitTestContainer; +import org.springframework.integration.channel.QueueChannel; import org.springframework.integration.support.MessageBuilder; +import org.springframework.messaging.MessageHandlingException; +import org.springframework.messaging.support.ErrorMessage; import org.springframework.rabbit.stream.producer.RabbitStreamTemplate; +import org.springframework.rabbit.stream.producer.StreamSendException; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; /** * @author Gary Russell * @author Chris Bono * @author Artem Bilan + * @author Ryan Riley * * @since 6.0 */ @@ -117,4 +128,52 @@ void sendNative() throws InterruptedException { streamTemplate.close(); } + @Test + void errorChanelAsync() { + Environment env = Mockito.mock(Environment.class); + RabbitStreamTemplate streamTemplate = new RabbitStreamTemplate(env, "stream.stream"); + RabbitStreamTemplate spyStreamTemplate = Mockito.spy(streamTemplate); + CompletableFuture> errorFuture = new CompletableFuture<>(); + Mockito.doReturn(errorFuture).when(spyStreamTemplate).send(ArgumentMatchers.any(Message.class)); + + QueueChannel errorChannel = new QueueChannel(); + RabbitStreamMessageHandler handler = RabbitStream.outboundStreamAdapter(spyStreamTemplate) + .sync(false) + .sendFailureChannel(errorChannel) + .getObject(); + SimpleCodec codec = new SimpleCodec(); + org.springframework.messaging.Message testMessage = MessageBuilder.withPayload(codec.messageBuilder() + .addData(new byte[1]) + .build()) + .build(); + handler.handleMessage(testMessage); + StreamSendException streamException = new StreamSendException("Test Error Code", 99); + errorFuture.completeExceptionally(streamException); + ErrorMessage errorMessage = (ErrorMessage) errorChannel.receive(1000); + assertThat(errorMessage).extracting(org.springframework.messaging.Message::getPayload).isEqualTo(streamException); + } + + @Test + void errorChanelSync() { + Environment env = Mockito.mock(Environment.class); + RabbitStreamTemplate streamTemplate = new RabbitStreamTemplate(env, "stream.stream"); + RabbitStreamTemplate spyStreamTemplate = Mockito.spy(streamTemplate); + CompletableFuture> errorFuture = new CompletableFuture<>(); + errorFuture.exceptionally(ErrorMessage::new); + Mockito.doReturn(errorFuture).when(spyStreamTemplate).send(ArgumentMatchers.any(Message.class)); + + QueueChannel errorChannel = new QueueChannel(); + RabbitStreamMessageHandler handler = RabbitStream.outboundStreamAdapter(spyStreamTemplate) + .sync(true) + .sendFailureChannel(errorChannel) + .getObject(); + SimpleCodec codec = new SimpleCodec(); + org.springframework.messaging.Message testMessage = MessageBuilder.withPayload(codec.messageBuilder() + .addData(new byte[1]) + .build()) + .build(); + assertThatExceptionOfType(MessageHandlingException.class) + .isThrownBy(() -> handler.handleMessage(testMessage)); + } + } diff --git a/spring-integration-kafka/src/main/java/org/springframework/integration/kafka/outbound/KafkaProducerMessageHandler.java b/spring-integration-kafka/src/main/java/org/springframework/integration/kafka/outbound/KafkaProducerMessageHandler.java index 49e61efe0e1..9b2efd5d0ab 100644 --- a/spring-integration-kafka/src/main/java/org/springframework/integration/kafka/outbound/KafkaProducerMessageHandler.java +++ b/spring-integration-kafka/src/main/java/org/springframework/integration/kafka/outbound/KafkaProducerMessageHandler.java @@ -38,6 +38,7 @@ import org.springframework.expression.EvaluationContext; import org.springframework.expression.Expression; import org.springframework.integration.MessageTimeoutException; +import org.springframework.integration.context.IntegrationContextUtils; import org.springframework.integration.expression.ExpressionUtils; import org.springframework.integration.expression.FunctionExpression; import org.springframework.integration.expression.ValueExpression; @@ -89,6 +90,7 @@ * @author Marius Bogoevici * @author Biju Kunjummen * @author Tom van den Berge + * @author Ryan Riley * * @since 5.4 */ @@ -434,16 +436,15 @@ public String getComponentType() { return this.isGateway ? "kafka:outbound-gateway" : "kafka:outbound-channel-adapter"; } - @Nullable - protected MessageChannel getSendFailureChannel() { - if (this.sendFailureChannel != null) { - return this.sendFailureChannel; - } - else if (this.sendFailureChannelName != null) { - this.sendFailureChannel = getChannelResolver().resolveDestination(this.sendFailureChannelName); - return this.sendFailureChannel; + protected @Nullable MessageChannel getSendFailureChannel() { + if (this.sendFailureChannel == null && (this.sendFailureChannelName != null || !this.sync)) { + String sendFailureChannelNameToUse = this.sendFailureChannelName; + if (sendFailureChannelNameToUse == null) { + sendFailureChannelNameToUse = IntegrationContextUtils.ERROR_CHANNEL_BEAN_NAME; + } + this.sendFailureChannel = getChannelResolver().resolveDestination(sendFailureChannelNameToUse); } - return null; + return this.sendFailureChannel; } protected @Nullable MessageChannel getSendSuccessChannel() { diff --git a/spring-integration-kafka/src/test/java/org/springframework/integration/kafka/outbound/KafkaProducerMessageHandlerTests.java b/spring-integration-kafka/src/test/java/org/springframework/integration/kafka/outbound/KafkaProducerMessageHandlerTests.java index 3aa7e3f41ad..e68b738e1c6 100644 --- a/spring-integration-kafka/src/test/java/org/springframework/integration/kafka/outbound/KafkaProducerMessageHandlerTests.java +++ b/spring-integration-kafka/src/test/java/org/springframework/integration/kafka/outbound/KafkaProducerMessageHandlerTests.java @@ -52,6 +52,7 @@ import org.springframework.expression.common.LiteralExpression; import org.springframework.expression.spel.standard.SpelExpressionParser; import org.springframework.integration.channel.DirectChannel; +import org.springframework.integration.channel.NullChannel; import org.springframework.integration.channel.QueueChannel; import org.springframework.integration.expression.FunctionExpression; import org.springframework.integration.expression.ValueExpression; @@ -116,6 +117,7 @@ * @author Biju Kunjummen * @author Artem Bilan * @author Tom van den Berge + * @author Ryan Riley * * @since 5.4 */ @@ -544,6 +546,7 @@ void testConsumeAndProduceTransaction() throws Exception { DirectChannel channel = new DirectChannel(); inbound.setOutputChannel(channel); KafkaProducerMessageHandler handler = new KafkaProducerMessageHandler(template); + handler.setSendFailureChannel(new NullChannel()); handler.setMessageKeyExpression(new LiteralExpression("bar")); handler.setTopicExpression(new LiteralExpression("topic")); channel.subscribe(handler); @@ -690,6 +693,7 @@ protected Producer createTransactionalProducer(String txIdPrefix) { DirectChannel channel = new DirectChannel(); inbound.setOutputChannel(channel); KafkaProducerMessageHandler handler = new KafkaProducerMessageHandler(template); + handler.setSendFailureChannel(new NullChannel()); handler.setMessageKeyExpression(new LiteralExpression("bar")); handler.setTopicExpression(new LiteralExpression("topic")); channel.subscribe(handler);