diff --git a/src/network_asio.cpp b/src/network_asio.cpp index 74ac52f68c13..041fcb550698 100644 --- a/src/network_asio.cpp +++ b/src/network_asio.cpp @@ -66,8 +66,10 @@ using boost::system::system_error; connection::connection(const std::string& host, const std::string& service) : io_context_() + , host_(host) + , service_(service) , resolver_(io_context_) - , socket_(io_context_) + , socket_(raw_socket{io_context_}) , done_(false) , write_buf_() , read_buf_() @@ -78,12 +80,15 @@ connection::connection(const std::string& host, const std::string& service) , bytes_to_read_(0) , bytes_read_(0) { -#if BOOST_VERSION >= 106600 - resolver_.async_resolve(host, service, -#else - resolver_.async_resolve(boost::asio::ip::tcp::resolver::query(host, service), -#endif - std::bind(&connection::handle_resolve, this, std::placeholders::_1, std::placeholders::_2)); + boost::system::error_code ec; + auto result = resolver_.resolve(host, service, boost::asio::ip::resolver_query_base::numeric_host, ec); + if(!ec) { // if numeric resolve succeeds then we got raw ip address so TLS host name validation would never pass + use_tls_ = false; + boost::asio::post(io_context_, [this, ec, result](){ handle_resolve(ec, { result } ); } ); + } else { + resolver_.async_resolve(host, service, + std::bind(&connection::handle_resolve, this, std::placeholders::_1, std::placeholders::_2)); + } LOG_NW << "Resolving hostname: " << host << '\n'; } @@ -94,7 +99,7 @@ void connection::handle_resolve(const boost::system::error_code& ec, results_typ throw system_error(ec); } - boost::asio::async_connect(socket_, results, + boost::asio::async_connect(utils::get(socket_), results, std::bind(&connection::handle_connect, this, std::placeholders::_1, std::placeholders::_2)); } @@ -109,6 +114,9 @@ void connection::handle_connect(const boost::system::error_code& ec, endpoint en #else LOG_NW << "Connected to " << endpoint->endpoint().address() << '\n'; #endif + if(endpoint.address().is_loopback()) { + use_tls_ = false; + } handshake(); } } @@ -116,21 +124,77 @@ void connection::handle_connect(const boost::system::error_code& ec, endpoint en void connection::handshake() { static const uint32_t handshake = 0; + static const uint32_t tls_handshake = htonl(uint32_t(1)); - boost::asio::async_write(socket_, boost::asio::buffer(reinterpret_cast(&handshake), 4), - std::bind(&connection::handle_write, this, std::placeholders::_1, std::placeholders::_2)); + boost::asio::async_write( + utils::get(socket_), + boost::asio::buffer(use_tls_ ? reinterpret_cast(&tls_handshake) : reinterpret_cast(&handshake), 4), + std::bind(&connection::handle_write, this, std::placeholders::_1, std::placeholders::_2) + ); - boost::asio::async_read(socket_, boost::asio::buffer(&handshake_response_.binary, 4), + boost::asio::async_read(utils::get(socket_), boost::asio::buffer(&handshake_response_.binary, 4), std::bind(&connection::handle_handshake, this, std::placeholders::_1)); } void connection::handle_handshake(const boost::system::error_code& ec) { if(ec) { + if(ec == boost::asio::error::eof && use_tls_) { + // immediate disconnect likely means old server not supporting TLS handshake code + fallback_to_unencrypted(); + return; + } + throw system_error(ec); } + + if(use_tls_) { + if(handshake_response_.num == 0xFFFFFFFFU) { + use_tls_ = false; + handle_handshake(ec); + return; + } - done_ = true; + if(handshake_response_.num == 0x00000000) { + tls_context_.set_default_verify_paths(); + raw_socket s { std::move(utils::get(socket_)) }; + socket_.emplace<1>(std::move(s), tls_context_); + + auto& socket { utils::get(socket_) }; + + socket.set_verify_mode( + boost::asio::ssl::verify_peer | + boost::asio::ssl::verify_fail_if_no_peer_cert + ); + + socket.set_verify_callback(boost::asio::ssl::host_name_verification(host_)); + + socket.async_handshake(boost::asio::ssl::stream_base::client, [this](const boost::system::error_code& ec) { + if(ec) { + throw system_error(ec); + } + + done_ = true; + }); + return; + } + + fallback_to_unencrypted(); + } else { + done_ = true; + } +} + +void connection::fallback_to_unencrypted() +{ + assert(use_tls_ == true); + use_tls_ = false; + + boost::asio::ip::tcp::endpoint endpoint { utils::get(socket_).remote_endpoint() }; + utils::get(socket_).close(); + + utils::get(socket_).async_connect(endpoint, + std::bind(&connection::handle_connect, this, std::placeholders::_1, endpoint)); } void connection::transfer(const config& request, config& response) @@ -154,19 +218,22 @@ void connection::transfer(const config& request, config& response) auto bufs = split_buffer(write_buf_->data()); bufs.push_front(boost::asio::buffer(reinterpret_cast(&payload_size_), 4)); - boost::asio::async_write(socket_, bufs, - std::bind(&connection::is_write_complete, this, std::placeholders::_1, std::placeholders::_2), - std::bind(&connection::handle_write, this, std::placeholders::_1, std::placeholders::_2)); + utils::visit([this, &bufs, &response](auto&& socket) { + boost::asio::async_write(socket, bufs, + std::bind(&connection::is_write_complete, this, std::placeholders::_1, std::placeholders::_2), + std::bind(&connection::handle_write, this, std::placeholders::_1, std::placeholders::_2)); - boost::asio::async_read(socket_, *read_buf_, - std::bind(&connection::is_read_complete, this, std::placeholders::_1, std::placeholders::_2), - std::bind(&connection::handle_read, this, std::placeholders::_1, std::placeholders::_2, std::ref(response))); + boost::asio::async_read(socket, *read_buf_, + std::bind(&connection::is_read_complete, this, std::placeholders::_1, std::placeholders::_2), + std::bind(&connection::handle_read, this, std::placeholders::_1, std::placeholders::_2, std::ref(response))); + }, socket_); } void connection::cancel() { - if(socket_.is_open()) { - boost::system::error_code ec; + utils::visit([](auto&& socket) { + if(socket.lowest_layer().is_open()) { + boost::system::error_code ec; #ifdef _MSC_VER // Silence warning about boost::asio::basic_socket::cancel always @@ -174,15 +241,16 @@ void connection::cancel() #pragma warning(push) #pragma warning(disable:4996) #endif - socket_.cancel(ec); + socket.lowest_layer().cancel(ec); #ifdef _MSC_VER #pragma warning(pop) #endif - if(ec) { - WRN_NW << "Failed to cancel network operations: " << ec.message() << std::endl; + if(ec) { + WRN_NW << "Failed to cancel network operations: " << ec.message() << std::endl; + } } - } + }, socket_); bytes_to_write_ = 0; bytes_written_ = 0; bytes_to_read_ = 0; diff --git a/src/network_asio.hpp b/src/network_asio.hpp index 7099676b034b..92c0108d91f8 100644 --- a/src/network_asio.hpp +++ b/src/network_asio.hpp @@ -31,6 +31,7 @@ #endif #include "exceptions.hpp" +#include "utils/variant.hpp" #if BOOST_VERSION >= 106600 #include @@ -39,6 +40,7 @@ #endif #include #include +#include class config; @@ -130,11 +132,18 @@ class connection boost::asio::io_service io_context_; #endif + std::string host_; + const std::string service_; typedef boost::asio::ip::tcp::resolver resolver; resolver resolver_; - typedef boost::asio::ip::tcp::socket socket; - socket socket_; + boost::asio::ssl::context tls_context_ { boost::asio::ssl::context::sslv23 }; + + typedef boost::asio::ip::tcp::socket raw_socket; + typedef boost::asio::ssl::stream tls_socket; + typedef utils::variant any_socket; + bool use_tls_ = true; + any_socket socket_; bool done_; @@ -157,6 +166,8 @@ class connection data_union handshake_response_; + void fallback_to_unencrypted(); + std::size_t is_write_complete(const boost::system::error_code& error, std::size_t bytes_transferred); void handle_write(const boost::system::error_code& ec, std::size_t bytes_transferred);