Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revisit bt request stream api #51

Merged
merged 1 commit into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
36 changes: 17 additions & 19 deletions include/quic/btstream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@ namespace oxen::quic
std::string_view ep;
std::string_view req_body;
std::weak_ptr<BTRequestStream> return_sender;
bool timed_out{false};

public:
message(BTRequestStream& bp, std::string req, bool is_error = false);

void respond(int64_t rid, std::string body, bool error = false);
void respond(std::string body, bool error = false);

bool timed_out{false};
bool is_error{false};

// To be used to determine if the message was a result of an error as such:
//
Expand All @@ -48,21 +50,24 @@ namespace oxen::quic
// if (m)
// { // success logic }
// }
operator bool() const { return not timed_out; }
operator bool() const { return not timed_out && not is_error; }

std::string_view view() const { return {data}; }

int64_t rid() const { return req_id; }
std::string_view type() const { return req_type; }
std::string_view endpoint() const { return ep; }
std::string_view body() const { return req_body; }
std::string endpoint_str() const { return std::string{ep}; }
std::string body_str() const { return std::string{req_body}; }
};

struct sent_request
{
// parsed request data
int64_t req_id;
std::string data;
std::function<void(message)> cb;
BTRequestStream& return_sender;

// total length of the request; is at the beginning of the request
Expand All @@ -73,7 +78,8 @@ namespace oxen::quic

bool is_empty() const { return data.empty() && total_len == 0; }

explicit sent_request(BTRequestStream& bp, std::string_view d, int64_t rid);
explicit sent_request(
BTRequestStream& bp, std::string_view d, int64_t rid, std::function<void(message)> f = nullptr);

bool is_expired(std::chrono::steady_clock::time_point tp) const { return timeout < tp; }

Expand All @@ -89,6 +95,8 @@ namespace oxen::quic
// outgoing requests awaiting response
std::deque<std::shared_ptr<sent_request>> sent_reqs;

std::unordered_map<std::string, std::function<void(message)>> func_map;

std::string buf;
std::string size_buf;

Expand All @@ -97,7 +105,6 @@ namespace oxen::quic
std::atomic<int64_t> next_rid{0};

friend struct sent_request;
std::function<void(message)> recv_callback;

public:
template <typename... Opt>
Expand All @@ -113,9 +120,7 @@ namespace oxen::quic
return std::dynamic_pointer_cast<BTRequestStream>(shared_from_this());
}

void request(std::string endpoint, std::string body);

void command(std::string endpoint, std::string body);
void command(std::string endpoint, std::string body, std::function<void(message)> = nullptr);

void respond(int64_t rid, std::string body, bool error = false);

Expand All @@ -125,28 +130,21 @@ namespace oxen::quic

void closed(uint64_t app_code) override;

private:
void handle_bp_opt(std::function<void(message)> recv_cb)
{
log::debug(bp_cat, "Bparser set user-provided recv callback!");
recv_callback = std::move(recv_cb);
}
void register_command(std::string endpoint, std::function<void(message)>);

private:
void handle_bp_opt(std::function<void(Stream&, uint64_t)> close_cb)
{
log::debug(bp_cat, "Bparser set user-provided close callback!");
close_callback = std::move(close_cb);
}

bool match(int64_t rid);

void handle_input(message msg);

void process_incoming(std::string_view req);

std::shared_ptr<sent_request> make_request(std::string endpoint, std::string body);

std::optional<sent_request> make_command(std::string endpoint, std::string body);
std::shared_ptr<sent_request> make_command(
std::string endpoint, std::string body, std::function<void(message)> = nullptr);

std::optional<sent_request> make_response(int64_t rid, std::string body, bool error = false);

Expand Down
109 changes: 45 additions & 64 deletions src/btstream.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "btstream.hpp"

#include "endpoint.hpp"

namespace oxen::quic
{
message::message(BTRequestStream& bp, std::string req, bool is_error) :
Expand All @@ -10,45 +12,47 @@ namespace oxen::quic
req_type = btlc.consume_string_view();
req_id = btlc.consume_integer<int64_t>();

if (req_type == "Q" || req_type == "C")
if (req_type == "C")
ep = btlc.consume_string_view();
else if (req_type == "E")
is_error = true;

req_body = btlc.consume_string_view();
}

sent_request::sent_request(BTRequestStream& bp, std::string_view d, int64_t rid) : req_id{rid}, return_sender{bp}
sent_request::sent_request(BTRequestStream& bp, std::string_view d, int64_t rid, std::function<void(message)> f) :
req_id{rid}, cb{std::move(f)}, return_sender{bp}
{
total_len = d.length();
data = oxenc::bt_serialize(d);
req_time = get_time();
timeout = req_time + TIMEOUT;
}

void message::respond(int64_t rid, std::string body, bool error)
void message::respond(std::string body, bool error)
{
log::trace(bp_cat, "{} called", __PRETTY_FUNCTION__);

return_sender.lock()->respond(rid, std::move(body), error);
return_sender.lock()->respond(req_id, std::move(body), error);
}

void BTRequestStream::request(std::string endpoint, std::string body)
void BTRequestStream::command(std::string endpoint, std::string body, std::function<void(message)> func)
{
log::trace(bp_cat, "{} called", __PRETTY_FUNCTION__);

auto req = make_request(std::move(endpoint), std::move(body));
send(req->view());

sent_reqs.push_back(std::move(req));
}

void BTRequestStream::command(std::string endpoint, std::string body)
{
log::trace(bp_cat, "{} called", __PRETTY_FUNCTION__);

auto req = make_command(std::move(endpoint), std::move(body));
auto req = make_command(std::move(endpoint), std::move(body), std::move(func));

if (req)
send(std::move(*req).payload());
{
// if we have a cb, then this is a request; else, it is a command
if (req->cb)
{
send(req->view());
sent_reqs.push_back(std::move(req));
}
else
send(std::move(*req).payload());
}
else
throw std::invalid_argument{"Invalid command!"};
}
Expand All @@ -75,7 +79,7 @@ namespace oxen::quic

if (f->is_expired(now))
{
recv_callback(f->to_message(true));
f->cb(f->to_message(true));
sent_reqs.pop_front();
}
else
Expand Down Expand Up @@ -108,25 +112,9 @@ namespace oxen::quic
close_callback(*this, app_code);
}

bool BTRequestStream::match(int64_t rid)
void BTRequestStream::register_command(std::string ep, std::function<void(message)> func)
{
log::trace(bp_cat, "{} called", __PRETTY_FUNCTION__);

// Iterate using forward iterators, s.t. we go highest (newest) rids to lowest (oldest) rids.
// As a result, our comparator checks if the sent request ID is greater thanthan the target rid
auto itr = std::lower_bound(
sent_reqs.begin(), sent_reqs.end(), rid, [](const std::shared_ptr<sent_request>& sr, int64_t rid) {
return sr->req_id > rid;
});

if (itr != sent_reqs.end() and itr->get()->req_id == rid)
{
log::debug(bp_cat, "Successfully matched response to sent request!");
sent_reqs.erase(itr);
return true;
}

return false;
endpoint.call([&]() { func_map[std::move(ep)] = std::move(func); });
}

void BTRequestStream::handle_input(message msg)
Expand All @@ -135,14 +123,28 @@ namespace oxen::quic

if (msg.req_type == "R" || msg.req_type == "E")
{
if (auto b = match(msg.req_id); not b)
// Iterate using forward iterators, s.t. we go highest (newest) rids to lowest (oldest) rids.
// As a result, our comparator checks if the sent request ID is greater thanthan the target rid
auto itr = std::lower_bound(
sent_reqs.begin(),
sent_reqs.end(),
msg.req_id,
[](const std::shared_ptr<sent_request>& sr, int64_t rid) { return sr->req_id > rid; });

if (itr != sent_reqs.end() and itr->get()->req_id == msg.req_id)
{
log::warning(bp_cat, "Error: could not match orphaned response!");
log::debug(bp_cat, "Successfully matched response to sent request!");
itr->get()->cb(msg);
sent_reqs.erase(itr);
return;
}
}

recv_callback(std::move(msg));
if (auto itr = func_map.find(msg.endpoint_str()); itr != func_map.end())
{
log::debug(bp_cat, "Executing request endpoint {}", msg.endpoint());
itr->second(std::move(msg));
}
}

void BTRequestStream::process_incoming(std::string_view req)
Expand Down Expand Up @@ -211,29 +213,8 @@ namespace oxen::quic
}
}

std::shared_ptr<sent_request> BTRequestStream::make_request(std::string endpoint, std::string body)
{
oxenc::bt_list_producer btlp;
auto rid = ++next_rid;

try
{
btlp.append("Q");
btlp.append(rid);
btlp.append(endpoint);
btlp.append(body);

return std::make_shared<sent_request>(*this, std::move(btlp).str(), rid);
}
catch (...)
{
log::critical(bp_cat, "Invalid outgoing request encoding!");
}

return nullptr;
}

std::optional<sent_request> BTRequestStream::make_command(std::string endpoint, std::string body)
std::shared_ptr<sent_request> BTRequestStream::make_command(
std::string endpoint, std::string body, std::function<void(message)> func)
{
oxenc::bt_list_producer btlp;
auto rid = ++next_rid;
Expand All @@ -245,14 +226,14 @@ namespace oxen::quic
btlp.append(endpoint);
btlp.append(body);

return sent_request{*this, btlp.view(), rid};
return std::make_shared<sent_request>(*this, std::move(btlp).str(), rid, func);
}
catch (...)
{
log::critical(bp_cat, "Invalid outgoing command encoding!");
}

return std::nullopt;
return nullptr;
}

std::optional<sent_request> BTRequestStream::make_response(int64_t rid, std::string body, bool error)
Expand Down
28 changes: 17 additions & 11 deletions tests/002-send-receive.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ namespace oxen::quic::test
}};

stream_constructor_callback server_constructor = [&](Connection& c, Endpoint& e, std::optional<int64_t>) {
return std::make_shared<BTRequestStream>(c, e, server_bp_cb);
auto s = std::make_shared<BTRequestStream>(c, e);
s->register_command("test_endpoint"s, server_bp_cb);
return s;
};

auto server_endpoint = test_net.endpoint(server_local);
Expand All @@ -193,24 +195,26 @@ namespace oxen::quic::test
if (msg)
{
log::info(log_cat, "Server bparser received: {}", msg.view());
msg.respond(msg.rid(), "test_response"s);
msg.respond("test_response"s);
}
}};

auto client_bp_cb = callback_waiter{[&](message msg) {
if (msg)
{
log::info(log_cat, "Client bparser received: {}", msg.view());
msg.respond(msg.rid(), "test_response"s);
msg.respond("test_response"s);
}
}};

stream_constructor_callback server_constructor = [&](Connection& c, Endpoint& e, std::optional<int64_t>) {
return std::make_shared<BTRequestStream>(c, e, server_bp_cb);
auto s = std::make_shared<BTRequestStream>(c, e);
s->register_command("test_endpoint"s, server_bp_cb);
return s;
};

stream_constructor_callback client_constructor = [&](Connection& c, Endpoint& e, std::optional<int64_t>) {
return std::make_shared<BTRequestStream>(c, e, client_bp_cb);
return std::make_shared<BTRequestStream>(c, e);
};

auto server_endpoint = test_net.endpoint(server_local);
Expand All @@ -223,7 +227,7 @@ namespace oxen::quic::test

std::shared_ptr<BTRequestStream> client_bp = conn_interface->get_new_stream<BTRequestStream>();

client_bp->request("test_endpoint"s, "test_request_body"s);
client_bp->command("test_endpoint"s, "test_request_body"s, client_bp_cb);

REQUIRE(server_bp_cb.wait());
REQUIRE(client_bp_cb.wait());
Expand All @@ -235,20 +239,22 @@ namespace oxen::quic::test
if (msg)
{
log::info(log_cat, "Server bparser received: {}", msg.view());
msg.respond(msg.rid(), "test_response"s);
msg.respond("test_response"s);
}
}};

auto client_bp_cb = callback_waiter{[&](message msg) {
if (msg)
{
log::info(log_cat, "Client bparser received: {}", msg.view());
msg.respond(msg.rid(), "test_response"s);
msg.respond("test_response"s);
}
}};

stream_constructor_callback server_constructor = [&](Connection& c, Endpoint& e, std::optional<int64_t>) {
return std::make_shared<BTRequestStream>(c, e, server_bp_cb);
auto s = std::make_shared<BTRequestStream>(c, e);
s->register_command("test_endpoint"s, server_bp_cb);
return s;
};

auto server_endpoint = test_net.endpoint(server_local);
Expand All @@ -259,9 +265,9 @@ namespace oxen::quic::test
auto client_endpoint = test_net.endpoint(client_local);
auto conn_interface = client_endpoint->connect(client_remote, client_tls);

auto client_bp = conn_interface->get_new_stream<BTRequestStream>(client_bp_cb);
auto client_bp = conn_interface->get_new_stream<BTRequestStream>();

client_bp->request("test_endpoint"s, "test_request_body"s);
client_bp->command("test_endpoint"s, "test_request_body"s, client_bp_cb);

REQUIRE(server_bp_cb.wait());
REQUIRE(client_bp_cb.wait());
Expand Down