Skip to content

Commit

Permalink
Can select poller
Browse files Browse the repository at this point in the history
  • Loading branch information
resetius committed Nov 30, 2023
1 parent df20609 commit f8b1d39
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 28 deletions.
7 changes: 6 additions & 1 deletion client/client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,13 @@ int main(int argc, char** argv) {

}
}
#ifdef __linux__
using TPoller = NNet::TEPoll;
#else
using TPoller = NNet::TPoll;
#endif
std::shared_ptr<ITimeSource> timeSource = std::make_shared<TTimeSource>();
NNet::TLoop<NNet::TPoll> loop;
NNet::TLoop<TPoller> loop;
Client(loop.Poller(), NNet::TAddress{hosts[0].Address, hosts[0].Port});
loop.Loop();
return 0;
Expand Down
10 changes: 8 additions & 2 deletions server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,20 @@ int main(int argc, char** argv) {
}
}

#ifdef __linux__
using TPoller = NNet::TEPoll;
#else
using TPoller = NNet::TPoll;
#endif

std::shared_ptr<ITimeSource> timeSource = std::make_shared<TTimeSource>();
NNet::TLoop<NNet::TPoll> loop;
NNet::TLoop<TPoller> loop;

for (auto& host : hosts) {
if (host.Id == id) {
myHost = host;
} else {
nodes[host.Id] = std::make_shared<TNode>(
nodes[host.Id] = std::make_shared<TNode<TPoller>>(
loop.Poller(),
std::to_string(host.Id),
NNet::TAddress{host.Address, host.Port},
Expand Down
43 changes: 30 additions & 13 deletions src/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
#include "server.h"
#include "messages.h"

TPromise<void>::TTask TWriter::Write(TMessageHolder<TMessage> message) {
template<typename TSocket>
TPromise<void>::TTask TWriter<TSocket>::Write(TMessageHolder<TMessage> message) {
auto payload = std::move(message.Payload);
char* p = (char*)message.Mes; // TODO: const char
uint32_t len = message->Len;
Expand All @@ -30,7 +31,8 @@ TPromise<void>::TTask TWriter::Write(TMessageHolder<TMessage> message) {
co_return;
}

TPromise<TMessageHolder<TMessage>>::TTask TReader::Read() {
template<typename TSocket>
TPromise<TMessageHolder<TMessage>>::TTask TReader<TSocket>::Read() {
decltype(TMessage::Type) type;
decltype(TMessage::Len) len;
auto s = co_await Socket.ReadSome((char*)&type, sizeof(type));
Expand Down Expand Up @@ -66,11 +68,13 @@ TPromise<TMessageHolder<TMessage>>::TTask TReader::Read() {
co_return mes;
}

void TNode::Send(TMessageHolder<TMessage> message) {
template<typename TPoller>
void TNode<TPoller>::Send(TMessageHolder<TMessage> message) {
Messages.emplace_back(std::move(message));
}

void TNode::Drain() {
template<typename TPoller>
void TNode<TPoller>::Drain() {
if (!Connected) {
Connect();
return;
Expand All @@ -83,7 +87,8 @@ void TNode::Drain() {
}
}

NNet::TTestTask TNode::DoDrain() {
template<typename TPoller>
NNet::TTestTask TNode<TPoller>::DoDrain() {
auto tosend = std::move(Messages);
try {
for (auto&& m : tosend) {
Expand All @@ -97,7 +102,8 @@ NNet::TTestTask TNode::DoDrain() {
co_return;
}

void TNode::Connect() {
template<typename TPoller>
void TNode<TPoller>::Connect() {
if (Address && (!Connector || Connector.done())) {
if (Connector && Connector.done()) {
Connector.destroy();
Expand All @@ -109,7 +115,8 @@ void TNode::Connect() {
}
}

NNet::TTestTask TNode::DoConnect() {
template<typename TPoller>
NNet::TTestTask TNode<TPoller>::DoConnect() {
std::cout << "Connecting " << Name << "\n";
while (!Connected) {
try {
Expand All @@ -127,9 +134,10 @@ NNet::TTestTask TNode::DoConnect() {
co_return;
}

NNet::TSimpleTask TRaftServer::InboundConnection(NNet::TSocket socket) {
template<typename TPoller>
NNet::TSimpleTask TRaftServer<TPoller>::InboundConnection(NNet::TSocket socket) {
try {
auto client = std::make_shared<TNode>(
auto client = std::make_shared<TNode<TPoller>>(
Poller, "client", std::move(socket), TimeSource
);
Nodes.insert(client);
Expand All @@ -146,18 +154,21 @@ NNet::TSimpleTask TRaftServer::InboundConnection(NNet::TSocket socket) {
co_return;
}

void TRaftServer::Serve() {
template<typename TPoller>
void TRaftServer<TPoller>::Serve() {
Idle();
InboundServe();
}

void TRaftServer::DrainNodes() {
template<typename TPoller>
void TRaftServer<TPoller>::DrainNodes() {
for (const auto& node : Nodes) {
node->Drain();
}
}

NNet::TSimpleTask TRaftServer::InboundServe() {
template<typename TPoller>
NNet::TSimpleTask TRaftServer<TPoller>::InboundServe() {
std::cout << "Bind\n";
Socket.Bind();
std::cout << "Listen\n";
Expand All @@ -170,7 +181,8 @@ NNet::TSimpleTask TRaftServer::InboundServe() {
co_return;
}

NNet::TSimpleTask TRaftServer::Idle() {
template<typename TPoller>
NNet::TSimpleTask TRaftServer<TPoller>::Idle() {
auto t0 = TimeSource->Now();
auto dt = std::chrono::milliseconds(2000);
auto sleep = std::chrono::milliseconds(100);
Expand All @@ -186,3 +198,8 @@ NNet::TSimpleTask TRaftServer::Idle() {
}
co_return;
}

template class TRaftServer<NNet::TPoll>;
#ifdef __linux__
template class TRaftServer<NNet::TEPoll>;
#endif
28 changes: 16 additions & 12 deletions src/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,28 +117,30 @@ struct TPromise<void>
std::coroutine_handle<> Caller = std::noop_coroutine();
};

template<typename TSocket>
class TReader {
public:
TReader(NNet::TPoll::TSocket& socket)
TReader(TSocket& socket)
: Socket(socket)
{ }

TPromise<TMessageHolder<TMessage>>::TTask Read();

private:
NNet::TPoll::TSocket& Socket;
TSocket& Socket;
};

template<typename TSocket>
class TWriter {
public:
TWriter(NNet::TPoll::TSocket& socket)
TWriter(TSocket& socket)
: Socket(socket)
{ }

TPromise<void>::TTask Write(TMessageHolder<TMessage> message);

private:
NNet::TPoll::TSocket& Socket;
TSocket& Socket;
};

struct THost {
Expand Down Expand Up @@ -166,16 +168,17 @@ struct THost {
}
};

template<typename TPoller>
class TNode: public INode {
public:
TNode(NNet::TPoll& poller, const std::string& name, NNet::TAddress address, const std::shared_ptr<ITimeSource>& ts)
TNode(TPoller& poller, const std::string& name, NNet::TAddress address, const std::shared_ptr<ITimeSource>& ts)
: Poller(poller)
, Name(name)
, Address(address)
, TimeSource(ts)
{ }

TNode(NNet::TPoll& poller, const std::string& name, NNet::TSocket socket, const std::shared_ptr<ITimeSource>& ts)
TNode(TPoller& poller, const std::string& name, typename TPoller::TSocket socket, const std::shared_ptr<ITimeSource>& ts)
: Poller(poller)
, Name(name)
, Socket(std::move(socket))
Expand All @@ -185,7 +188,7 @@ class TNode: public INode {

void Send(TMessageHolder<TMessage> message) override;
void Drain() override;
NNet::TSocket& Sock() {
typename TPoller::TSocket& Sock() {
return Socket;
}

Expand All @@ -195,11 +198,11 @@ class TNode: public INode {
NNet::TTestTask DoDrain();
NNet::TTestTask DoConnect();

NNet::TPoll& Poller;
TPoller& Poller;
std::string Name;
std::optional<NNet::TAddress> Address;
std::shared_ptr<ITimeSource> TimeSource;
NNet::TPoll::TSocket Socket;
typename TPoller::TSocket Socket;
bool Connected = false;

std::coroutine_handle<> Drainer;
Expand All @@ -208,10 +211,11 @@ class TNode: public INode {
std::vector<TMessageHolder<TMessage>> Messages;
};

template<typename TPoller>
class TRaftServer {
public:
TRaftServer(
NNet::TPoll& poller,
TPoller& poller,
NNet::TAddress address,
const std::shared_ptr<TRaft>& raft,
const TNodeDict& nodes,
Expand All @@ -234,8 +238,8 @@ class TRaftServer {
NNet::TSimpleTask Idle();
void DrainNodes();

NNet::TPoll& Poller;
NNet::TPoll::TSocket Socket;
TPoller& Poller;
typename TPoller::TSocket Socket;
std::shared_ptr<TRaft> Raft;
std::unordered_set<std::shared_ptr<INode>> Nodes;
std::shared_ptr<ITimeSource> TimeSource;
Expand Down

0 comments on commit f8b1d39

Please sign in to comment.