Skip to content

Commit

Permalink
Use shared_ptr
Browse files Browse the repository at this point in the history
  • Loading branch information
resetius committed Nov 26, 2023
1 parent a68f63e commit 3352323
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 17 deletions.
4 changes: 2 additions & 2 deletions src/raft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ void TRaft::Become(EState newStateName) {
}
}

void TRaft::Process(TMessageHolder<TMessage> message, INode* replyTo) {
void TRaft::Process(TMessageHolder<TMessage> message, const std::shared_ptr<INode>& replyTo) {
auto now = TimeSource->Now();

if (message.IsEx()) {
Expand Down Expand Up @@ -404,7 +404,7 @@ void TRaft::Process(TMessageHolder<TMessage> message, INode* replyTo) {
ApplyResult(now, std::move(result), replyTo);
}

void TRaft::ApplyResult(ITimeSource::Time now, std::unique_ptr<TResult> result, INode* replyTo) {
void TRaft::ApplyResult(ITimeSource::Time now, std::unique_ptr<TResult> result, const std::shared_ptr<INode>& replyTo) {
if (!result) {
return;
}
Expand Down
4 changes: 2 additions & 2 deletions src/raft.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ class TRaft {
public:
TRaft(int node, const TNodeDict& nodes, const std::shared_ptr<ITimeSource>& ts);

void Process(TMessageHolder<TMessage> message, INode* replyTo = nullptr);
void ApplyResult(ITimeSource::Time now, std::unique_ptr<TResult> result, INode* replyTo = nullptr);
void Process(TMessageHolder<TMessage> message, const std::shared_ptr<INode>& replyTo = {});
void ApplyResult(ITimeSource::Time now, std::unique_ptr<TResult> result, const std::shared_ptr<INode>& replyTo = {});

// ut
EState CurrentStateName() const {
Expand Down
20 changes: 10 additions & 10 deletions src/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ void TNode::Send(const TMessageHolder<TMessage>& message) {
}

void TNode::Drain() {
if (!Connected) {
Connect();
return;
}
if (!Drainer || Drainer.done()) {
if (Drainer && Drainer.done()) {
Drainer.destroy();
Expand All @@ -94,18 +98,13 @@ void TNode::Drain() {
}

NNet::TTestTask TNode::DoDrain() {
if (!Connected) {
Connect();
co_return;
}
auto tosend = std::move(Messages);
try {
for (auto&& m : tosend) {
co_await TWriter(Socket).Write(std::move(m));
}
} catch (const std::exception& ex) {
std::cout << "Error on write: " << ex.what() << "\n";
Connect();
}
Messages.clear();
co_return;
Expand Down Expand Up @@ -143,13 +142,14 @@ NNet::TTestTask TNode::DoConnect() {

NNet::TSimpleTask TRaftServer::InboundConnection(NNet::TSocket socket) {
try {
TClientNode client;
auto client = std::make_shared<TClientNode>();
Nodes.insert(client);
while (true) {
auto mes = co_await TReader(socket).Read();
std::cout << "Got message " << mes->Type << "\n";
Raft->Process(std::move(mes), &client);
if (!client.Messages.empty()) {
auto tosend = std::move(client.Messages); client.Messages.clear();
Raft->Process(std::move(mes), client);
if (!client->Messages.empty()) {
auto tosend = std::move(client->Messages); client->Messages.clear();
for (auto&& mes : tosend) {
co_await TWriter(socket).Write(std::move(mes));
}
Expand All @@ -168,7 +168,7 @@ void TRaftServer::Serve() {
}

void TRaftServer::DrainNodes() {
for (auto [id, node] : Nodes) {
for (const auto& node : Nodes) {
node->Drain();
}
}
Expand Down
9 changes: 6 additions & 3 deletions src/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,12 @@ class TRaftServer {
: Poller(poller)
, Socket(std::move(address), Poller)
, Raft(raft)
, Nodes(nodes)
, TimeSource(ts)
{ }
{
for (const auto& [_, node] : nodes) {
Nodes.insert(node);
}
}

void Serve();

Expand All @@ -210,6 +213,6 @@ class TRaftServer {
NNet::TPoll& Poller;
NNet::TPoll::TSocket Socket;
std::shared_ptr<TRaft> Raft;
TNodeDict Nodes;
std::unordered_set<std::shared_ptr<INode>> Nodes;
std::shared_ptr<ITimeSource> TimeSource;
};

0 comments on commit 3352323

Please sign in to comment.