Skip to content

Commit

Permalink
GH-2348: Custom Correlation Consumer Side
Browse files Browse the repository at this point in the history
Resolves #2348

The replying template supports a custom header for correlation for cases
when the consumer side is not a Spring app and uses a different header.

Support a custom header name on the consumer side, for cases where the
client side is not Spring and uses a different header.
  • Loading branch information
garyrussell authored and artembilan committed Sep 22, 2022
1 parent 9066598 commit 8a23e47
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 35 deletions.
3 changes: 3 additions & 0 deletions spring-kafka-docs/src/main/asciidoc/kafka.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,9 @@ These header names are used by the `@KafkaListener` infrastructure to route the
Starting with version 2.3, you can customize the header names - the template has 3 properties `correlationHeaderName`, `replyTopicHeaderName`, and `replyPartitionHeaderName`.
This is useful if your server is not a Spring application (or does not use the `@KafkaListener`).

NOTE: Conversely, if the requesting application is not a spring application and puts correlation information in a different header, starting with version 3.0, you can configure a custom `correlationHeaderName` on the listener container factory and that header will be echoed back.
Previously, the listener had to echo custom correlation headers.

[[exchanging-messages]]
====== Request/Reply with `Message<?>` s

Expand Down
6 changes: 6 additions & 0 deletions spring-kafka-docs/src/main/asciidoc/whats-new.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,9 @@ See <<kafka-template>>.

The futures returned by this class are now `CompletableFuture` s instead of `ListenableFuture` s.
See <<replying-template>> and <<exchanging-messages>>.

[[x30-listener]]
==== `@KafkaListener` Changes

You can now use a custom correlation header which will be echoed in any reply message.
See the note at the end of <<replying-template>> for more information.
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ public abstract class AbstractKafkaListenerContainerFactory<C extends AbstractMe

private ContainerCustomizer<K, V, C> containerCustomizer;

private String correlationHeaderName;

@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
this.applicationContext = applicationContext;
Expand Down Expand Up @@ -321,6 +323,17 @@ public void setContainerCustomizer(ContainerCustomizer<K, V, C> containerCustomi
this.containerCustomizer = containerCustomizer;
}

/**
* Set a custom header name for the correlation id. Default
* {@link org.springframework.kafka.support.KafkaHeaders#CORRELATION_ID}. This header
* will be echoed back in any reply message.
* @param correlationHeaderName the header name.
* @since 3.0
*/
public void setCorrelationHeaderName(String correlationHeaderName) {
this.correlationHeaderName = correlationHeaderName;
}

@SuppressWarnings("deprecation")
@Override
public void afterPropertiesSet() {
Expand Down Expand Up @@ -363,7 +376,8 @@ private void configureEndpoint(AbstractKafkaListenerEndpoint<K, V> aklEndpoint)
.acceptIfNotNull(this.ackDiscarded, aklEndpoint::setAckDiscarded)
.acceptIfNotNull(this.replyTemplate, aklEndpoint::setReplyTemplate)
.acceptIfNotNull(this.replyHeadersConfigurer, aklEndpoint::setReplyHeadersConfigurer)
.acceptIfNotNull(this.batchToRecordAdapter, aklEndpoint::setBatchToRecordAdapter);
.acceptIfNotNull(this.batchToRecordAdapter, aklEndpoint::setBatchToRecordAdapter)
.acceptIfNotNull(this.correlationHeaderName, aklEndpoint::setCorrelationHeaderName);
if (aklEndpoint.getBatchListener() == null) {
JavaUtils.INSTANCE
.acceptIfNotNull(this.batchListener, aklEndpoint::setBatchListener);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import org.springframework.kafka.listener.adapter.MessagingMessageListenerAdapter;
import org.springframework.kafka.listener.adapter.RecordFilterStrategy;
import org.springframework.kafka.listener.adapter.ReplyHeadersConfigurer;
import org.springframework.kafka.support.JavaUtils;
import org.springframework.kafka.support.TopicPartitionOffset;
import org.springframework.kafka.support.converter.MessageConverter;
import org.springframework.lang.Nullable;
Expand Down Expand Up @@ -117,6 +118,8 @@ public abstract class AbstractKafkaListenerEndpoint<K, V>

private byte[] listenerInfo;

private String correlationHeaderName;

@Override
public void setBeanFactory(BeanFactory beanFactory) throws BeansException {
this.beanFactory = beanFactory;
Expand Down Expand Up @@ -445,6 +448,16 @@ public void setBatchToRecordAdapter(BatchToRecordAdapter<K, V> batchToRecordAdap
this.batchToRecordAdapter = batchToRecordAdapter;
}

/**
* Set a custom header name for the correlation id. Default
* {@link org.springframework.kafka.support.KafkaHeaders#CORRELATION_ID}. This header
* will be echoed back in any reply message.
* @param correlationHeaderName the header name.
* @since 3.0
*/
public void setCorrelationHeaderName(String correlationHeaderName) {
this.correlationHeaderName = correlationHeaderName;
}

@Override
public void afterPropertiesSet() {
Expand Down Expand Up @@ -485,9 +498,9 @@ private void setupMessageListener(MessageListenerContainer container,
@Nullable MessageConverter messageConverter) {

MessagingMessageListenerAdapter<K, V> adapter = createMessageListener(container, messageConverter);
if (this.replyHeadersConfigurer != null) {
adapter.setReplyHeadersConfigurer(this.replyHeadersConfigurer);
}
JavaUtils.INSTANCE
.acceptIfNotNull(this.replyHeadersConfigurer, adapter::setReplyHeadersConfigurer)
.acceptIfNotNull(this.correlationHeaderName, adapter::setCorrelationHeaderName);
adapter.setSplitIterables(this.splitIterables);
Object messageListener = adapter;
boolean isBatchListener = isBatchListener();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,30 @@ public abstract class MessagingMessageListenerAdapter<K, V> implements ConsumerS

private boolean splitIterables = true;

private String correlationHeaderName = KafkaHeaders.CORRELATION_ID;

/**
* Create an instance with the provided bean and method.
* @param bean the bean.
* @param method the method.
*/
public MessagingMessageListenerAdapter(Object bean, Method method) {
this.bean = bean;
this.inferredType = determineInferredType(method); // NOSONAR = intentionally not final
}

/**
* Set a custom header name for the correlation id. Default
* {@link KafkaHeaders#CORRELATION_ID}. This header will be echoed back in any reply
* message.
* @param correlationHeaderName the header name.
* @since 3.0
*/
public void setCorrelationHeaderName(String correlationHeaderName) {
Assert.notNull(correlationHeaderName, "'correlationHeaderName' cannot be null");
this.correlationHeaderName = correlationHeaderName;
}

/**
* Set the MessageConverter.
* @param messageConverter the converter.
Expand Down Expand Up @@ -478,7 +497,7 @@ private Message<?> checkHeaders(Object result, String topic, Object source) { //
MessageHeaders headers = reply.getHeaders();
boolean needsTopic = headers.get(KafkaHeaders.TOPIC) == null;
boolean sourceIsMessage = source instanceof Message;
boolean needsCorrelation = headers.get(KafkaHeaders.CORRELATION_ID) == null && sourceIsMessage;
boolean needsCorrelation = headers.get(this.correlationHeaderName) == null && sourceIsMessage;
boolean needsPartition = headers.get(KafkaHeaders.PARTITION) == null && sourceIsMessage
&& getReplyPartition((Message<?>) source) != null;
if (needsTopic || needsCorrelation || needsPartition) {
Expand All @@ -487,8 +506,8 @@ private Message<?> checkHeaders(Object result, String topic, Object source) { //
builder.setHeader(KafkaHeaders.TOPIC, topic);
}
if (needsCorrelation && sourceIsMessage) {
builder.setHeader(KafkaHeaders.CORRELATION_ID,
((Message<?>) source).getHeaders().get(KafkaHeaders.CORRELATION_ID));
builder.setHeader(this.correlationHeaderName,
((Message<?>) source).getHeaders().get(this.correlationHeaderName));
}
if (sourceIsMessage && reply.getHeaders().get(KafkaHeaders.REPLY_PARTITION) == null) {
setPartition(builder, (Message<?>) source);
Expand All @@ -503,8 +522,8 @@ private void sendSingleResult(Object result, String topic, @Nullable Object sour
byte[] correlationId = null;
boolean sourceIsMessage = source instanceof Message;
if (sourceIsMessage
&& ((Message<?>) source).getHeaders().get(KafkaHeaders.CORRELATION_ID) != null) {
correlationId = ((Message<?>) source).getHeaders().get(KafkaHeaders.CORRELATION_ID, byte[].class);
&& ((Message<?>) source).getHeaders().get(this.correlationHeaderName) != null) {
correlationId = ((Message<?>) source).getHeaders().get(this.correlationHeaderName, byte[].class);
}
if (sourceIsMessage) {
sendReplyForMessageSource(result, topic, source, correlationId);
Expand All @@ -515,15 +534,15 @@ private void sendSingleResult(Object result, String topic, @Nullable Object sour
}

@SuppressWarnings("unchecked")
private void sendReplyForMessageSource(Object result, String topic, Object source, byte[] correlationId) {
private void sendReplyForMessageSource(Object result, String topic, Object source, @Nullable byte[] correlationId) {
MessageBuilder<Object> builder = MessageBuilder.withPayload(result)
.setHeader(KafkaHeaders.TOPIC, topic);
if (this.replyHeadersConfigurer != null) {
Map<String, Object> headersToCopy = ((Message<?>) source).getHeaders().entrySet().stream()
.filter(e -> {
String key = e.getKey();
return !key.equals(MessageHeaders.ID) && !key.equals(MessageHeaders.TIMESTAMP)
&& !key.equals(KafkaHeaders.CORRELATION_ID)
&& !key.equals(this.correlationHeaderName)
&& !key.startsWith(KafkaHeaders.RECEIVED);
})
.filter(e -> this.replyHeadersConfigurer.shouldCopy(e.getKey(), e.getValue()))
Expand All @@ -537,7 +556,7 @@ private void sendReplyForMessageSource(Object result, String topic, Object sourc
}
}
if (correlationId != null) {
builder.setHeader(KafkaHeaders.CORRELATION_ID, correlationId);
builder.setHeader(this.correlationHeaderName, correlationId);
}
setPartition(builder, ((Message<?>) source));
this.replyTemplate.send(builder.build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ public void testGoodWithSimpleMapper() throws Exception {
@Test
public void testAggregateNormal() throws Exception {
AggregatingReplyingKafkaTemplate<Integer, String, String> template = aggregatingTemplate(
new TopicPartitionOffset(D_REPLY, 0), 2, new AtomicInteger());
new TopicPartitionOffset(D_REPLY, 0), 3, new AtomicInteger());
try {
template.setCorrelationHeaderName("customCorrelation");
template.setDefaultReplyTimeout(Duration.ofSeconds(30));
Expand All @@ -477,13 +477,17 @@ public void testAggregateNormal() throws Exception {
future.getSendFuture().get(10, TimeUnit.SECONDS); // send ok
ConsumerRecord<Integer, Collection<ConsumerRecord<Integer, String>>> consumerRecord =
future.get(30, TimeUnit.SECONDS);
assertThat(consumerRecord.value().size()).isEqualTo(2);
assertThat(consumerRecord.value().size()).isEqualTo(3);
Iterator<ConsumerRecord<Integer, String>> iterator = consumerRecord.value().iterator();
String value1 = iterator.next().value();
assertThat(value1).isIn("fOO", "FOO");
assertThat(value1).isIn("fOO", "FOO", "Foo");
String value2 = iterator.next().value();
assertThat(value2).isIn("fOO", "FOO");
assertThat(value2).isIn("fOO", "FOO", "Foo");
assertThat(value2).isNotSameAs(value1);
String value3 = iterator.next().value();
assertThat(value3).isIn("fOO", "FOO", "Foo");
assertThat(value3).isNotSameAs(value1);
assertThat(value3).isNotSameAs(value2);
assertThat(consumerRecord.topic()).isEqualTo(AggregatingReplyingKafkaTemplate.AGGREGATED_RESULTS_TOPIC);
}
finally {
Expand All @@ -495,7 +499,7 @@ public void testAggregateNormal() throws Exception {
@Test
public void testAggregateNormalStringCorrelation() throws Exception {
AggregatingReplyingKafkaTemplate<Integer, String, String> template = aggregatingTemplate(
new TopicPartitionOffset(D_REPLY, 0), 2, new AtomicInteger());
new TopicPartitionOffset(D_REPLY, 0), 3, new AtomicInteger());
try {
template.setCorrelationHeaderName("customCorrelation");
template.setBinaryCorrelation(false);
Expand All @@ -506,13 +510,17 @@ public void testAggregateNormalStringCorrelation() throws Exception {
future.getSendFuture().get(10, TimeUnit.SECONDS); // send ok
ConsumerRecord<Integer, Collection<ConsumerRecord<Integer, String>>> consumerRecord =
future.get(30, TimeUnit.SECONDS);
assertThat(consumerRecord.value().size()).isEqualTo(2);
assertThat(consumerRecord.value().size()).isEqualTo(3);
Iterator<ConsumerRecord<Integer, String>> iterator = consumerRecord.value().iterator();
String value1 = iterator.next().value();
assertThat(value1).isIn("fOO", "FOO");
assertThat(value1).isIn("fOO", "FOO", "Foo");
String value2 = iterator.next().value();
assertThat(value2).isIn("fOO", "FOO");
assertThat(value2).isIn("fOO", "FOO", "Foo");
assertThat(value2).isNotSameAs(value1);
String value3 = iterator.next().value();
assertThat(value3).isIn("fOO", "FOO", "Foo");
assertThat(value3).isNotSameAs(value1);
assertThat(value3).isNotSameAs(value2);
assertThat(consumerRecord.topic()).isEqualTo(AggregatingReplyingKafkaTemplate.AGGREGATED_RESULTS_TOPIC);
}
finally {
Expand All @@ -526,7 +534,7 @@ public void testAggregateNormalStringCorrelation() throws Exception {
@Disabled("time sensitive")
public void testAggregateTimeout() throws Exception {
AggregatingReplyingKafkaTemplate<Integer, String, String> template = aggregatingTemplate(
new TopicPartitionOffset(E_REPLY, 0), 3, new AtomicInteger());
new TopicPartitionOffset(E_REPLY, 0), 4, new AtomicInteger());
try {
template.setDefaultReplyTimeout(Duration.ofSeconds(5));
template.setCorrelationHeaderName("customCorrelation");
Expand Down Expand Up @@ -561,7 +569,7 @@ public void testAggregateTimeout() throws Exception {
public void testAggregateTimeoutPartial() throws Exception {
AtomicInteger releaseCount = new AtomicInteger();
AggregatingReplyingKafkaTemplate<Integer, String, String> template = aggregatingTemplate(
new TopicPartitionOffset(F_REPLY, 0), 3, releaseCount);
new TopicPartitionOffset(F_REPLY, 0), 4, releaseCount);
template.setReturnPartialOnTimeout(true);
try {
template.setDefaultReplyTimeout(Duration.ofSeconds(5));
Expand All @@ -572,16 +580,20 @@ public void testAggregateTimeoutPartial() throws Exception {
future.getSendFuture().get(10, TimeUnit.SECONDS); // send ok
ConsumerRecord<Integer, Collection<ConsumerRecord<Integer, String>>> consumerRecord =
future.get(30, TimeUnit.SECONDS);
assertThat(consumerRecord.value().size()).isEqualTo(2);
assertThat(consumerRecord.value().size()).isEqualTo(3);
Iterator<ConsumerRecord<Integer, String>> iterator = consumerRecord.value().iterator();
String value1 = iterator.next().value();
assertThat(value1).isIn("fOO", "FOO");
assertThat(value1).isIn("fOO", "FOO", "Foo");
String value2 = iterator.next().value();
assertThat(value2).isIn("fOO", "FOO");
assertThat(value2).isIn("fOO", "FOO", "Foo");
assertThat(value2).isNotSameAs(value1);
String value3 = iterator.next().value();
assertThat(value3).isIn("fOO", "FOO", "Foo");
assertThat(value3).isNotSameAs(value1);
assertThat(value3).isNotSameAs(value2);
assertThat(consumerRecord.topic())
.isEqualTo(AggregatingReplyingKafkaTemplate.PARTIAL_RESULTS_AFTER_TIMEOUT_TOPIC);
assertThat(releaseCount.get()).isEqualTo(3);
assertThat(releaseCount.get()).isEqualTo(4);
}
finally {
template.stop();
Expand All @@ -593,7 +605,7 @@ public void testAggregateTimeoutPartial() throws Exception {
public void testAggregateTimeoutPartialStringCorrelation() throws Exception {
AtomicInteger releaseCount = new AtomicInteger();
AggregatingReplyingKafkaTemplate<Integer, String, String> template = aggregatingTemplate(
new TopicPartitionOffset(F_REPLY, 0), 3, releaseCount);
new TopicPartitionOffset(F_REPLY, 0), 4, releaseCount);
template.setReturnPartialOnTimeout(true);
template.setBinaryCorrelation(false);
try {
Expand All @@ -605,16 +617,20 @@ public void testAggregateTimeoutPartialStringCorrelation() throws Exception {
future.getSendFuture().get(10, TimeUnit.SECONDS); // send ok
ConsumerRecord<Integer, Collection<ConsumerRecord<Integer, String>>> consumerRecord =
future.get(30, TimeUnit.SECONDS);
assertThat(consumerRecord.value().size()).isEqualTo(2);
assertThat(consumerRecord.value().size()).isEqualTo(3);
Iterator<ConsumerRecord<Integer, String>> iterator = consumerRecord.value().iterator();
String value1 = iterator.next().value();
assertThat(value1).isIn("fOO", "FOO");
assertThat(value1).isIn("fOO", "FOO", "Foo");
String value2 = iterator.next().value();
assertThat(value2).isIn("fOO", "FOO");
assertThat(value2).isIn("fOO", "FOO", "Foo");
assertThat(value2).isNotSameAs(value1);
String value3 = iterator.next().value();
assertThat(value3).isIn("fOO", "FOO", "Foo");
assertThat(value3).isNotSameAs(value1);
assertThat(value3).isNotSameAs(value2);
assertThat(consumerRecord.topic())
.isEqualTo(AggregatingReplyingKafkaTemplate.PARTIAL_RESULTS_AFTER_TIMEOUT_TOPIC);
assertThat(releaseCount.get()).isEqualTo(3);
assertThat(releaseCount.get()).isEqualTo(4);
}
finally {
template.stop();
Expand Down Expand Up @@ -865,6 +881,17 @@ public ConcurrentKafkaListenerContainerFactory<Integer, String> kafkaListenerCon
return factory;
}

@Bean
public ConcurrentKafkaListenerContainerFactory<Integer, String> customListenerContainerFactory() {
ConcurrentKafkaListenerContainerFactory<Integer, String> factory =
new ConcurrentKafkaListenerContainerFactory<>();
factory.setConsumerFactory(cf());
factory.setReplyTemplate(template());
factory.setCorrelationHeaderName("customCorrelation");
factory.setMissingTopicsFatal(false);
return factory;
}

@Bean
public ConcurrentKafkaListenerContainerFactory<Integer, String> simpleMapperFactory() {
ConcurrentKafkaListenerContainerFactory<Integer, String> factory =
Expand Down Expand Up @@ -921,22 +948,30 @@ public HandlerReturn handlerReturn() {
return new HandlerReturn();
}

@KafkaListener(id = "def1", topics = { D_REQUEST, E_REQUEST, F_REQUEST })
@KafkaListener(id = "def1", topics = { D_REQUEST, E_REQUEST, F_REQUEST },
containerFactory = "customListenerContainerFactory")
@SendTo // default REPLY_TOPIC header
public Message<String> dListener1(String in, @Header("customCorrelation") byte[] correlation) {
return MessageBuilder.withPayload(in.toUpperCase())
.setHeader("customCorrelation", correlation)
.build();
}

@KafkaListener(id = "def2", topics = { D_REQUEST, E_REQUEST, F_REQUEST })
@KafkaListener(id = "def2", topics = { D_REQUEST, E_REQUEST, F_REQUEST },
containerFactory = "customListenerContainerFactory")
@SendTo // default REPLY_TOPIC header
public Message<String> dListener2(String in, @Header("customCorrelation") byte[] correlation) {
public Message<String> dListener2(String in) {
return MessageBuilder.withPayload(in.substring(0, 1) + in.substring(1).toUpperCase())
.setHeader("customCorrelation", correlation)
.build();
}

@KafkaListener(id = "def3", topics = { D_REQUEST, E_REQUEST, F_REQUEST },
containerFactory = "customListenerContainerFactory")
@SendTo // default REPLY_TOPIC header
public String dListener3(String in) {
return in.substring(0, 1).toUpperCase() + in.substring(1);
}

@KafkaListener(id = G_REQUEST, topics = G_REQUEST)
public void gListener(Message<String> in) {
String replyTopic = new String(in.getHeaders().get("custom.reply.to", byte[].class));
Expand Down

0 comments on commit 8a23e47

Please sign in to comment.