Skip to content

Commit

Permalink
Simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
resetius committed Nov 30, 2023
1 parent 1ea3fdd commit df20609
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 181 deletions.
176 changes: 47 additions & 129 deletions src/raft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,12 @@ TRaft::TRaft(int node, const TNodeDict& nodes)
}
}

std::unique_ptr<TResult> TRaft::OnRequestVote(ITimeSource::Time now, TMessageHolder<TRequestVoteRequest> message) {
void TRaft::OnRequestVote(ITimeSource::Time now, TMessageHolder<TRequestVoteRequest> message) {
if (message->Term < State->CurrentTerm) {
auto reply = NewHoldedMessage<TRequestVoteResponse>();
reply->Src = Id;
reply->Dst = message->Src;
reply->Term = State->CurrentTerm;
reply->VoteGranted = false;
return std::make_unique<TResult>(TResult {
.Message = reply
});
auto reply = NewHoldedMessage(
TMessageEx {.Src = Id, .Dst = message->Src, .Term = State->CurrentTerm},
TRequestVoteResponse {.VoteGranted = false});
Nodes[reply->Dst]->Send(std::move(reply));
} else if (message->Term == State->CurrentTerm) {
bool accept = false;
if (State->VotedFor == 0 || State->VotedFor == message->CandidateId) {
Expand All @@ -125,49 +121,29 @@ std::unique_ptr<TResult> TRaft::OnRequestVote(ITimeSource::Time now, TMessageHol
TMessageEx {.Src = Id, .Dst = message->Src, .Term = State->CurrentTerm},
TRequestVoteResponse {.VoteGranted = accept});

decltype(VolatileState) nextVolatileState = nullptr;
if (accept) {
nextVolatileState = std::make_unique<TVolatileState>(*VolatileState);
nextVolatileState->ElectionDue = MakeElection(now);
VolatileState->ElectionDue = MakeElection(now);
State->VotedFor = message->CandidateId;
}

return std::make_unique<TResult>(TResult {
.NextState = accept ? std::make_unique<TState>(TState{
.CurrentTerm = State->CurrentTerm,
.VotedFor = message->CandidateId,
.Log = State->Log
}) : nullptr,
.NextVolatileState = std::move(nextVolatileState),
.Message = reply,
});
Nodes[reply->Dst]->Send(std::move(reply));
}

return nullptr;
}

std::unique_ptr<TResult> TRaft::OnRequestVote(TMessageHolder<TRequestVoteResponse> message) {
void TRaft::OnRequestVote(TMessageHolder<TRequestVoteResponse> message) {
if (message->VoteGranted && message->Term == State->CurrentTerm) {
auto votes = VolatileState->Votes;
votes.insert(message->Src);

auto nextVolatileState = *VolatileState;
nextVolatileState
(*VolatileState)
.SetVotes(votes)
.MergeRpcDue({{message->Src, ITimeSource::Max}});
return std::make_unique<TResult>(TResult {
.NextVolatileState = std::make_unique<TVolatileState>(
nextVolatileState
),
});
}

return {};
}

std::unique_ptr<TResult> TRaft::OnAppendEntries(ITimeSource::Time now, TMessageHolder<TAppendEntriesRequest> message) {
void TRaft::OnAppendEntries(ITimeSource::Time now, TMessageHolder<TAppendEntriesRequest> message) {
if (message->Term < State->CurrentTerm) {
auto nextVolatileState = *VolatileState;
nextVolatileState.ElectionDue = MakeElection(now);
VolatileState->ElectionDue = MakeElection(now);

auto reply = NewHoldedMessage(
TMessageEx {
Expand All @@ -179,31 +155,27 @@ std::unique_ptr<TResult> TRaft::OnAppendEntries(ITimeSource::Time now, TMessageH
.MatchIndex = 0,
.Success = false,
});
return std::make_unique<TResult>(TResult {
.NextVolatileState = std::make_unique<TVolatileState>(nextVolatileState),
.Message = reply,
});
Nodes[reply->Dst]->Send(std::move(reply));
return;
}

assert(message->Term == State->CurrentTerm);

uint64_t matchIndex = 0;
uint64_t commitIndex = VolatileState->CommitIndex;
bool success = false;
std::unique_ptr<TState> state;
if (message->PrevLogIndex == 0 ||
(message->PrevLogIndex <= State->Log.size()
&& State->LogTerm(message->PrevLogIndex) == message->PrevLogTerm))
{
success = true;
auto index = message->PrevLogIndex;
state = std::make_unique<TState>(*State);
auto& log = state->Log;
auto& log = State->Log;
for (auto& data : message.Payload) {
auto entry = data.Cast<TLogEntry>();
index++;
// replace or append log entries
if (state->LogTerm(index) != entry->Term) {
if (State->LogTerm(index) != entry->Term) {
while (log.size() > index-1) {
log.pop_back();
}
Expand All @@ -219,42 +191,30 @@ std::unique_ptr<TResult> TRaft::OnAppendEntries(ITimeSource::Time now, TMessageH
TMessageEx {.Src = Id, .Dst = message->Src, .Term = State->CurrentTerm},
TAppendEntriesResponse {.MatchIndex = matchIndex, .Success = success});

auto nextVolatileState = *VolatileState;
nextVolatileState.SetCommitIndex(commitIndex);
nextVolatileState.ElectionDue = MakeElection(now);
return std::make_unique<TResult>(TResult {
.NextState = std::move(state),
.NextVolatileState = std::make_unique<TVolatileState>(nextVolatileState),
.NextStateName = EState::FOLLOWER,
.Message = reply,
});
VolatileState->SetCommitIndex(commitIndex);
VolatileState->ElectionDue = MakeElection(now);
Become(EState::FOLLOWER);
Nodes[reply->Dst]->Send(std::move(reply));
}

std::unique_ptr<TResult> TRaft::OnAppendEntries(TMessageHolder<TAppendEntriesResponse> message) {
void TRaft::OnAppendEntries(TMessageHolder<TAppendEntriesResponse> message) {
if (message->Term != State->CurrentTerm) {
return nullptr;
return;
}

auto nodeId = message->Src;
if (message->Success) {
auto matchIndex = std::max(VolatileState->MatchIndex[nodeId], message->MatchIndex);
auto nextVolatileState = *VolatileState;
nextVolatileState
(*VolatileState)
.MergeMatchIndex({{nodeId, matchIndex}})
.MergeNextIndex({{nodeId, message->MatchIndex+1}})
.CommitAdvance(Nservers, State->Log.size(), *State)
.MergeRpcDue({{nodeId, ITimeSource::Time{}}});
return std::make_unique<TResult>(TResult {
.NextVolatileState = std::make_unique<TVolatileState>(nextVolatileState)
});
} else {
auto nextVolatileState = *VolatileState;
nextVolatileState
(*VolatileState)
.MergeNextIndex({{nodeId, std::max((uint64_t)1, VolatileState->NextIndex[nodeId]-1)}})
//.MergeNextIndex({{nodeId, 1}})
.MergeRpcDue({{nodeId, ITimeSource::Time{}}});
return std::make_unique<TResult>(TResult {
.NextVolatileState = std::make_unique<TVolatileState>(nextVolatileState)
});
}
}

Expand Down Expand Up @@ -291,6 +251,9 @@ TMessageHolder<TAppendEntriesRequest> TRaft::CreateAppendEntries(uint32_t nodeId
for (auto i = prevIndex; i < lastIndex; i++) {
payload.push_back(State->Log[i]);
}
if (!payload.empty()) {
std::cout << "Send " << payload.size() << " entries to " << nodeId << "\n";
}
mes.Payload = std::move(payload);
return mes;
}
Expand All @@ -303,57 +266,47 @@ std::vector<TMessageHolder<TAppendEntriesRequest>> TRaft::CreateAppendEntries()
return res;
}

std::unique_ptr<TResult> TRaft::Follower(ITimeSource::Time now, TMessageHolder<TMessage> message) {
void TRaft::Follower(ITimeSource::Time now, TMessageHolder<TMessage> message) {
if (auto maybeRequestVote = message.Maybe<TRequestVoteRequest>()) {
return OnRequestVote(now, std::move(maybeRequestVote.Cast()));
OnRequestVote(now, std::move(maybeRequestVote.Cast()));
} else if (auto maybeAppendEntries = message.Maybe<TAppendEntriesRequest>()) {
return OnAppendEntries(now, maybeAppendEntries.Cast());
OnAppendEntries(now, maybeAppendEntries.Cast());
}
return nullptr;
}

std::unique_ptr<TResult> TRaft::Candidate(ITimeSource::Time now, TMessageHolder<TMessage> message) {
void TRaft::Candidate(ITimeSource::Time now, TMessageHolder<TMessage> message) {
if (auto maybeResponseVote = message.Maybe<TRequestVoteResponse>()) {
return OnRequestVote(std::move(maybeResponseVote.Cast()));
OnRequestVote(std::move(maybeResponseVote.Cast()));
} else if (auto maybeRequestVote = message.Maybe<TRequestVoteRequest>()) {
return OnRequestVote(now, std::move(maybeRequestVote.Cast()));
OnRequestVote(now, std::move(maybeRequestVote.Cast()));
} else if (auto maybeAppendEntries = message.Maybe<TAppendEntriesRequest>()) {
return OnAppendEntries(now, maybeAppendEntries.Cast());
OnAppendEntries(now, maybeAppendEntries.Cast());
}
return nullptr;
}

std::unique_ptr<TResult> TRaft::Leader(ITimeSource::Time now, TMessageHolder<TMessage> message) {
void TRaft::Leader(ITimeSource::Time now, TMessageHolder<TMessage> message, const std::shared_ptr<INode>& replyTo) {
if (auto maybeAppendEntries = message.Maybe<TAppendEntriesResponse>()) {
return OnAppendEntries(maybeAppendEntries.Cast());
OnAppendEntries(maybeAppendEntries.Cast());
} else if (auto maybeCommandRequest = message.Maybe<TCommandRequest>()) {
auto command = maybeCommandRequest.Cast();
auto log = State->Log;
auto& log = State->Log;
auto dataSize = command->Len - sizeof(TCommandRequest);
auto entry = NewHoldedMessage<TLogEntry>(sizeof(TLogEntry)+dataSize);
memcpy(entry->Data, command->Data, dataSize);
entry->Term = State->CurrentTerm;
log.push_back(entry);
auto index = log.size()-1;
auto nextState = std::make_unique<TState>(TState {
.CurrentTerm = State->CurrentTerm,
.VotedFor = State->VotedFor,
.Log = std::move(log),
});
auto mes = NewHoldedMessage(TCommandResponse {.Index = index});
return std::make_unique<TResult>(TResult {
.NextState = std::move(nextState),
.Message = mes
});
if (replyTo) {
auto mes = NewHoldedMessage(TCommandResponse {.Index = index});
waiting.emplace(TWaiting{mes->Index, mes, replyTo});
}
} else if (auto maybeVoteRequest = message.Maybe<TRequestVoteRequest>()) {
return OnRequestVote(now, std::move(maybeVoteRequest.Cast()));
OnRequestVote(now, std::move(maybeVoteRequest.Cast()));
} else if (auto maybeVoteResponse = message.Maybe<TRequestVoteResponse>()) {
// skip additional votes
} else if (auto maybeAppendEntries = message.Maybe<TAppendEntriesRequest>()) {
return OnAppendEntries(now, maybeAppendEntries.Cast());
OnAppendEntries(now, maybeAppendEntries.Cast());
}

return nullptr;
}

void TRaft::Become(EState newStateName) {
Expand All @@ -374,54 +327,19 @@ void TRaft::Process(ITimeSource::Time now, TMessageHolder<TMessage> message, con
}
}
}
std::unique_ptr<TResult> result;
switch (StateName) {
case EState::FOLLOWER:
result = Follower(now, std::move(message));
Follower(now, std::move(message));
break;
case EState::CANDIDATE:
result = Candidate(now, std::move(message));
Candidate(now, std::move(message));
break;
case EState::LEADER:
result = Leader(now, std::move(message));
Leader(now, std::move(message), replyTo);
break;
default:
throw std::logic_error("Unknown state");
}

ApplyResult(now, std::move(result), replyTo);
}

void TRaft::ApplyResult(ITimeSource::Time now, std::unique_ptr<TResult> result, const std::shared_ptr<INode>& replyTo) {
if (!result) {
return;
}
if (result->NextState) {
State = std::move(result->NextState);
}
if (result->NextVolatileState) {
VolatileState = std::move(result->NextVolatileState);
}
if (result->Message) {
if (auto reply = result->Message.Maybe<TCommandResponse>()) {
if (replyTo) {
auto res = reply.Cast();
waiting.emplace(TWaiting{res->Index, res, replyTo});
}
} else {
auto messageEx = result->Message.Cast<TMessageEx>();
if (messageEx->Dst == 0) {
for (auto& [id, v] : Nodes) {
v->Send(std::move(messageEx));
}
} else {
Nodes[messageEx->Dst]->Send(std::move(messageEx));
}
}
}
if (result->NextStateName != EState::NONE) {
StateName = result->NextStateName;
}
}

void TRaft::ProcessWaiting() {
Expand Down
25 changes: 8 additions & 17 deletions src/raft.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,19 +62,11 @@ enum class EState: int {
LEADER = 3,
};

struct TResult {
std::unique_ptr<TState> NextState;
std::unique_ptr<TVolatileState> NextVolatileState;
EState NextStateName = EState::NONE;
TMessageHolder<TMessage> Message;
};

class TRaft {
public:
TRaft(int node, const TNodeDict& nodes);

void Process(ITimeSource::Time now, TMessageHolder<TMessage> message, const std::shared_ptr<INode>& replyTo = {});
void ApplyResult(ITimeSource::Time now, std::unique_ptr<TResult> result, const std::shared_ptr<INode>& replyTo = {});
void ProcessTimeout(ITimeSource::Time now);

// ut
Expand All @@ -100,16 +92,15 @@ class TRaft {
return Id;
}

std::unique_ptr<TResult> Candidate(ITimeSource::Time now, TMessageHolder<TMessage> message);

private:
std::unique_ptr<TResult> Follower(ITimeSource::Time now, TMessageHolder<TMessage> message);
std::unique_ptr<TResult> Leader(ITimeSource::Time now, TMessageHolder<TMessage> message);

std::unique_ptr<TResult> OnRequestVote(ITimeSource::Time now, TMessageHolder<TRequestVoteRequest> message);
std::unique_ptr<TResult> OnRequestVote(TMessageHolder<TRequestVoteResponse> message);
std::unique_ptr<TResult> OnAppendEntries(ITimeSource::Time now, TMessageHolder<TAppendEntriesRequest> message);
std::unique_ptr<TResult> OnAppendEntries(TMessageHolder<TAppendEntriesResponse> message);
void Candidate(ITimeSource::Time now, TMessageHolder<TMessage> message);
void Follower(ITimeSource::Time now, TMessageHolder<TMessage> message);
void Leader(ITimeSource::Time now, TMessageHolder<TMessage> message, const std::shared_ptr<INode>& replyTo);

void OnRequestVote(ITimeSource::Time now, TMessageHolder<TRequestVoteRequest> message);
void OnRequestVote(TMessageHolder<TRequestVoteResponse> message);
void OnAppendEntries(ITimeSource::Time now, TMessageHolder<TAppendEntriesRequest> message);
void OnAppendEntries(TMessageHolder<TAppendEntriesResponse> message);

void LeaderTimeout(ITimeSource::Time now);
void CandidateTimeout(ITimeSource::Time now);
Expand Down
2 changes: 1 addition & 1 deletion src/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ NNet::TSimpleTask TRaftServer::Idle() {
DrainNodes();
auto t1 = TimeSource->Now();
if (t1 > t0 + dt) {
std::cout << "Idle " << (uint32_t)Raft->CurrentStateName() << "\n";
std::cout << "Idle " << (uint32_t)Raft->CurrentStateName() << " " << Raft->GetState()->Log.size() << "\n";
t0 = t1;
}
co_await Poller.Sleep(sleep);
Expand Down
Loading

0 comments on commit df20609

Please sign in to comment.