Skip to content

Commit

Permalink
KAFKA-15100; KRaft data race with the expiration service (apache#14141)
Browse files Browse the repository at this point in the history
The KRaft client uses an expiration service to complete FETCH requests that have timed out. This expiration service uses a different thread from the KRaft polling thread. This means that it is unsafe for the expiration service thread to call tryCompleteFetchRequest. tryCompleteFetchRequest reads and updates a lot of states that is assumed to be only be read and updated from the polling thread.

The KRaft client now does not call tryCompleteFetchRequest when the FETCH request has expired. It instead will send the FETCH response that was computed when the FETCH request was first handled.

This change also fixes a bug where the KRaft client was not sending the FETCH response immediately, if the response contained a diverging epoch or snapshot id.

Reviewers: Jason Gustafson <jason@confluent.io>
  • Loading branch information
jsancio authored and rreddy-22 committed Sep 20, 2023
1 parent 4a0bb51 commit 0473851
Show file tree
Hide file tree
Showing 9 changed files with 152 additions and 37 deletions.
6 changes: 5 additions & 1 deletion raft/src/main/java/org/apache/kafka/raft/ElectionState.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public static ElectionState withUnknownLeader(int epoch, Set<Integer> voters) {
public boolean isLeader(int nodeId) {
if (nodeId < 0)
throw new IllegalArgumentException("Invalid negative nodeId: " + nodeId);
return leaderIdOpt.orElse(-1) == nodeId;
return leaderIdOrSentinel() == nodeId;
}

public boolean isVotedCandidate(int nodeId) {
Expand Down Expand Up @@ -94,6 +94,10 @@ public boolean hasVoted() {
return votedIdOpt.isPresent();
}

public int leaderIdOrSentinel() {
return leaderIdOpt.orElse(-1);
}


@Override
public String toString() {
Expand Down
49 changes: 39 additions & 10 deletions raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -879,9 +879,9 @@ private FetchResponseData buildFetchResponse(
.setRecords(records)
.setErrorCode(error.code())
.setLogStartOffset(log.startOffset())
.setHighWatermark(highWatermark
.map(offsetMetadata -> offsetMetadata.offset)
.orElse(-1L));
.setHighWatermark(
highWatermark.map(offsetMetadata -> offsetMetadata.offset).orElse(-1L)
);

partitionData.currentLeader()
.setLeaderEpoch(quorum.epoch())
Expand Down Expand Up @@ -960,8 +960,9 @@ private CompletableFuture<FetchResponseData> handleFetchRequest(
|| fetchPartition.fetchOffset() < 0
|| fetchPartition.lastFetchedEpoch() < 0
|| fetchPartition.lastFetchedEpoch() > fetchPartition.currentLeaderEpoch()) {
return completedFuture(buildEmptyFetchResponse(
Errors.INVALID_REQUEST, Optional.empty()));
return completedFuture(
buildEmptyFetchResponse(Errors.INVALID_REQUEST, Optional.empty())
);
}

int replicaId = FetchRequest.replicaId(request);
Expand All @@ -971,7 +972,15 @@ private CompletableFuture<FetchResponseData> handleFetchRequest(

if (partitionResponse.errorCode() != Errors.NONE.code()
|| FetchResponse.recordsSize(partitionResponse) > 0
|| request.maxWaitMs() == 0) {
|| request.maxWaitMs() == 0
|| isPartitionDiverged(partitionResponse)
|| isPartitionSnapshotted(partitionResponse)) {
// Reply immediately if any of the following is true
// 1. The response contains an errror
// 2. There are records in the response
// 3. The fetching replica doesn't want to wait for the partition to contain new data
// 4. The fetching replica needs to truncate because the log diverged
// 5. The fetching replica needs to fetch a snapshot
return completedFuture(response);
}

Expand All @@ -984,11 +993,16 @@ private CompletableFuture<FetchResponseData> handleFetchRequest(
Throwable cause = exception instanceof ExecutionException ?
exception.getCause() : exception;

// If the fetch timed out in purgatory, it means no new data is available,
// and we will complete the fetch successfully. Otherwise, if there was
// any other error, we need to return it.
Errors error = Errors.forException(cause);
if (error != Errors.REQUEST_TIMED_OUT) {
if (error == Errors.REQUEST_TIMED_OUT) {
// Note that for this case the calling thread is the expiration service thread and not the
// polling thread.
//
// If the fetch request timed out in purgatory, it means no new data is available,
// just return the original fetch response.
return response;
} else {
// If there was any error other than REQUEST_TIMED_OUT, return it.
logger.info("Failed to handle fetch from {} at {} due to {}",
replicaId, fetchPartition.fetchOffset(), error);
return buildEmptyFetchResponse(error, Optional.empty());
Expand All @@ -999,6 +1013,9 @@ private CompletableFuture<FetchResponseData> handleFetchRequest(
logger.trace("Completing delayed fetch from {} starting at offset {} at {}",
replicaId, fetchPartition.fetchOffset(), completionTimeMs);

// It is safe to call tryCompleteFetchRequest because only the polling thread completes this
// future successfully. This is true because only the polling thread appends record batches to
// the log from maybeAppendBatches.
return tryCompleteFetchRequest(replicaId, fetchPartition, time.milliseconds());
});
}
Expand Down Expand Up @@ -1048,6 +1065,18 @@ private FetchResponseData tryCompleteFetchRequest(
}
}

private static boolean isPartitionDiverged(FetchResponseData.PartitionData partitionResponseData) {
FetchResponseData.EpochEndOffset divergingEpoch = partitionResponseData.divergingEpoch();

return divergingEpoch.epoch() != -1 || divergingEpoch.endOffset() != -1;
}

private static boolean isPartitionSnapshotted(FetchResponseData.PartitionData partitionResponseData) {
FetchResponseData.SnapshotId snapshotId = partitionResponseData.snapshotId();

return snapshotId.epoch() != -1 || snapshotId.endOffset() != -1;
}

private static OptionalInt optionalLeaderId(int leaderIdOrNil) {
if (leaderIdOrNil < 0)
return OptionalInt.empty();
Expand Down
5 changes: 2 additions & 3 deletions raft/src/main/java/org/apache/kafka/raft/LeaderState.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public class LeaderState<T> implements EpochState {
private final long epochStartOffset;
private final Set<Integer> grantingVoters;

private Optional<LogOffsetMetadata> highWatermark;
private Optional<LogOffsetMetadata> highWatermark = Optional.empty();
private final Map<Integer, ReplicaState> voterStates = new HashMap<>();
private final Map<Integer, ReplicaState> observerStates = new HashMap<>();
private final Logger log;
Expand All @@ -71,7 +71,6 @@ protected LeaderState(
this.localId = localId;
this.epoch = epoch;
this.epochStartOffset = epochStartOffset;
this.highWatermark = Optional.empty();

for (int voterId : voters) {
boolean hasAcknowledgedLeader = voterId == localId;
Expand Down Expand Up @@ -337,7 +336,7 @@ public DescribeQuorumResponseData.PartitionData describeQuorum(long currentTimeM
.setErrorCode(Errors.NONE.code())
.setLeaderId(localId)
.setLeaderEpoch(epoch)
.setHighWatermark(highWatermark().map(offsetMetadata -> offsetMetadata.offset).orElse(-1L))
.setHighWatermark(highWatermark.map(offsetMetadata -> offsetMetadata.offset).orElse(-1L))
.setCurrentVoters(describeReplicaStates(voterStates, currentTimeMs))
.setObservers(describeReplicaStates(observerStates, currentTimeMs));
}
Expand Down
3 changes: 1 addition & 2 deletions raft/src/main/java/org/apache/kafka/raft/QuorumState.java
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ public int epoch() {
}

public int leaderIdOrSentinel() {
return leaderId().orElse(-1);
return state.election().leaderIdOrSentinel();
}

public Optional<LogOffsetMetadata> highWatermark() {
Expand Down Expand Up @@ -570,5 +570,4 @@ public boolean isResigned() {
public boolean isCandidate() {
return state instanceof CandidateState;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,38 @@ public void testListenerRenotified() throws Exception {
}
}

@Test
public void testLeaderImmediatelySendsSnapshotId() throws Exception {
int localId = 0;
int otherNodeId = 1;
Set<Integer> voters = Utils.mkSet(localId, otherNodeId);
OffsetAndEpoch snapshotId = new OffsetAndEpoch(3, 4);

RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters)
.withUnknownLeader(snapshotId.epoch())
.appendToLog(snapshotId.epoch(), Arrays.asList("a", "b", "c"))
.appendToLog(snapshotId.epoch(), Arrays.asList("d", "e", "f"))
.appendToLog(snapshotId.epoch(), Arrays.asList("g", "h", "i"))
.withEmptySnapshot(snapshotId)
.deleteBeforeSnapshot(snapshotId)
.build();

context.becomeLeader();
int epoch = context.currentEpoch();

// Send a fetch request for an end offset and epoch which has been snapshotted
context.deliverRequest(context.fetchRequest(epoch, otherNodeId, 6, 2, 500));
context.client.poll();

// Expect that the leader replies immediately with a snapshot id
FetchResponseData.PartitionData partitionResponse = context.assertSentFetchPartitionResponse();
assertEquals(Errors.NONE, Errors.forCode(partitionResponse.errorCode()));
assertEquals(epoch, partitionResponse.currentLeader().leaderEpoch());
assertEquals(localId, partitionResponse.currentLeader().leaderId());
assertEquals(snapshotId.epoch(), partitionResponse.snapshotId().epoch());
assertEquals(snapshotId.offset(), partitionResponse.snapshotId().endOffset());
}

@Test
public void testFetchRequestOffsetLessThanLogStart() throws Exception {
int localId = 0;
Expand Down
30 changes: 30 additions & 0 deletions raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1206,6 +1206,36 @@ public void testListenerCommitCallbackAfterLeaderWrite() throws Exception {
assertEquals(records, context.listener.commitWithLastOffset(offset));
}

@Test
public void testLeaderImmediatelySendsDivergingEpoch() throws Exception {
int localId = 0;
int otherNodeId = 1;
Set<Integer> voters = Utils.mkSet(localId, otherNodeId);

RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters)
.withUnknownLeader(5)
.appendToLog(1, Arrays.asList("a", "b", "c"))
.appendToLog(3, Arrays.asList("d", "e", "f"))
.appendToLog(5, Arrays.asList("g", "h", "i"))
.build();

// Start off as the leader
context.becomeLeader();
int epoch = context.currentEpoch();

// Send a fetch request for an end offset and epoch which has diverged
context.deliverRequest(context.fetchRequest(epoch, otherNodeId, 6, 2, 500));
context.client.poll();

// Expect that the leader replies immediately with a diverging epoch
FetchResponseData.PartitionData partitionResponse = context.assertSentFetchPartitionResponse();
assertEquals(Errors.NONE, Errors.forCode(partitionResponse.errorCode()));
assertEquals(epoch, partitionResponse.currentLeader().leaderEpoch());
assertEquals(localId, partitionResponse.currentLeader().leaderId());
assertEquals(1, partitionResponse.divergingEpoch().epoch());
assertEquals(3, partitionResponse.divergingEpoch().endOffset());
}

@Test
public void testCandidateIgnoreVoteRequestOnSameEpoch() throws Exception {
int localId = 0;
Expand Down
54 changes: 40 additions & 14 deletions raft/src/test/java/org/apache/kafka/raft/MockLog.java
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ public void truncateTo(long offset) {
" which is below the current high watermark " + highWatermark);
}

logger.debug("Truncating log to end offset {}", offset);
batches.removeIf(entry -> entry.lastOffset() >= offset);
epochStartOffsets.removeIf(epochStartOffset -> epochStartOffset.startOffset >= offset);
firstUnflushedOffset = Math.min(firstUnflushedOffset, endOffset().offset);
Expand All @@ -98,6 +99,8 @@ public boolean truncateToLatestSnapshot() {
(snapshotId.epoch() == logLastFetchedEpoch().orElse(0) &&
snapshotId.offset() > endOffset().offset)) {

logger.debug("Truncating to the latest snapshot at {}", snapshotId);

batches.clear();
epochStartOffsets.clear();
snapshots.headMap(snapshotId, false).clear();
Expand Down Expand Up @@ -278,10 +281,11 @@ public LogAppendInfo appendAsLeader(Records records, int epoch) {
return append(records, OptionalInt.of(epoch));
}

private Long appendBatch(LogBatch batch) {
private long appendBatch(LogBatch batch) {
if (batch.epoch > lastFetchedEpoch()) {
epochStartOffsets.add(new EpochStartOffset(batch.epoch, batch.firstOffset()));
}

batches.add(batch);
return batch.firstOffset();
}
Expand Down Expand Up @@ -311,15 +315,22 @@ private LogAppendInfo append(Records records, OptionalInt epoch) {
);
}

List<LogEntry> entries = buildEntries(batch, Record::offset);
appendBatch(
new LogBatch(
epoch.orElseGet(batch::partitionLeaderEpoch),
batch.isControlBatch(),
entries
)
LogBatch logBatch = new LogBatch(
epoch.orElseGet(batch::partitionLeaderEpoch),
batch.isControlBatch(),
buildEntries(batch, Record::offset)
);
lastOffset = entries.get(entries.size() - 1).offset;

if (logger.isDebugEnabled()) {
String nodeState = "Follower";
if (epoch.isPresent()) {
nodeState = "Leader";
}
logger.debug("{} appending to the log {}", nodeState, logBatch);
}

appendBatch(logBatch);
lastOffset = logBatch.last().offset;
}

return new LogAppendInfo(baseOffset, lastOffset);
Expand Down Expand Up @@ -385,13 +396,9 @@ private void verifyOffsetInRange(long offset) {

@Override
public LogFetchInfo read(long startOffset, Isolation isolation) {
OptionalLong maxOffsetOpt = isolation == Isolation.COMMITTED ?
OptionalLong.of(highWatermark.offset) :
OptionalLong.empty();

verifyOffsetInRange(startOffset);

long maxOffset = maxOffsetOpt.orElse(endOffset().offset);
long maxOffset = isolation == Isolation.COMMITTED ? highWatermark.offset : endOffset().offset;
if (startOffset >= maxOffset) {
return new LogFetchInfo(MemoryRecords.EMPTY, new LogOffsetMetadata(
startOffset, metadataForOffset(startOffset)));
Expand All @@ -401,6 +408,13 @@ public LogFetchInfo read(long startOffset, Isolation isolation) {
int batchCount = 0;
LogOffsetMetadata batchStartOffset = null;

logger.debug(
"Looking for a batch that starts at {} and ends at {} for isolation {}",
startOffset,
maxOffset,
isolation
);

for (LogBatch batch : batches) {
// Note that start offset is inclusive while max offset is exclusive. We only return
// complete batches, so batches which end at an offset larger than the max offset are
Expand Down Expand Up @@ -541,6 +555,7 @@ public boolean deleteBeforeSnapshot(OffsetAndEpoch snapshotId) {
if (snapshots.containsKey(snapshotId)) {
snapshots.headMap(snapshotId, false).clear();

logger.debug("Deleting batches included in the snapshot {}", snapshotId);
batches.removeIf(entry -> entry.lastOffset() < snapshotId.offset());

AtomicReference<Optional<EpochStartOffset>> last = new AtomicReference<>(Optional.empty());
Expand All @@ -566,6 +581,17 @@ public boolean deleteBeforeSnapshot(OffsetAndEpoch snapshotId) {
return updated;
}

@Override
public String toString() {
return String.format(
"MockLog(epochStartOffsets=%s, batches=%s, snapshots=%s, highWatermark=%s",
epochStartOffsets,
batches,
snapshots,
highWatermark
);
}

static class MockOffsetMetadata implements OffsetMetadata {
final long id;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -385,9 +385,7 @@ LeaderAndEpoch currentLeaderAndEpoch() {
return new LeaderAndEpoch(election.leaderIdOpt, election.epoch);
}

void expectAndGrantVotes(
int epoch
) throws Exception {
void expectAndGrantVotes(int epoch) throws Exception {
pollUntilRequest();

List<RaftRequest.Outbound> voteRequests = collectVoteRequests(epoch,
Expand All @@ -406,9 +404,7 @@ private int localIdOrThrow() {
return localId.orElseThrow(() -> new AssertionError("Required local id is not defined"));
}

private void expectBeginEpoch(
int epoch
) throws Exception {
private void expectBeginEpoch(int epoch) throws Exception {
pollUntilRequest();
for (RaftRequest.Outbound request : collectBeginEpochRequests(epoch)) {
BeginQuorumEpochResponseData beginEpochResponse = beginEpochResponse(epoch, localIdOrThrow());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
import org.apache.kafka.raft.MockLog.LogEntry;
import org.apache.kafka.raft.internals.BatchMemoryPool;
import org.apache.kafka.server.common.serialization.RecordSerde;
import org.apache.kafka.snapshot.SnapshotReader;
import org.apache.kafka.snapshot.RecordsSnapshotReader;
import org.apache.kafka.snapshot.SnapshotReader;

import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
Expand Down

0 comments on commit 0473851

Please sign in to comment.