From 8145267ea952c8afbde9d3fc6f604976028a205c Mon Sep 17 00:00:00 2001 From: loonycyborg Date: Sun, 7 Feb 2021 02:10:52 +0300 Subject: [PATCH] Add TLS support to multiplayer client --- src/game_initialization/multiplayer.cpp | 4 +- src/server/common/server_base.cpp | 62 ++++++++++-------- src/wesnothd_connection.cpp | 85 +++++++++++++++++++------ src/wesnothd_connection.hpp | 15 ++++- 4 files changed, 115 insertions(+), 51 deletions(-) diff --git a/src/game_initialization/multiplayer.cpp b/src/game_initialization/multiplayer.cpp index df3ea7825cca1..5d393ca51f56e 100644 --- a/src/game_initialization/multiplayer.cpp +++ b/src/game_initialization/multiplayer.cpp @@ -240,7 +240,7 @@ std::unique_ptr mp_manager::open_connection(std::string hos gui2::dialogs::loading_screen::progress(loading_stage::connect_to_server); // Initializes the connection to the server. - auto conn = std::make_unique(addr->first, addr->second); + auto conn = std::make_unique(addr->first, addr->second, true); // First, spin until we get a handshake from the server. conn->wait_for_handshake(); @@ -288,7 +288,7 @@ std::unique_ptr mp_manager::open_connection(std::string hos // Open a new connection with the new host and port. conn.reset(); - conn = std::make_unique(redirect_host, redirect_port); + conn = std::make_unique(redirect_host, redirect_port, true); // Wait for new handshake. conn->wait_for_handshake(); diff --git a/src/server/common/server_base.cpp b/src/server/common/server_base.cpp index fef8bcb50f696..354caef061c85 100644 --- a/src/server/common/server_base.cpp +++ b/src/server/common/server_base.cpp @@ -99,7 +99,6 @@ void server_base::serve(boost::asio::yield_context yield, boost::asio::ip::tcp:: } socket_ptr socket = std::make_shared(io_service_); - bool use_tls { false }; boost::system::error_code error; acceptor.async_accept(socket->lowest_layer(), yield[error]); @@ -135,6 +134,13 @@ void server_base::serve(boost::asio::yield_context yield, boost::asio::ip::tcp:: char buf[4]; } protocol_version; + union { + uint32_t number; + char buf[4]; + } handshake_response; + + any_socket_ptr final_socket; + async_read(*socket, boost::asio::buffer(protocol_version.buf), yield[error]); if(check_error(error, socket)) return; @@ -143,14 +149,25 @@ void server_base::serve(boost::asio::yield_context yield, boost::asio::ip::tcp:: case 0: async_write(*socket, boost::asio::buffer(handshake_response_.buf, 4), yield[error]); if(check_error(error, socket)) return; + final_socket = socket; break; case 1: if(!tls_enabled_) { ERR_SERVER << client_address(socket) << "\tTLS requested by client but not enabled on server\n"; - async_send_error(socket, "TLS support disabled on server."); + handshake_response.number = 0xFFFFFFFFU; + } else { + handshake_response.number = 0x00000000; + } + + async_write(*socket, boost::asio::buffer(handshake_response.buf, 4), yield[error]); + if(check_error(error, socket) || !tls_enabled_) return; + + final_socket = tls_socket_ptr { new tls_socket_ptr::element_type(std::move(*socket), tls_context_) }; + utils::get(final_socket)->async_handshake(boost::asio::ssl::stream_base::server, yield[error]); + if(error) { + ERR_SERVER << "TLS handshake failed: " << error.message() << "\n"; return; } - use_tls = true; break; default: @@ -158,34 +175,27 @@ void server_base::serve(boost::asio::yield_context yield, boost::asio::ip::tcp:: return; } - const std::string ip = client_address(socket); + utils::visit([this](auto&& socket) { + const std::string ip = client_address(socket); - const std::string reason = is_ip_banned(ip); - if (!reason.empty()) { - LOG_SERVER << ip << "\trejected banned user. Reason: " << reason << "\n"; - async_send_error(socket, "You are banned. Reason: " + reason); - return; - } else if (ip_exceeds_connection_limit(ip)) { - LOG_SERVER << ip << "\trejected ip due to excessive connections\n"; - async_send_error(socket, "Too many connections from your IP."); - return; - } else { - if(use_tls) { - async_send_warning(socket, "Go TLS."); - tls_socket_ptr tls_socket { new tls_socket_ptr::element_type(std::move(*socket), tls_context_) }; - tls_socket->async_handshake(boost::asio::ssl::stream_base::server, yield[error]); - if(error) { - ERR_SERVER << "TLS handshake failed: " << error.message() << "\n"; + const std::string reason = is_ip_banned(ip); + if (!reason.empty()) { + LOG_SERVER << ip << "\trejected banned user. Reason: " << reason << "\n"; + async_send_error(socket, "You are banned. Reason: " + reason); return; - } - - DBG_SERVER << ip << "\tnew encrypted connection fully accepted\n"; - this->handle_new_client(tls_socket); + } else if (ip_exceeds_connection_limit(ip)) { + LOG_SERVER << ip << "\trejected ip due to excessive connections\n"; + async_send_error(socket, "Too many connections from your IP."); + return; } else { - DBG_SERVER << ip << "\tnew connection fully accepted\n"; + if constexpr (std::is_same_v) { + DBG_SERVER << ip << "\tnew encrypted connection fully accepted\n"; + } else { + DBG_SERVER << ip << "\tnew connection fully accepted\n"; + } this->handle_new_client(socket); } - } + }, final_socket); } #ifndef _WIN32 diff --git a/src/wesnothd_connection.cpp b/src/wesnothd_connection.cpp index 062affbdfae48..6bb6c618ef617 100644 --- a/src/wesnothd_connection.cpp +++ b/src/wesnothd_connection.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -56,11 +57,15 @@ using boost::system::error_code; using boost::system::system_error; // main thread -wesnothd_connection::wesnothd_connection(const std::string& host, const std::string& service) +wesnothd_connection::wesnothd_connection(const std::string& host, const std::string& service, bool tls) : worker_thread_() , io_context_() , resolver_(io_context_) - , socket_(io_context_) + , tls_context_(boost::asio::ssl::context::sslv23) + , host_(host) + , service_(service) + , use_tls_(tls) + , socket_(raw_socket{io_context_}) , last_error_() , last_error_mutex_() , handshake_finished_() @@ -120,7 +125,7 @@ void wesnothd_connection::handle_resolve(const error_code& ec, results_type resu throw system_error(ec); } - boost::asio::async_connect(socket_, results, + boost::asio::async_connect(utils::get(socket_), results, std::bind(&wesnothd_connection::handle_connect, this, std::placeholders::_1, std::placeholders::_2)); } @@ -146,11 +151,11 @@ void wesnothd_connection::handshake() { MPTEST_LOG; static const uint32_t handshake = 0; + static const uint32_t tls_handshake = htonl(uint32_t(1)); - boost::asio::async_write(socket_, boost::asio::buffer(reinterpret_cast(&handshake), 4), + boost::asio::async_write(utils::get(socket_), boost::asio::buffer(use_tls_ ? reinterpret_cast(&tls_handshake) : reinterpret_cast(&handshake), 4), [](const error_code& ec, std::size_t) { if(ec) { throw system_error(ec); } }); - - boost::asio::async_read(socket_, boost::asio::buffer(&handshake_response_.binary, 4), + boost::asio::async_read(utils::get(socket_), boost::asio::buffer(&handshake_response_.binary, 4), std::bind(&wesnothd_connection::handle_handshake, this, std::placeholders::_1)); } @@ -162,9 +167,43 @@ void wesnothd_connection::handle_handshake(const error_code& ec) LOG_NW << __func__ << " Throwing: " << ec << "\n"; throw system_error(ec); } + + if(use_tls_) { + if(handshake_response_.num == 0xFFFFFFFFU) { + throw std::runtime_error("The server doesn't support TLS"); + } - handshake_finished_.set_value(); - recv(); + if(handshake_response_.num == 0x00000000) { + tls_context_.set_default_verify_paths(); + raw_socket s { std::move(utils::get(socket_)) }; + socket_.emplace<1>(std::move(s), tls_context_); + + auto& socket { utils::get(socket_) }; + + socket.set_verify_mode( + boost::asio::ssl::verify_peer | + boost::asio::ssl::verify_fail_if_no_peer_cert + ); + + socket.set_verify_callback(boost::asio::ssl::host_name_verification(host_)); + + socket.async_handshake(boost::asio::ssl::stream_base::client, [this](const error_code& ec) { + if(ec) { + LOG_NW << __func__ << " Throwing: " << ec << "\n"; + throw system_error(ec); + } + + handshake_finished_.set_value(); + recv(); + }); + return; + } + + throw std::runtime_error("Invalid handshake"); + } else { + handshake_finished_.set_value(); + recv(); + } } // main thread @@ -222,8 +261,9 @@ void wesnothd_connection::send_data(const configr_of& request) void wesnothd_connection::cancel() { MPTEST_LOG; - if(socket_.is_open()) { - boost::system::error_code ec; + utils::visit([](auto&& socket) { + if(socket.lowest_layer().is_open()) { + boost::system::error_code ec; #ifdef _MSC_VER // Silence warning about boost::asio::basic_socket::cancel always @@ -231,15 +271,16 @@ void wesnothd_connection::cancel() #pragma warning(push) #pragma warning(disable:4996) #endif - socket_.cancel(ec); + socket.lowest_layer().cancel(ec); #ifdef _MSC_VER #pragma warning(pop) #endif - if(ec) { - WRN_NW << "Failed to cancel network operations: " << ec.message() << std::endl; + if(ec) { + WRN_NW << "Failed to cancel network operations: " << ec.message() << std::endl; + } } - } + }, socket_); } // main thread @@ -383,9 +424,11 @@ void wesnothd_connection::send() buf.data() }; - boost::asio::async_write(socket_, bufs, - std::bind(&wesnothd_connection::is_write_complete, this, std::placeholders::_1, std::placeholders::_2), - std::bind(&wesnothd_connection::handle_write, this, std::placeholders::_1, std::placeholders::_2)); + utils::visit([this, &bufs](auto&& socket) { + boost::asio::async_write(socket, bufs, + std::bind(&wesnothd_connection::is_write_complete, this, std::placeholders::_1, std::placeholders::_2), + std::bind(&wesnothd_connection::handle_write, this, std::placeholders::_1, std::placeholders::_2)); + }, socket_); } // worker thread @@ -393,9 +436,11 @@ void wesnothd_connection::recv() { MPTEST_LOG; - boost::asio::async_read(socket_, read_buf_, - std::bind(&wesnothd_connection::is_read_complete, this, std::placeholders::_1, std::placeholders::_2), - std::bind(&wesnothd_connection::handle_read, this, std::placeholders::_1, std::placeholders::_2)); + utils::visit([this](auto&& socket) { + boost::asio::async_read(socket, read_buf_, + std::bind(&wesnothd_connection::is_read_complete, this, std::placeholders::_1, std::placeholders::_2), + std::bind(&wesnothd_connection::handle_read, this, std::placeholders::_1, std::placeholders::_2)); + }, socket_); } // main thread diff --git a/src/wesnothd_connection.hpp b/src/wesnothd_connection.hpp index 57b6b91aa4fdd..0a8365d861a9b 100644 --- a/src/wesnothd_connection.hpp +++ b/src/wesnothd_connection.hpp @@ -40,6 +40,7 @@ #endif #include #include +#include #include #include @@ -72,8 +73,9 @@ class wesnothd_connection * * @param host Name of the host to connect to * @param service Service identifier such as "80" or "http" + * @param tls Whether we want to use TLS to make connection encrypted */ - wesnothd_connection(const std::string& host, const std::string& service); + wesnothd_connection(const std::string& host, const std::string& service, bool tls); /** * Queues the given data to be sent to the server. @@ -147,8 +149,15 @@ class wesnothd_connection typedef boost::asio::ip::tcp::resolver resolver; resolver resolver_; - typedef boost::asio::ip::tcp::socket socket; - socket socket_; + boost::asio::ssl::context tls_context_; + + std::string host_; + std::string service_; + typedef boost::asio::ip::tcp::socket raw_socket; + typedef boost::asio::ssl::stream tls_socket; + typedef utils::variant any_socket; + bool use_tls_; + any_socket socket_; boost::system::error_code last_error_;