Skip to content

Commit

Permalink
Revisit bt request stream api
Browse files Browse the repository at this point in the history
- commands are now registered and methods are executed automatically
- now only commands; providing a callback makes it implicitly a request
  • Loading branch information
dr7ana committed Sep 26, 2023
1 parent ba19583 commit 9ed5a59
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 92 deletions.
30 changes: 13 additions & 17 deletions include/quic/btstream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ namespace oxen::quic
public:
message(BTRequestStream& bp, std::string req, bool is_error = false);

void respond(int64_t rid, std::string body, bool error = false);
void respond(std::string body, bool error = false);

// To be used to determine if the message was a result of an error as such:
//
Expand All @@ -56,13 +56,16 @@ namespace oxen::quic
std::string_view type() const { return req_type; }
std::string_view endpoint() const { return ep; }
std::string_view body() const { return req_body; }
std::string endpoint_str() const { return std::string{ep}; }
std::string body_str() const { return std::string{req_body}; }
};

struct sent_request
{
// parsed request data
int64_t req_id;
std::string data;
std::function<void(message)> cb;
BTRequestStream& return_sender;

// total length of the request; is at the beginning of the request
Expand All @@ -73,7 +76,8 @@ namespace oxen::quic

bool is_empty() const { return data.empty() && total_len == 0; }

explicit sent_request(BTRequestStream& bp, std::string_view d, int64_t rid);
explicit sent_request(
BTRequestStream& bp, std::string_view d, int64_t rid, std::function<void(message)> f = nullptr);

bool is_expired(std::chrono::steady_clock::time_point tp) const { return timeout < tp; }

Expand All @@ -89,6 +93,8 @@ namespace oxen::quic
// outgoing requests awaiting response
std::deque<std::shared_ptr<sent_request>> sent_reqs;

std::unordered_map<std::string, std::function<void(message)>> func_map;

std::string buf;
std::string size_buf;

Expand All @@ -97,7 +103,6 @@ namespace oxen::quic
std::atomic<int64_t> next_rid{0};

friend struct sent_request;
std::function<void(message)> recv_callback;

public:
template <typename... Opt>
Expand All @@ -113,9 +118,7 @@ namespace oxen::quic
return std::dynamic_pointer_cast<BTRequestStream>(shared_from_this());
}

void request(std::string endpoint, std::string body);

void command(std::string endpoint, std::string body);
void command(std::string endpoint, std::string body, std::function<void(message)> = nullptr);

void respond(int64_t rid, std::string body, bool error = false);

Expand All @@ -125,28 +128,21 @@ namespace oxen::quic

void closed(uint64_t app_code) override;

private:
void handle_bp_opt(std::function<void(message)> recv_cb)
{
log::debug(bp_cat, "Bparser set user-provided recv callback!");
recv_callback = std::move(recv_cb);
}
void register_function(std::string endpoint, std::function<void(message)>);

private:
void handle_bp_opt(std::function<void(Stream&, uint64_t)> close_cb)
{
log::debug(bp_cat, "Bparser set user-provided close callback!");
close_callback = std::move(close_cb);
}

bool match(int64_t rid);

void handle_input(message msg);

void process_incoming(std::string_view req);

std::shared_ptr<sent_request> make_request(std::string endpoint, std::string body);

std::optional<sent_request> make_command(std::string endpoint, std::string body);
std::shared_ptr<sent_request> make_command(
std::string endpoint, std::string body, std::function<void(message)> = nullptr);

std::optional<sent_request> make_response(int64_t rid, std::string body, bool error = false);

Expand Down
105 changes: 41 additions & 64 deletions src/btstream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,45 +10,45 @@ namespace oxen::quic
req_type = btlc.consume_string_view();
req_id = btlc.consume_integer<int64_t>();

if (req_type == "Q" || req_type == "C")
if (req_type == "C")
ep = btlc.consume_string_view();

req_body = btlc.consume_string_view();
}

sent_request::sent_request(BTRequestStream& bp, std::string_view d, int64_t rid) : req_id{rid}, return_sender{bp}
sent_request::sent_request(BTRequestStream& bp, std::string_view d, int64_t rid, std::function<void(message)> f) :
req_id{rid}, cb{std::move(f)}, return_sender{bp}
{
total_len = d.length();
data = oxenc::bt_serialize(d);
req_time = get_time();
timeout = req_time + TIMEOUT;
}

void message::respond(int64_t rid, std::string body, bool error)
void message::respond(std::string body, bool error)
{
log::trace(bp_cat, "{} called", __PRETTY_FUNCTION__);

return_sender.lock()->respond(rid, std::move(body), error);
return_sender.lock()->respond(req_id, std::move(body), error);
}

void BTRequestStream::request(std::string endpoint, std::string body)
void BTRequestStream::command(std::string endpoint, std::string body, std::function<void(message)> func)
{
log::trace(bp_cat, "{} called", __PRETTY_FUNCTION__);

auto req = make_request(std::move(endpoint), std::move(body));
send(req->view());

sent_reqs.push_back(std::move(req));
}

void BTRequestStream::command(std::string endpoint, std::string body)
{
log::trace(bp_cat, "{} called", __PRETTY_FUNCTION__);

auto req = make_command(std::move(endpoint), std::move(body));
auto req = make_command(std::move(endpoint), std::move(body), std::move(func));

if (req)
send(std::move(*req).payload());
{
// if we have a cb, then this is a request; else, it is a command
if (req->cb)
{
send(req->view());
sent_reqs.push_back(std::move(req));
}
else
send(std::move(*req).payload());
}
else
throw std::invalid_argument{"Invalid command!"};
}
Expand All @@ -75,7 +75,7 @@ namespace oxen::quic

if (f->is_expired(now))
{
recv_callback(f->to_message(true));
f->cb(f->to_message(true));
sent_reqs.pop_front();
}
else
Expand Down Expand Up @@ -108,25 +108,9 @@ namespace oxen::quic
close_callback(*this, app_code);
}

bool BTRequestStream::match(int64_t rid)
void BTRequestStream::register_function(std::string endpoint, std::function<void(message)> func)
{
log::trace(bp_cat, "{} called", __PRETTY_FUNCTION__);

// Iterate using forward iterators, s.t. we go highest (newest) rids to lowest (oldest) rids.
// As a result, our comparator checks if the sent request ID is greater thanthan the target rid
auto itr = std::lower_bound(
sent_reqs.begin(), sent_reqs.end(), rid, [](const std::shared_ptr<sent_request>& sr, int64_t rid) {
return sr->req_id > rid;
});

if (itr != sent_reqs.end() and itr->get()->req_id == rid)
{
log::debug(bp_cat, "Successfully matched response to sent request!");
sent_reqs.erase(itr);
return true;
}

return false;
func_map.emplace(std::move(endpoint), std::move(func));
}

void BTRequestStream::handle_input(message msg)
Expand All @@ -135,14 +119,28 @@ namespace oxen::quic

if (msg.req_type == "R" || msg.req_type == "E")
{
if (auto b = match(msg.req_id); not b)
// Iterate using forward iterators, s.t. we go highest (newest) rids to lowest (oldest) rids.
// As a result, our comparator checks if the sent request ID is greater thanthan the target rid
auto itr = std::lower_bound(
sent_reqs.begin(),
sent_reqs.end(),
msg.req_id,
[](const std::shared_ptr<sent_request>& sr, int64_t rid) { return sr->req_id > rid; });

if (itr != sent_reqs.end() and itr->get()->req_id == msg.req_id)
{
log::warning(bp_cat, "Error: could not match orphaned response!");
log::debug(bp_cat, "Successfully matched response to sent request!");
itr->get()->cb(msg);
sent_reqs.erase(itr);
return;
}
}

recv_callback(std::move(msg));
if (auto itr = func_map.find(msg.endpoint_str()); itr != func_map.end())
{
log::debug(bp_cat, "Executing request endpoint {}", msg.endpoint());
itr->second(std::move(msg));
}
}

void BTRequestStream::process_incoming(std::string_view req)
Expand Down Expand Up @@ -211,29 +209,8 @@ namespace oxen::quic
}
}

std::shared_ptr<sent_request> BTRequestStream::make_request(std::string endpoint, std::string body)
{
oxenc::bt_list_producer btlp;
auto rid = ++next_rid;

try
{
btlp.append("Q");
btlp.append(rid);
btlp.append(endpoint);
btlp.append(body);

return std::make_shared<sent_request>(*this, std::move(btlp).str(), rid);
}
catch (...)
{
log::critical(bp_cat, "Invalid outgoing request encoding!");
}

return nullptr;
}

std::optional<sent_request> BTRequestStream::make_command(std::string endpoint, std::string body)
std::shared_ptr<sent_request> BTRequestStream::make_command(
std::string endpoint, std::string body, std::function<void(message)> func)
{
oxenc::bt_list_producer btlp;
auto rid = ++next_rid;
Expand All @@ -245,14 +222,14 @@ namespace oxen::quic
btlp.append(endpoint);
btlp.append(body);

return sent_request{*this, btlp.view(), rid};
return std::make_shared<sent_request>(*this, std::move(btlp).str(), rid, func);
}
catch (...)
{
log::critical(bp_cat, "Invalid outgoing command encoding!");
}

return std::nullopt;
return nullptr;
}

std::optional<sent_request> BTRequestStream::make_response(int64_t rid, std::string body, bool error)
Expand Down
28 changes: 17 additions & 11 deletions tests/002-send-receive.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ namespace oxen::quic::test
}};

stream_constructor_callback server_constructor = [&](Connection& c, Endpoint& e, std::optional<int64_t>) {
return std::make_shared<BTRequestStream>(c, e, server_bp_cb);
auto s = std::make_shared<BTRequestStream>(c, e);
s->register_function("test_endpoint"s, server_bp_cb);
return s;
};

auto server_endpoint = test_net.endpoint(server_local);
Expand All @@ -193,24 +195,26 @@ namespace oxen::quic::test
if (msg)
{
log::info(log_cat, "Server bparser received: {}", msg.view());
msg.respond(msg.rid(), "test_response"s);
msg.respond("test_response"s);
}
}};

auto client_bp_cb = callback_waiter{[&](message msg) {
if (msg)
{
log::info(log_cat, "Client bparser received: {}", msg.view());
msg.respond(msg.rid(), "test_response"s);
msg.respond("test_response"s);
}
}};

stream_constructor_callback server_constructor = [&](Connection& c, Endpoint& e, std::optional<int64_t>) {
return std::make_shared<BTRequestStream>(c, e, server_bp_cb);
auto s = std::make_shared<BTRequestStream>(c, e);
s->register_function("test_endpoint"s, server_bp_cb);
return s;
};

stream_constructor_callback client_constructor = [&](Connection& c, Endpoint& e, std::optional<int64_t>) {
return std::make_shared<BTRequestStream>(c, e, client_bp_cb);
return std::make_shared<BTRequestStream>(c, e);
};

auto server_endpoint = test_net.endpoint(server_local);
Expand All @@ -223,7 +227,7 @@ namespace oxen::quic::test

std::shared_ptr<BTRequestStream> client_bp = conn_interface->get_new_stream<BTRequestStream>();

client_bp->request("test_endpoint"s, "test_request_body"s);
client_bp->command("test_endpoint"s, "test_request_body"s, client_bp_cb);

REQUIRE(server_bp_cb.wait());
REQUIRE(client_bp_cb.wait());
Expand All @@ -235,20 +239,22 @@ namespace oxen::quic::test
if (msg)
{
log::info(log_cat, "Server bparser received: {}", msg.view());
msg.respond(msg.rid(), "test_response"s);
msg.respond("test_response"s);
}
}};

auto client_bp_cb = callback_waiter{[&](message msg) {
if (msg)
{
log::info(log_cat, "Client bparser received: {}", msg.view());
msg.respond(msg.rid(), "test_response"s);
msg.respond("test_response"s);
}
}};

stream_constructor_callback server_constructor = [&](Connection& c, Endpoint& e, std::optional<int64_t>) {
return std::make_shared<BTRequestStream>(c, e, server_bp_cb);
auto s = std::make_shared<BTRequestStream>(c, e);
s->register_function("test_endpoint"s, server_bp_cb);
return s;
};

auto server_endpoint = test_net.endpoint(server_local);
Expand All @@ -259,9 +265,9 @@ namespace oxen::quic::test
auto client_endpoint = test_net.endpoint(client_local);
auto conn_interface = client_endpoint->connect(client_remote, client_tls);

auto client_bp = conn_interface->get_new_stream<BTRequestStream>(client_bp_cb);
auto client_bp = conn_interface->get_new_stream<BTRequestStream>();

client_bp->request("test_endpoint"s, "test_request_body"s);
client_bp->command("test_endpoint"s, "test_request_body"s, client_bp_cb);

REQUIRE(server_bp_cb.wait());
REQUIRE(client_bp_cb.wait());
Expand Down

0 comments on commit 9ed5a59

Please sign in to comment.