Skip to content

Commit

Permalink
Add TLS support to addon client
Browse files Browse the repository at this point in the history
  • Loading branch information
loonycyborg committed Feb 17, 2021
1 parent b1532ea commit 09ea0da
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 26 deletions.
116 changes: 92 additions & 24 deletions src/network_asio.cpp
Expand Up @@ -66,8 +66,10 @@ using boost::system::system_error;

connection::connection(const std::string& host, const std::string& service)
: io_context_()
, host_(host)
, service_(service)
, resolver_(io_context_)
, socket_(io_context_)
, socket_(raw_socket{io_context_})
, done_(false)
, write_buf_()
, read_buf_()
Expand All @@ -78,12 +80,15 @@ connection::connection(const std::string& host, const std::string& service)
, bytes_to_read_(0)
, bytes_read_(0)
{
#if BOOST_VERSION >= 106600
resolver_.async_resolve(host, service,
#else
resolver_.async_resolve(boost::asio::ip::tcp::resolver::query(host, service),
#endif
std::bind(&connection::handle_resolve, this, std::placeholders::_1, std::placeholders::_2));
boost::system::error_code ec;
auto result = resolver_.resolve(host, service, boost::asio::ip::resolver_query_base::numeric_host, ec);
if(!ec) { // if numeric resolve succeeds then we got raw ip address so TLS host name validation would never pass
use_tls_ = false;
boost::asio::post(io_context_, [this, ec, result](){ handle_resolve(ec, { result } ); } );
} else {
resolver_.async_resolve(host, service,
std::bind(&connection::handle_resolve, this, std::placeholders::_1, std::placeholders::_2));
}

LOG_NW << "Resolving hostname: " << host << '\n';
}
Expand All @@ -94,7 +99,7 @@ void connection::handle_resolve(const boost::system::error_code& ec, results_typ
throw system_error(ec);
}

boost::asio::async_connect(socket_, results,
boost::asio::async_connect(utils::get<raw_socket>(socket_), results,
std::bind(&connection::handle_connect, this, std::placeholders::_1, std::placeholders::_2));
}

Expand All @@ -109,28 +114,87 @@ void connection::handle_connect(const boost::system::error_code& ec, endpoint en
#else
LOG_NW << "Connected to " << endpoint->endpoint().address() << '\n';
#endif
if(endpoint.address().is_loopback()) {
use_tls_ = false;
}
handshake();
}
}

void connection::handshake()
{
static const uint32_t handshake = 0;
static const uint32_t tls_handshake = htonl(uint32_t(1));

boost::asio::async_write(socket_, boost::asio::buffer(reinterpret_cast<const char*>(&handshake), 4),
std::bind(&connection::handle_write, this, std::placeholders::_1, std::placeholders::_2));
boost::asio::async_write(
utils::get<raw_socket>(socket_),
boost::asio::buffer(use_tls_ ? reinterpret_cast<const char*>(&tls_handshake) : reinterpret_cast<const char*>(&handshake), 4),
std::bind(&connection::handle_write, this, std::placeholders::_1, std::placeholders::_2)
);

boost::asio::async_read(socket_, boost::asio::buffer(&handshake_response_.binary, 4),
boost::asio::async_read(utils::get<raw_socket>(socket_), boost::asio::buffer(&handshake_response_.binary, 4),
std::bind(&connection::handle_handshake, this, std::placeholders::_1));
}

void connection::handle_handshake(const boost::system::error_code& ec)
{
if(ec) {
if(ec == boost::asio::error::eof && use_tls_) {
// immediate disconnect likely means old server not supporting TLS handshake code
fallback_to_unencrypted();
return;
}

throw system_error(ec);
}

if(use_tls_) {
if(handshake_response_.num == 0xFFFFFFFFU) {
use_tls_ = false;
handle_handshake(ec);
return;
}

done_ = true;
if(handshake_response_.num == 0x00000000) {
tls_context_.set_default_verify_paths();
raw_socket s { std::move(utils::get<raw_socket>(socket_)) };
socket_.emplace<1>(std::move(s), tls_context_);

auto& socket { utils::get<tls_socket>(socket_) };

socket.set_verify_mode(
boost::asio::ssl::verify_peer |
boost::asio::ssl::verify_fail_if_no_peer_cert
);

socket.set_verify_callback(boost::asio::ssl::host_name_verification(host_));

socket.async_handshake(boost::asio::ssl::stream_base::client, [this](const boost::system::error_code& ec) {
if(ec) {
throw system_error(ec);
}

done_ = true;
});
return;
}

fallback_to_unencrypted();
} else {
done_ = true;
}
}

void connection::fallback_to_unencrypted()
{
assert(use_tls_ == true);
use_tls_ = false;

boost::asio::ip::tcp::endpoint endpoint { utils::get<raw_socket>(socket_).remote_endpoint() };
utils::get<raw_socket>(socket_).close();

utils::get<raw_socket>(socket_).async_connect(endpoint,
std::bind(&connection::handle_connect, this, std::placeholders::_1, endpoint));
}

void connection::transfer(const config& request, config& response)
Expand All @@ -154,35 +218,39 @@ void connection::transfer(const config& request, config& response)
auto bufs = split_buffer(write_buf_->data());
bufs.push_front(boost::asio::buffer(reinterpret_cast<const char*>(&payload_size_), 4));

boost::asio::async_write(socket_, bufs,
std::bind(&connection::is_write_complete, this, std::placeholders::_1, std::placeholders::_2),
std::bind(&connection::handle_write, this, std::placeholders::_1, std::placeholders::_2));
utils::visit([this, &bufs, &response](auto&& socket) {
boost::asio::async_write(socket, bufs,
std::bind(&connection::is_write_complete, this, std::placeholders::_1, std::placeholders::_2),
std::bind(&connection::handle_write, this, std::placeholders::_1, std::placeholders::_2));

boost::asio::async_read(socket_, *read_buf_,
std::bind(&connection::is_read_complete, this, std::placeholders::_1, std::placeholders::_2),
std::bind(&connection::handle_read, this, std::placeholders::_1, std::placeholders::_2, std::ref(response)));
boost::asio::async_read(socket, *read_buf_,
std::bind(&connection::is_read_complete, this, std::placeholders::_1, std::placeholders::_2),
std::bind(&connection::handle_read, this, std::placeholders::_1, std::placeholders::_2, std::ref(response)));
}, socket_);
}

void connection::cancel()
{
if(socket_.is_open()) {
boost::system::error_code ec;
utils::visit([](auto&& socket) {
if(socket.lowest_layer().is_open()) {
boost::system::error_code ec;

#ifdef _MSC_VER
// Silence warning about boost::asio::basic_socket<Protocol>::cancel always
// returning an error on XP, which we don't support anymore.
#pragma warning(push)
#pragma warning(disable:4996)
#endif
socket_.cancel(ec);
socket.lowest_layer().cancel(ec);
#ifdef _MSC_VER
#pragma warning(pop)
#endif

if(ec) {
WRN_NW << "Failed to cancel network operations: " << ec.message() << std::endl;
if(ec) {
WRN_NW << "Failed to cancel network operations: " << ec.message() << std::endl;
}
}
}
}, socket_);
bytes_to_write_ = 0;
bytes_written_ = 0;
bytes_to_read_ = 0;
Expand Down
15 changes: 13 additions & 2 deletions src/network_asio.hpp
Expand Up @@ -31,6 +31,7 @@
#endif

#include "exceptions.hpp"
#include "utils/variant.hpp"

#if BOOST_VERSION >= 106600
#include <boost/asio/io_context.hpp>
Expand All @@ -39,6 +40,7 @@
#endif
#include <boost/asio/ip/tcp.hpp>
#include <boost/asio/streambuf.hpp>
#include <boost/asio/ssl.hpp>

class config;

Expand Down Expand Up @@ -130,11 +132,18 @@ class connection
boost::asio::io_service io_context_;
#endif

std::string host_;
const std::string service_;
typedef boost::asio::ip::tcp::resolver resolver;
resolver resolver_;

typedef boost::asio::ip::tcp::socket socket;
socket socket_;
boost::asio::ssl::context tls_context_ { boost::asio::ssl::context::sslv23 };

typedef boost::asio::ip::tcp::socket raw_socket;
typedef boost::asio::ssl::stream<raw_socket> tls_socket;
typedef utils::variant<raw_socket, tls_socket> any_socket;
bool use_tls_ = true;
any_socket socket_;

bool done_;

Expand All @@ -157,6 +166,8 @@ class connection

data_union handshake_response_;

void fallback_to_unencrypted();

std::size_t is_write_complete(const boost::system::error_code& error, std::size_t bytes_transferred);
void handle_write(const boost::system::error_code& ec, std::size_t bytes_transferred);

Expand Down

0 comments on commit 09ea0da

Please sign in to comment.