Skip to content

Commit

Permalink
Hold sockets in unique_ptrs to avoid the need for variant::emplace()
Browse files Browse the repository at this point in the history
  • Loading branch information
loonycyborg committed Mar 11, 2021
1 parent 9e45434 commit dd0a322
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 33 deletions.
29 changes: 15 additions & 14 deletions src/network_asio.cpp
Expand Up @@ -69,7 +69,7 @@ connection::connection(const std::string& host, const std::string& service)
, host_(host)
, service_(service)
, resolver_(io_context_)
, socket_(raw_socket{io_context_})
, socket_(raw_socket(new raw_socket::element_type{io_context_}))
, done_(false)
, write_buf_()
, read_buf_()
Expand Down Expand Up @@ -99,7 +99,7 @@ void connection::handle_resolve(const boost::system::error_code& ec, results_typ
throw system_error(ec);
}

boost::asio::async_connect(utils::get<raw_socket>(socket_), results,
boost::asio::async_connect(*utils::get<raw_socket>(socket_), results,
std::bind(&connection::handle_connect, this, std::placeholders::_1, std::placeholders::_2));
}

Expand Down Expand Up @@ -127,12 +127,12 @@ void connection::handshake()
static const uint32_t tls_handshake = htonl(uint32_t(1));

boost::asio::async_write(
utils::get<raw_socket>(socket_),
*utils::get<raw_socket>(socket_),
boost::asio::buffer(use_tls_ ? reinterpret_cast<const char*>(&tls_handshake) : reinterpret_cast<const char*>(&handshake), 4),
std::bind(&connection::handle_write, this, std::placeholders::_1, std::placeholders::_2)
);

boost::asio::async_read(utils::get<raw_socket>(socket_), boost::asio::buffer(&handshake_response_.binary, 4),
boost::asio::async_read(*utils::get<raw_socket>(socket_), boost::asio::buffer(&handshake_response_.binary, 4),
std::bind(&connection::handle_handshake, this, std::placeholders::_1));
}

Expand All @@ -158,9 +158,10 @@ void connection::handle_handshake(const boost::system::error_code& ec)
if(handshake_response_.num == 0x00000000) {
tls_context_.set_default_verify_paths();
raw_socket s { std::move(utils::get<raw_socket>(socket_)) };
socket_.emplace<1>(std::move(s), tls_context_);

auto& socket { utils::get<tls_socket>(socket_) };
tls_socket ts { new tls_socket::element_type { std::move(*s), tls_context_ } };
socket_ = std::move(ts);

auto& socket { *utils::get<tls_socket>(socket_) };

socket.set_verify_mode(
boost::asio::ssl::verify_peer |
Expand Down Expand Up @@ -194,10 +195,10 @@ void connection::fallback_to_unencrypted()
assert(use_tls_ == true);
use_tls_ = false;

boost::asio::ip::tcp::endpoint endpoint { utils::get<raw_socket>(socket_).remote_endpoint() };
utils::get<raw_socket>(socket_).close();
boost::asio::ip::tcp::endpoint endpoint { utils::get<raw_socket>(socket_)->remote_endpoint() };
utils::get<raw_socket>(socket_)->close();

utils::get<raw_socket>(socket_).async_connect(endpoint,
utils::get<raw_socket>(socket_)->async_connect(endpoint,
std::bind(&connection::handle_connect, this, std::placeholders::_1, endpoint));
}

Expand All @@ -223,11 +224,11 @@ void connection::transfer(const config& request, config& response)
bufs.push_front(boost::asio::buffer(reinterpret_cast<const char*>(&payload_size_), 4));

utils::visit([this, &bufs, &response](auto&& socket) {
boost::asio::async_write(socket, bufs,
boost::asio::async_write(*socket, bufs,
std::bind(&connection::is_write_complete, this, std::placeholders::_1, std::placeholders::_2),
std::bind(&connection::handle_write, this, std::placeholders::_1, std::placeholders::_2));

boost::asio::async_read(socket, *read_buf_,
boost::asio::async_read(*socket, *read_buf_,
std::bind(&connection::is_read_complete, this, std::placeholders::_1, std::placeholders::_2),
std::bind(&connection::handle_read, this, std::placeholders::_1, std::placeholders::_2, std::ref(response)));
}, socket_);
Expand All @@ -236,7 +237,7 @@ void connection::transfer(const config& request, config& response)
void connection::cancel()
{
utils::visit([](auto&& socket) {
if(socket.lowest_layer().is_open()) {
if(socket->lowest_layer().is_open()) {
boost::system::error_code ec;

#ifdef _MSC_VER
Expand All @@ -245,7 +246,7 @@ void connection::cancel()
#pragma warning(push)
#pragma warning(disable:4996)
#endif
socket.lowest_layer().cancel(ec);
socket->lowest_layer().cancel(ec);
#ifdef _MSC_VER
#pragma warning(pop)
#endif
Expand Down
4 changes: 2 additions & 2 deletions src/network_asio.hpp
Expand Up @@ -139,8 +139,8 @@ class connection

boost::asio::ssl::context tls_context_ { boost::asio::ssl::context::sslv23 };

typedef boost::asio::ip::tcp::socket raw_socket;
typedef boost::asio::ssl::stream<raw_socket> tls_socket;
typedef std::unique_ptr<boost::asio::ip::tcp::socket> raw_socket;
typedef std::unique_ptr<boost::asio::ssl::stream<raw_socket::element_type>> tls_socket;
typedef utils::variant<raw_socket, tls_socket> any_socket;
bool use_tls_ = true;
any_socket socket_;
Expand Down
31 changes: 16 additions & 15 deletions src/wesnothd_connection.cpp
Expand Up @@ -63,7 +63,7 @@ wesnothd_connection::wesnothd_connection(const std::string& host, const std::str
, tls_context_(boost::asio::ssl::context::sslv23)
, host_(host)
, service_(service)
, socket_(raw_socket{io_context_})
, socket_(raw_socket{ new raw_socket::element_type{io_context_} })
, last_error_()
, last_error_mutex_()
, handshake_finished_()
Expand Down Expand Up @@ -117,10 +117,10 @@ wesnothd_connection::~wesnothd_connection()
if(auto socket = utils::get_if<tls_socket>(&socket_)) {
error_code ec;
// this sends close_notify for secure connection shutdown
socket->async_shutdown([](const error_code&) {} );
(*socket)->async_shutdown([](const error_code&) {} );
const char buffer[] = "";
// this write is needed to trigger immediate close instead of waiting for other side's close_notify
boost::asio::write(*socket, boost::asio::buffer(buffer, 0), ec);
boost::asio::write(**socket, boost::asio::buffer(buffer, 0), ec);
}
// Stop the io_service and wait for the worker thread to terminate.
stop();
Expand All @@ -136,7 +136,7 @@ void wesnothd_connection::handle_resolve(const error_code& ec, results_type resu
throw system_error(ec);
}

boost::asio::async_connect(utils::get<raw_socket>(socket_), results,
boost::asio::async_connect(*utils::get<raw_socket>(socket_), results,
std::bind(&wesnothd_connection::handle_connect, this, std::placeholders::_1, std::placeholders::_2));
}

Expand Down Expand Up @@ -167,9 +167,9 @@ void wesnothd_connection::handshake()
static const uint32_t handshake = 0;
static const uint32_t tls_handshake = htonl(uint32_t(1));

boost::asio::async_write(utils::get<raw_socket>(socket_), boost::asio::buffer(use_tls_ ? reinterpret_cast<const char*>(&tls_handshake) : reinterpret_cast<const char*>(&handshake), 4),
boost::asio::async_write(*utils::get<raw_socket>(socket_), boost::asio::buffer(use_tls_ ? reinterpret_cast<const char*>(&tls_handshake) : reinterpret_cast<const char*>(&handshake), 4),
[](const error_code& ec, std::size_t) { if(ec) { throw system_error(ec); } });
boost::asio::async_read(utils::get<raw_socket>(socket_), boost::asio::buffer(&handshake_response_.binary, 4),
boost::asio::async_read(*utils::get<raw_socket>(socket_), boost::asio::buffer(&handshake_response_.binary, 4),
std::bind(&wesnothd_connection::handle_handshake, this, std::placeholders::_1));
}

Expand Down Expand Up @@ -197,9 +197,10 @@ void wesnothd_connection::handle_handshake(const error_code& ec)
if(handshake_response_.num == 0x00000000) {
tls_context_.set_default_verify_paths();
raw_socket s { std::move(utils::get<raw_socket>(socket_)) };
socket_.emplace<1>(std::move(s), tls_context_);
tls_socket ts { new tls_socket::element_type{std::move(*s), tls_context_} };
socket_ = std::move(ts);

auto& socket { utils::get<tls_socket>(socket_) };
auto& socket { *utils::get<tls_socket>(socket_) };

socket.set_verify_mode(
boost::asio::ssl::verify_peer |
Expand Down Expand Up @@ -237,10 +238,10 @@ void wesnothd_connection::fallback_to_unencrypted()
assert(use_tls_ == true);
use_tls_ = false;

boost::asio::ip::tcp::endpoint endpoint { utils::get<raw_socket>(socket_).remote_endpoint() };
utils::get<raw_socket>(socket_).close();
boost::asio::ip::tcp::endpoint endpoint { utils::get<raw_socket>(socket_)->remote_endpoint() };
utils::get<raw_socket>(socket_)->close();

utils::get<raw_socket>(socket_).async_connect(endpoint,
utils::get<raw_socket>(socket_)->async_connect(endpoint,
std::bind(&wesnothd_connection::handle_connect, this, std::placeholders::_1, endpoint));
}

Expand Down Expand Up @@ -300,7 +301,7 @@ void wesnothd_connection::cancel()
{
MPTEST_LOG;
utils::visit([](auto&& socket) {
if(socket.lowest_layer().is_open()) {
if(socket->lowest_layer().is_open()) {
boost::system::error_code ec;

#ifdef _MSC_VER
Expand All @@ -309,7 +310,7 @@ void wesnothd_connection::cancel()
#pragma warning(push)
#pragma warning(disable:4996)
#endif
socket.lowest_layer().cancel(ec);
socket->lowest_layer().cancel(ec);
#ifdef _MSC_VER
#pragma warning(pop)
#endif
Expand Down Expand Up @@ -463,7 +464,7 @@ void wesnothd_connection::send()
};

utils::visit([this, &bufs](auto&& socket) {
boost::asio::async_write(socket, bufs,
boost::asio::async_write(*socket, bufs,
std::bind(&wesnothd_connection::is_write_complete, this, std::placeholders::_1, std::placeholders::_2),
std::bind(&wesnothd_connection::handle_write, this, std::placeholders::_1, std::placeholders::_2));
}, socket_);
Expand All @@ -475,7 +476,7 @@ void wesnothd_connection::recv()
MPTEST_LOG;

utils::visit([this](auto&& socket) {
boost::asio::async_read(socket, read_buf_,
boost::asio::async_read(*socket, read_buf_,
std::bind(&wesnothd_connection::is_read_complete, this, std::placeholders::_1, std::placeholders::_2),
std::bind(&wesnothd_connection::handle_read, this, std::placeholders::_1, std::placeholders::_2));
}, socket_);
Expand Down
4 changes: 2 additions & 2 deletions src/wesnothd_connection.hpp
Expand Up @@ -152,8 +152,8 @@ class wesnothd_connection

std::string host_;
std::string service_;
typedef boost::asio::ip::tcp::socket raw_socket;
typedef boost::asio::ssl::stream<raw_socket> tls_socket;
typedef std::unique_ptr<boost::asio::ip::tcp::socket> raw_socket;
typedef std::unique_ptr<boost::asio::ssl::stream<raw_socket::element_type>> tls_socket;
typedef utils::variant<raw_socket, tls_socket> any_socket;
bool use_tls_ = true;
any_socket socket_;
Expand Down

0 comments on commit dd0a322

Please sign in to comment.