Skip to content

Commit

Permalink
Add connection established and closed callbacks
Browse files Browse the repository at this point in the history
Connection established is called when the connection has confirmed the
TLS handshake; server does this on ngtcp2 handshake "complete" after
sending the final packet for the handshake to the client which will then
call ngtcp2 handshake "confirmed".

Connection closed callback is called when a connection closes for any
reason, including that it timed out and was never actually established
(thus acting as a "connection attempt failed" signal as well).

In order to have a connection timeout call connection closed, a
handshake timeout needed to be set in ngtcp2 (the default was
funtionally infinite).  5s should be sufficient, but this can be made
configurable at a later time.

Tests which were confirming connection establishment or failure have
been updated to use the new callbacks.
  • Loading branch information
tewinget committed Sep 6, 2023
1 parent 2425041 commit 26d1fd7
Show file tree
Hide file tree
Showing 8 changed files with 207 additions and 208 deletions.
3 changes: 3 additions & 0 deletions include/quic/connection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ namespace oxen::quic
virtual bool datagrams_enabled() const = 0;
virtual bool packet_splitting_enabled() const = 0;
virtual const ConnectionID& scid() const = 0;
virtual bool close_cb_called() = 0; // return old value and set true

// WIP functions: these are meant to expose specific aspects of the internal state of connection
// and the datagram IO object for debugging and application (user) utilization.
Expand Down Expand Up @@ -168,6 +169,7 @@ namespace oxen::quic

const ConnectionID& scid() const override { return _source_cid; }
const ConnectionID& dcid() const { return _dest_cid; }
bool close_cb_called() override; // return old value and set true

const Path& path() const { return _path; }
const Address& local() const { return _path.local; }
Expand Down Expand Up @@ -211,6 +213,7 @@ namespace oxen::quic
const bool _datagrams_enabled{false};
const bool _packet_splitting{false};
std::atomic<bool> _congested{false};
bool close_cb_was_called{false};

struct connection_deleter
{
Expand Down
17 changes: 17 additions & 0 deletions include/quic/endpoint.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,27 @@ extern "C"

namespace oxen::quic
{
struct connection_established_callback : public std::function<void(connection_interface& conn)>
{
using std::function<void(connection_interface& conn)>::function;
};
struct connection_closed_callback : public std::function<void(connection_interface& conn)> // do we care about reason?
{
using std::function<void(connection_interface& conn)>::function;
};

class Endpoint : std::enable_shared_from_this<Endpoint>
{
private:
void handle_ep_opt(opt::enable_datagrams dc);
void handle_ep_opt(dgram_data_callback dgram_cb);
void handle_ep_opt(connection_established_callback conn_established_cb);
void handle_ep_opt(connection_closed_callback conn_closed_cb);

public:
connection_established_callback on_connection_established;
connection_closed_callback on_connection_closed;

// Non-movable/non-copyable; you must always hold a Endpoint in a shared_ptr
Endpoint(const Endpoint&) = delete;
Endpoint& operator=(const Endpoint&) = delete;
Expand Down Expand Up @@ -190,6 +204,9 @@ namespace oxen::quic
void delete_connection(const ConnectionID& cid);
void drain_connection(Connection& conn);

void connection_established(connection_interface& conn);
void connection_closed(connection_interface& conn);

int _rbufsize{4096};

private:
Expand Down
39 changes: 39 additions & 0 deletions src/connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,36 @@ namespace oxen::quic
return 0;
}

int on_handshake_completed([[maybe_unused]] ngtcp2_conn *conn, void *user_data)
{
auto* conn_ptr = static_cast<Connection*>(user_data);
auto dir_str = conn_ptr->is_inbound() ? "server"s : "client"s;

log::trace(log_cat, "HANDSHAKE COMPLETED on {} connection", dir_str);

// server considers handshake complete and confirmed, and connection established at this point
if (conn_ptr->is_inbound())
conn_ptr->endpoint().connection_established(*conn_ptr);

return 0;
}

int on_handshake_confirmed([[maybe_unused]] ngtcp2_conn *conn, void *user_data)
{
auto* conn_ptr = static_cast<Connection*>(user_data);
auto dir_str = conn_ptr->is_inbound() ? "server"s : "client"s;

log::trace(log_cat, "HANDSHAKE CONFIRMED on {} connection", dir_str);

// server should never call this, as it "confirms" on handshake completed
assert(conn_ptr->is_outbound());

// client considers handshake complete and confirmed, and connection established at this point
conn_ptr->endpoint().connection_established(*conn_ptr);

return 0;
}

void rand_cb(uint8_t* dest, size_t destlen, const ngtcp2_rand_ctx* rand_ctx)
{
(void)rand_ctx;
Expand Down Expand Up @@ -396,6 +426,12 @@ namespace oxen::quic
return context->stream_data_cb;
}

bool Connection::close_cb_called() {
bool b = close_cb_was_called;
close_cb_was_called = true;
return b;
}

void Connection::on_packet_io_ready()
{
auto ts = get_time();
Expand Down Expand Up @@ -1061,6 +1097,8 @@ namespace oxen::quic
callbacks.get_path_challenge_data = ngtcp2_crypto_get_path_challenge_data_cb;
callbacks.version_negotiation = ngtcp2_crypto_version_negotiation_cb;
callbacks.stream_open = on_stream_open;
callbacks.handshake_completed = on_handshake_completed;
callbacks.handshake_confirmed = on_handshake_confirmed;

ngtcp2_settings_default(&settings);

Expand All @@ -1073,6 +1111,7 @@ namespace oxen::quic
settings.initial_rtt = NGTCP2_DEFAULT_INITIAL_RTT;
settings.max_window = 24_Mi;
settings.max_stream_window = 16_Mi;
settings.handshake_timeout = std::chrono::nanoseconds(5s).count();

ngtcp2_transport_params_default(&params);

Expand Down
41 changes: 41 additions & 0 deletions src/endpoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,18 @@ namespace oxen::quic
dgram_recv_cb = std::move(func);
}

void Endpoint::handle_ep_opt(connection_established_callback conn_established_cb)
{
log::trace(log_cat, "Endpoint given connection established callback");
on_connection_established = std::move(conn_established_cb);
}

void Endpoint::handle_ep_opt(connection_closed_callback conn_closed_cb)
{
log::trace(log_cat, "Endpoint given connection closed callback");
on_connection_closed = std::move(conn_closed_cb);
}

void Endpoint::_init_internals()
{
log::debug(log_cat, "Starting new UDP socket on {}", _local);
Expand Down Expand Up @@ -102,6 +114,8 @@ namespace oxen::quic

void Endpoint::drain_connection(Connection& conn)
{
connection_closed(conn);

if (conn.is_draining())
return;

Expand Down Expand Up @@ -158,6 +172,8 @@ namespace oxen::quic
{
log::debug(log_cat, "Closing connection (CID: {})", *conn.scid().data);

connection_closed(conn);

if (conn.is_closing() || conn.is_draining())
return;

Expand Down Expand Up @@ -227,6 +243,7 @@ namespace oxen::quic
{
if (auto itr = conns.find(cid); itr != conns.end())
{
connection_closed(*(itr->second));
itr->second->call_close_cb();

conns.erase(itr);
Expand All @@ -236,6 +253,30 @@ namespace oxen::quic
log::warning(log_cat, "Error: could not delete connection [ID: {}]; could not find", *cid.data);
}

void Endpoint::connection_established(connection_interface& conn)
{
log::trace(log_cat, "Connection established, calling user callback [ID: {}]", conn.scid());
if (on_connection_established)
on_connection_established(conn);
}

// closing, "is closed", "is draining", etc are a little messy and calling
// the user callback for a close more than once feels bad, so this first
// checks if it has been called on that connection
void Endpoint::connection_closed(connection_interface& conn)
{
if (conn.close_cb_called())
return;

if (on_connection_closed)
{
log::trace(log_cat, "Connection closed, calling user callback [ID: {}]", conn.scid());
on_connection_closed(conn);
}
else
log::trace(log_cat, "Connection closed, but not calling user callback (not set) [ID: {}]", conn.scid());
}

std::optional<ConnectionID> Endpoint::handle_packet_connid(const Packet& pkt)
{
ngtcp2_version_cid vid;
Expand Down
72 changes: 30 additions & 42 deletions tests/001-handshake.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,34 +114,27 @@ namespace oxen::quic::test
{
Network test_net{};

std::promise<bool> tls;
std::future<bool> tls_future = tls.get_future();

gnutls_callback outbound_tls_cb =
[&](gnutls_session_t, unsigned int, unsigned int, unsigned int, const gnutls_datum_t*) {
log::debug(log_cat, "Calling client TLS callback... handshake completed...");

tls.set_value(true);
return 0;
};

auto server_tls = GNUTLSCreds::make("./serverkey.pem"s, "./servercert.pem"s, "./clientcert.pem"s);
auto client_tls = GNUTLSCreds::make("./clientkey.pem"s, "./clientcert.pem"s, "./servercert.pem"s);
client_tls->set_client_tls_policy(outbound_tls_cb);

opt::local_addr server_local{};
opt::local_addr client_local{};

auto server_endpoint = test_net.endpoint(server_local);
bool_waiter<connection_established_callback> server_established;
auto server_endpoint = test_net.endpoint(server_local, server_established.func());
REQUIRE(server_endpoint->listen(server_tls));

opt::remote_addr client_remote{"::1"s, server_endpoint->local().port()};

auto client_endpoint = test_net.endpoint(client_local);
bool_waiter<connection_established_callback> client_established;
auto client_endpoint = test_net.endpoint(client_local, client_established.func());

REQUIRE_NOTHROW(client_endpoint->connect(client_remote, client_tls));

REQUIRE(tls_future.get());
REQUIRE(client_established.wait_ready());
REQUIRE(server_established.is_ready());
REQUIRE(server_established.get() == true);
REQUIRE(client_established.get() == true);
};
};

Expand All @@ -151,66 +144,61 @@ namespace oxen::quic::test
{
Network test_net{};

std::promise<bool> tls;
std::future<bool> tls_future = tls.get_future();
std::atomic<bool> success = false;
std::promise<bool> conn_closed;
auto f = conn_closed.get_future();

gnutls_callback outbound_tls_cb =
[&](gnutls_session_t, unsigned int, unsigned int, unsigned int, const gnutls_datum_t*) {
log::debug(log_cat, "Calling client TLS callback... handshake completed...");
connection_closed_callback closed_cb = [&conn_closed](auto&&...){
conn_closed.set_value(true);
};

tls.set_value(true);
return 0;
};
connection_established_callback established_cb = [&success](auto&&...){
success = true;
};

auto server_tls = GNUTLSCreds::make("./serverkey.pem"s, "./servercert.pem"s, "./clientcert.pem"s);
auto client_tls = GNUTLSCreds::make("./clientkey.pem"s, "./clientcert.pem"s, "./servercert.pem"s);
client_tls->set_client_tls_policy(outbound_tls_cb);

opt::local_addr server_local{};
opt::local_addr client_local{};

auto server_endpoint = test_net.endpoint(server_local);
auto server_endpoint = test_net.endpoint(server_local, established_cb);
REQUIRE_THROWS(server_endpoint->listen());

opt::remote_addr client_remote{"127.0.0.1"s, server_endpoint->local().port()};

auto client_endpoint = test_net.endpoint(client_local);
auto client_endpoint = test_net.endpoint(client_local, std::move(closed_cb), established_cb);
auto conn_interface = client_endpoint->connect(client_remote, client_tls);

REQUIRE(tls_future.valid());
REQUIRE(f.wait_for(10s) == std::future_status::ready);
REQUIRE(f.get() == true);
REQUIRE(success == false);
};

SECTION("Successful TLS handshake")
{
Network test_net{};

std::promise<bool> tls;
std::future<bool> tls_future = tls.get_future();

gnutls_callback outbound_tls_cb =
[&](gnutls_session_t, unsigned int, unsigned int, unsigned int, const gnutls_datum_t*) {
log::debug(log_cat, "Calling client TLS callback... handshake completed...");

tls.set_value(true);
return 0;
};

auto server_tls = GNUTLSCreds::make("./serverkey.pem"s, "./servercert.pem"s, "./clientcert.pem"s);
auto client_tls = GNUTLSCreds::make("./clientkey.pem"s, "./clientcert.pem"s, "./servercert.pem"s);
client_tls->set_client_tls_policy(outbound_tls_cb);

opt::local_addr server_local{};
opt::local_addr client_local{};

auto server_endpoint = test_net.endpoint(server_local);
bool_waiter<connection_established_callback> server_established;
auto server_endpoint = test_net.endpoint(server_local, server_established.func());
REQUIRE(server_endpoint->listen(server_tls));

opt::remote_addr client_remote{"127.0.0.1"s, server_endpoint->local().port()};

auto client_endpoint = test_net.endpoint(client_local);
bool_waiter<connection_established_callback> client_established;
auto client_endpoint = test_net.endpoint(client_local, client_established.func());
auto conn_interface = client_endpoint->connect(client_remote, client_tls);

REQUIRE(tls_future.get());
REQUIRE(client_established.wait_ready());
REQUIRE(server_established.is_ready());
REQUIRE(server_established.get() == true);
REQUIRE(client_established.get() == true);
};
};
} // namespace oxen::quic::test

0 comments on commit 26d1fd7

Please sign in to comment.