Skip to content

Commit

Permalink
KAFKA-15555: Ensure wakeups are handled correctly in poll() (apache#1…
Browse files Browse the repository at this point in the history
…4746)

We need to be careful when aborting a long poll with wakeup() since the
consumer might never return records if the poll is interrupted after the
consumer position has been updated but the records have not been returned
to the caller of poll().

This PR avoid wake-ups during this critical period.

Reviewers: Philip Nee <pnee@confluent.io>, Kirk True <ktrue@confluent.io>, Lucas Brutschy <lbrutschy@confluent.io>
  • Loading branch information
cadonna authored and yyu1993 committed Feb 15, 2024
1 parent e59177f commit e799e24
Show file tree
Hide file tree
Showing 8 changed files with 271 additions and 11 deletions.
1 change: 1 addition & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -1423,6 +1423,7 @@ project(':clients') {
testImplementation libs.junitJupiter
testImplementation libs.log4j
testImplementation libs.mockitoCore
testImplementation libs.mockitoJunitJupiter // supports MockitoExtension

testRuntimeOnly libs.slf4jlog4j
testRuntimeOnly libs.jacksonDatabind
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -407,23 +407,44 @@ public class AsyncKafkaConsumer<K, V> implements ConsumerDelegate<K, V> {
*
* @param timeout timeout of the poll loop
* @return ConsumerRecord. It can be empty if time timeout expires.
*
* @throws org.apache.kafka.common.errors.WakeupException if {@link #wakeup()} is called before or while this
* function is called
* @throws org.apache.kafka.common.errors.InterruptException if the calling thread is interrupted before or while
* this function is called
* @throws org.apache.kafka.common.errors.RecordTooLargeException if the fetched record is larger than the maximum
* allowable size
* @throws org.apache.kafka.common.KafkaException for any other unrecoverable errors
* @throws java.lang.IllegalStateException if the consumer is not subscribed to any topics or manually assigned any
* partitions to consume from or an unexpected error occurred
* @throws org.apache.kafka.clients.consumer.OffsetOutOfRangeException if the fetch position of the consumer is
* out of range and no offset reset policy is configured.
* @throws org.apache.kafka.common.errors.TopicAuthorizationException if the consumer is not authorized to read
* from a partition
* @throws org.apache.kafka.common.errors.SerializationException if the fetched records cannot be deserialized
*/
@Override
public ConsumerRecords<K, V> poll(final Duration timeout) {
Timer timer = time.timer(timeout);

acquireAndEnsureOpen();
try {
wakeupTrigger.setFetchAction(fetchBuffer);
kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());

if (subscriptions.hasNoSubscriptionOrUserAssignment()) {
throw new IllegalStateException("Consumer is not subscribed to any topics or assigned any partitions");
}

do {
// We must not allow wake-ups between polling for fetches and returning the records.
// If the polled fetches are not empty the consumed position has already been updated in the polling
// of the fetches. A wakeup between returned fetches and returning records would lead to never
// returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.
wakeupTrigger.maybeTriggerWakeup();

updateAssignmentMetadataIfNeeded(timer);
final Fetch<K, V> fetch = pollForFetches(timer);

if (!fetch.isEmpty()) {
if (fetch.records().isEmpty()) {
log.trace("Returning empty records from `poll()` "
Expand All @@ -438,6 +459,7 @@ public ConsumerRecords<K, V> poll(final Duration timeout) {
return ConsumerRecords.empty();
} finally {
kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());
wakeupTrigger.clearTask();
release();
}
}
Expand Down Expand Up @@ -636,7 +658,7 @@ public Map<TopicPartition, OffsetAndMetadata> committed(final Set<TopicPartition
try {
return applicationEventHandler.addAndGet(event, time.timer(timeout));
} finally {
wakeupTrigger.clearActiveTask();
wakeupTrigger.clearTask();
}
} finally {
release();
Expand Down Expand Up @@ -922,7 +944,7 @@ public void commitSync(Map<TopicPartition, OffsetAndMetadata> offsets, Duration
offsets.forEach(this::updateLastSeenEpochIfNewer);
ConsumerUtils.getResult(commitFuture, time.timer(timeout));
} finally {
wakeupTrigger.clearActiveTask();
wakeupTrigger.clearTask();
kafkaConsumerMetrics.recordCommitSync(time.nanoseconds() - commitStart);
release();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import java.util.Set;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
Expand All @@ -52,6 +53,8 @@ public class FetchBuffer implements AutoCloseable {
private final Condition notEmptyCondition;
private final IdempotentCloser idempotentCloser = new IdempotentCloser();

private final AtomicBoolean wokenup = new AtomicBoolean(false);

private CompletedFetch nextInLineFetch;

public FetchBuffer(final LogContext logContext) {
Expand Down Expand Up @@ -166,7 +169,7 @@ void awaitNotEmpty(Timer timer) {
try {
lock.lock();

while (isEmpty()) {
while (isEmpty() && !wokenup.compareAndSet(true, false)) {
// Update the timer before we head into the loop in case it took a while to get the lock.
timer.update();

Expand All @@ -185,6 +188,16 @@ void awaitNotEmpty(Timer timer) {
}
}

void wakeup() {
wokenup.set(true);
try {
lock.lock();
notEmptyCondition.signalAll();
} finally {
lock.unlock();
}
}

/**
* Updates the buffer to retain only the fetch data that corresponds to the given partitions. Any previously
* {@link CompletedFetch fetched data} is removed if its partition is not in the given set of partitions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;

/**
Expand All @@ -44,6 +45,10 @@ public void wakeup() {
ActiveFuture active = (ActiveFuture) task;
active.future().completeExceptionally(new WakeupException());
return null;
} else if (task instanceof FetchAction) {
FetchAction fetchAction = (FetchAction) task;
fetchAction.fetchBuffer().wakeup();
return new WakeupFuture();
} else {
return task;
}
Expand Down Expand Up @@ -75,17 +80,51 @@ public <T> CompletableFuture<T> setActiveTask(final CompletableFuture<T> current
return currentTask;
}

public void clearActiveTask() {
public void setFetchAction(final FetchBuffer fetchBuffer) {
final AtomicBoolean throwWakeupException = new AtomicBoolean(false);
pendingTask.getAndUpdate(task -> {
if (task == null) {
return new FetchAction(fetchBuffer);
} else if (task instanceof WakeupFuture) {
throwWakeupException.set(true);
return null;
} else if (task instanceof ActiveFuture) {
}
// last active state is still active
throw new IllegalStateException("Last active task is still active");
});
if (throwWakeupException.get()) {
throw new WakeupException();
}
}

public void clearTask() {
pendingTask.getAndUpdate(task -> {
if (task == null) {
return null;
} else if (task instanceof ActiveFuture || task instanceof FetchAction) {
return null;
}
return task;
});
}

public void maybeTriggerWakeup() {
final AtomicBoolean throwWakeupException = new AtomicBoolean(false);
pendingTask.getAndUpdate(task -> {
if (task == null) {
return null;
} else if (task instanceof WakeupFuture) {
throwWakeupException.set(true);
return null;
} else {
return task;
}
});
if (throwWakeupException.get()) {
throw new WakeupException();
}
}

Wakeupable getPendingTask() {
return pendingTask.get();
}
Expand All @@ -105,4 +144,17 @@ public CompletableFuture<?> future() {
}

static class WakeupFuture implements Wakeupable { }

static class FetchAction implements Wakeupable {

private final FetchBuffer fetchBuffer;

public FetchAction(FetchBuffer fetchBuffer) {
this.fetchBuffer = fetchBuffer;
}

public FetchBuffer fetchBuffer() {
return fetchBuffer;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.kafka.clients.consumer.internals;

import org.apache.kafka.clients.consumer.ConsumerRecord;
import org.apache.kafka.clients.consumer.OffsetAndMetadata;
import org.apache.kafka.clients.consumer.OffsetAndTimestamp;
import org.apache.kafka.clients.consumer.OffsetCommitCallback;
Expand Down Expand Up @@ -51,6 +52,7 @@
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatchers;
import org.mockito.MockedConstruction;
import org.mockito.Mockito;
import org.mockito.stubbing.Answer;

import java.time.Duration;
Expand All @@ -69,8 +71,11 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static java.util.Arrays.asList;
import static java.util.Collections.singleton;
import static java.util.Collections.singletonList;
import static org.apache.kafka.common.utils.Utils.mkEntry;
import static org.apache.kafka.common.utils.Utils.mkMap;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
Expand All @@ -80,6 +85,7 @@
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mockConstruction;
Expand All @@ -90,6 +96,7 @@
public class AsyncKafkaConsumerTest {

private AsyncKafkaConsumer<?, ?> consumer;
private FetchCollector<?, ?> fetchCollector;
private ConsumerTestBuilder.AsyncKafkaConsumerTestBuilder testBuilder;
private ApplicationEventHandler applicationEventHandler;

Expand All @@ -103,6 +110,7 @@ private void setup(Optional<ConsumerTestBuilder.GroupInformation> groupInfo) {
testBuilder = new ConsumerTestBuilder.AsyncKafkaConsumerTestBuilder(groupInfo);
applicationEventHandler = testBuilder.applicationEventHandler;
consumer = testBuilder.consumer;
fetchCollector = testBuilder.fetchCollector;
}

@AfterEach
Expand Down Expand Up @@ -216,6 +224,82 @@ public void testCommitted_ExceptionThrown() {
}
}

@Test
public void testWakeupBeforeCallingPoll() {
final String topicName = "foo";
final int partition = 3;
final TopicPartition tp = new TopicPartition(topicName, partition);
doReturn(Fetch.empty()).when(fetchCollector).collectFetch(any(FetchBuffer.class));
Map<TopicPartition, OffsetAndMetadata> offsets = mkMap(mkEntry(tp, new OffsetAndMetadata(1)));
doReturn(offsets).when(applicationEventHandler).addAndGet(any(OffsetFetchApplicationEvent.class), any(Timer.class));
consumer.assign(singleton(tp));

consumer.wakeup();

assertThrows(WakeupException.class, () -> consumer.poll(Duration.ZERO));
assertDoesNotThrow(() -> consumer.poll(Duration.ZERO));
}

@Test
public void testWakeupAfterEmptyFetch() {
final String topicName = "foo";
final int partition = 3;
final TopicPartition tp = new TopicPartition(topicName, partition);
doAnswer(invocation -> {
consumer.wakeup();
return Fetch.empty();
}).when(fetchCollector).collectFetch(any(FetchBuffer.class));
Map<TopicPartition, OffsetAndMetadata> offsets = mkMap(mkEntry(tp, new OffsetAndMetadata(1)));
doReturn(offsets).when(applicationEventHandler).addAndGet(any(OffsetFetchApplicationEvent.class), any(Timer.class));
consumer.assign(singleton(tp));

assertThrows(WakeupException.class, () -> consumer.poll(Duration.ofMinutes(1)));
assertDoesNotThrow(() -> consumer.poll(Duration.ZERO));
}

@Test
public void testWakeupAfterNonEmptyFetch() {
final String topicName = "foo";
final int partition = 3;
final TopicPartition tp = new TopicPartition(topicName, partition);
final List<ConsumerRecord<String, String>> records = asList(
new ConsumerRecord<>(topicName, partition, 2, "key1", "value1"),
new ConsumerRecord<>(topicName, partition, 3, "key2", "value2")
);
doAnswer(invocation -> {
consumer.wakeup();
return Fetch.forPartition(tp, records, true);
}).when(fetchCollector).collectFetch(Mockito.any(FetchBuffer.class));
Map<TopicPartition, OffsetAndMetadata> offsets = mkMap(mkEntry(tp, new OffsetAndMetadata(1)));
doReturn(offsets).when(applicationEventHandler).addAndGet(any(OffsetFetchApplicationEvent.class), any(Timer.class));
consumer.assign(singleton(tp));

// since wakeup() is called when the non-empty fetch is returned the wakeup should be ignored
assertDoesNotThrow(() -> consumer.poll(Duration.ofMinutes(1)));
// the previously ignored wake-up should not be ignored in the next call
assertThrows(WakeupException.class, () -> consumer.poll(Duration.ZERO));
}

@Test
public void testClearWakeupTriggerAfterPoll() {
final String topicName = "foo";
final int partition = 3;
final TopicPartition tp = new TopicPartition(topicName, partition);
final List<ConsumerRecord<String, String>> records = asList(
new ConsumerRecord<>(topicName, partition, 2, "key1", "value1"),
new ConsumerRecord<>(topicName, partition, 3, "key2", "value2")
);
doReturn(Fetch.forPartition(tp, records, true))
.when(fetchCollector).collectFetch(any(FetchBuffer.class));
Map<TopicPartition, OffsetAndMetadata> offsets = mkMap(mkEntry(tp, new OffsetAndMetadata(1)));
doReturn(offsets).when(applicationEventHandler).addAndGet(any(OffsetFetchApplicationEvent.class), any(Timer.class));
consumer.assign(singleton(tp));

consumer.poll(Duration.ZERO);

assertDoesNotThrow(() -> consumer.poll(Duration.ZERO));
}

@Test
public void testEnsureCallbackExecutedByApplicationThread() {
final String currentThread = Thread.currentThread().getName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ public static class AsyncKafkaConsumerTestBuilder extends ApplicationEventHandle

final AsyncKafkaConsumer<String, String> consumer;

final FetchCollector<String, String> fetchCollector;

public AsyncKafkaConsumerTestBuilder(Optional<GroupInformation> groupInfo) {
super(groupInfo);
String clientId = config.getString(CommonClientConfigs.CLIENT_ID_CONFIG);
Expand All @@ -320,13 +322,13 @@ public AsyncKafkaConsumerTestBuilder(Optional<GroupInformation> groupInfo) {
config.originals(Collections.singletonMap(ConsumerConfig.CLIENT_ID_CONFIG, clientId))
);
Deserializers<String, String> deserializers = new Deserializers<>(new StringDeserializer(), new StringDeserializer());
FetchCollector<String, String> fetchCollector = new FetchCollector<>(logContext,
this.fetchCollector = spy(new FetchCollector<>(logContext,
metadata,
subscriptions,
fetchConfig,
deserializers,
metricsManager,
time);
time));
this.consumer = spy(new AsyncKafkaConsumer<>(
logContext,
clientId,
Expand Down
Loading

0 comments on commit e799e24

Please sign in to comment.