diff --git a/spring-kafka/src/main/java/org/springframework/kafka/listener/KafkaMessageListenerContainer.java b/spring-kafka/src/main/java/org/springframework/kafka/listener/KafkaMessageListenerContainer.java index b1d033df7a..1d0b8a6546 100644 --- a/spring-kafka/src/main/java/org/springframework/kafka/listener/KafkaMessageListenerContainer.java +++ b/spring-kafka/src/main/java/org/springframework/kafka/listener/KafkaMessageListenerContainer.java @@ -243,6 +243,8 @@ private final class ListenerConsumer implements SchedulingAwareRunnable, Consume private final Map> offsets = new HashMap<>(); + private final GenericMessageListener genericListener; + private final MessageListener listener; private final BatchMessageListener batchListener; @@ -311,6 +313,7 @@ private final class ListenerConsumer implements SchedulingAwareRunnable, Consume } this.consumer = consumer; GenericErrorHandler errHandler = this.containerProperties.getGenericErrorHandler(); + this.genericListener = listener; if (listener instanceof BatchMessageListener) { this.listener = null; this.batchListener = (BatchMessageListener) listener; @@ -391,7 +394,7 @@ public void onPartitionsAssigned(Collection partitions) { KafkaMessageListenerContainer.this.getContainerProperties().getCommitCallback()); } } - if (ListenerConsumer.this.listener instanceof ConsumerSeekAware) { + if (ListenerConsumer.this.genericListener instanceof ConsumerSeekAware) { seekPartitions(partitions, false); } if (this.consumerAwareListener != null) { @@ -431,10 +434,10 @@ public void seekToEnd(String topic, int partition) { }; if (idle) { - ((ConsumerSeekAware) ListenerConsumer.this.listener).onIdleContainer(current, callback); + ((ConsumerSeekAware) ListenerConsumer.this.genericListener).onIdleContainer(current, callback); } else { - ((ConsumerSeekAware) ListenerConsumer.this.listener).onPartitionsAssigned(current, callback); + ((ConsumerSeekAware) ListenerConsumer.this.genericListener).onPartitionsAssigned(current, callback); } } @@ -466,8 +469,8 @@ public boolean isLongLived() { @Override public void run() { - if (this.listener instanceof ConsumerSeekAware) { - ((ConsumerSeekAware) this.listener).registerSeekCallback(this); + if (this.genericListener instanceof ConsumerSeekAware) { + ((ConsumerSeekAware) this.genericListener).registerSeekCallback(this); } this.count = 0; this.last = System.currentTimeMillis(); @@ -500,7 +503,7 @@ public void run() { publishIdleContainerEvent(now - lastReceive, this.isConsumerAwareListener ? this.consumer : null); lastAlertAt = now; - if (this.listener instanceof ConsumerSeekAware) { + if (this.genericListener instanceof ConsumerSeekAware) { seekPartitions(getAssignedPartitions(), true); } } diff --git a/spring-kafka/src/test/java/org/springframework/kafka/listener/KafkaMessageListenerContainerTests.java b/spring-kafka/src/test/java/org/springframework/kafka/listener/KafkaMessageListenerContainerTests.java index 15168222a6..fcc3a32697 100644 --- a/spring-kafka/src/test/java/org/springframework/kafka/listener/KafkaMessageListenerContainerTests.java +++ b/spring-kafka/src/test/java/org/springframework/kafka/listener/KafkaMessageListenerContainerTests.java @@ -105,9 +105,13 @@ public class KafkaMessageListenerContainerTests { private static String topic14 = "testTopic14"; + private static String topic15 = "testTopic15"; + + private static String topic16 = "testTopic16"; + @ClassRule public static KafkaEmbedded embeddedKafka = new KafkaEmbedded(1, true, topic3, topic4, topic5, - topic6, topic7, topic8, topic9, topic10, topic11, topic12, topic13, topic14); + topic6, topic7, topic8, topic9, topic10, topic11, topic12, topic13, topic14, topic15, topic16); @Rule public TestName testName = new TestName(); @@ -724,15 +728,62 @@ public void testSeekAutoCommit() throws Exception { @Test public void testSeekAutoCommitDefault() throws Exception { - Map props = KafkaTestUtils.consumerProps("test12", "true", embeddedKafka); + Map props = KafkaTestUtils.consumerProps("test15", "true", embeddedKafka); props.remove(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG); // test true by default - testSeekGuts(props, topic12, true); + testSeekGuts(props, topic15, true); + } + + @Test + public void testSeekBatch() throws Exception { + logger.info("Start seek batch seek"); + Map props = KafkaTestUtils.consumerProps("test16", "true", embeddedKafka); + DefaultKafkaConsumerFactory cf = new DefaultKafkaConsumerFactory<>(props); + ContainerProperties containerProps = new ContainerProperties(topic16); + final CountDownLatch registerLatch = new CountDownLatch(1); + final CountDownLatch assignedLatch = new CountDownLatch(1); + final CountDownLatch idleLatch = new CountDownLatch(1); + class Listener implements BatchMessageListener, ConsumerSeekAware { + + @Override + public void onMessage(List> data) { + // empty + } + + @Override + public void registerSeekCallback(ConsumerSeekCallback callback) { + registerLatch.countDown(); + } + + @Override + public void onPartitionsAssigned(Map assignments, ConsumerSeekCallback callback) { + assignedLatch.countDown(); + } + + @Override + public void onIdleContainer(Map assignments, ConsumerSeekCallback callback) { + idleLatch.countDown(); + } + + } + Listener messageListener = new Listener(); + containerProps.setMessageListener(messageListener); + containerProps.setSyncCommits(true); + containerProps.setAckOnError(false); + containerProps.setIdleEventInterval(10L); + KafkaMessageListenerContainer container = new KafkaMessageListenerContainer<>(cf, + containerProps); + container.setBeanName("testBatchSeek"); + container.start(); + assertThat(registerLatch.await(10, TimeUnit.SECONDS)).isTrue(); + assertThat(assignedLatch.await(10, TimeUnit.SECONDS)).isTrue(); + assertThat(idleLatch.await(10, TimeUnit.SECONDS)).isTrue(); + container.stop(); } private void testSeekGuts(Map props, String topic, boolean autoCommit) throws Exception { logger.info("Start seek " + topic); DefaultKafkaConsumerFactory cf = new DefaultKafkaConsumerFactory<>(props); - ContainerProperties containerProps = new ContainerProperties(topic11); + ContainerProperties containerProps = new ContainerProperties(topic); final AtomicReference latch = new AtomicReference<>(new CountDownLatch(6)); final AtomicBoolean seekInitial = new AtomicBoolean(); final CountDownLatch idleLatch = new CountDownLatch(1); @@ -749,10 +800,10 @@ public void onMessage(ConsumerRecord data) { messageThread = Thread.currentThread(); latch.get().countDown(); if (latch.get().getCount() == 2 && !seekInitial.get()) { - callback.seekToEnd(topic11, 0); - callback.seekToBeginning(topic11, 0); - callback.seek(topic11, 0, 1); - callback.seek(topic11, 1, 1); + callback.seekToEnd(topic, 0); + callback.seekToBeginning(topic, 0); + callback.seek(topic, 0, 1); + callback.seek(topic, 1, 1); } } @@ -792,7 +843,7 @@ public void onIdleContainer(Map assignments, ConsumerSeekC KafkaMessageListenerContainer container = new KafkaMessageListenerContainer<>(cf, containerProps); - container.setBeanName("testRecordAcks"); + container.setBeanName("testSeek" + topic); container.start(); assertThat(KafkaTestUtils.getPropertyValue(container, "listenerConsumer.autoCommit", Boolean.class)) .isEqualTo(autoCommit); @@ -801,7 +852,7 @@ public void onIdleContainer(Map assignments, ConsumerSeekC Map senderProps = KafkaTestUtils.producerProps(embeddedKafka); ProducerFactory pf = new DefaultKafkaProducerFactory<>(senderProps); KafkaTemplate template = new KafkaTemplate<>(pf); - template.setDefaultTopic(topic11); + template.setDefaultTopic(topic); template.sendDefault(0, 0, "foo"); template.sendDefault(1, 0, "bar"); template.sendDefault(0, 0, "baz"); @@ -843,11 +894,11 @@ public void publishEvent(ApplicationEvent event) { ArgumentCaptor> captor = ArgumentCaptor.forClass(Collection.class); verify(consumer).seekToBeginning(captor.capture()); TopicPartition next = captor.getValue().iterator().next(); - assertThat(next.topic()).isEqualTo(topic11); + assertThat(next.topic()).isEqualTo(topic); assertThat(next.partition()).isEqualTo(0); verify(consumer).seekToEnd(captor.capture()); next = captor.getValue().iterator().next(); - assertThat(next.topic()).isEqualTo(topic11); + assertThat(next.topic()).isEqualTo(topic); assertThat(next.partition()).isEqualTo(0); logger.info("Stop seek"); }