Skip to content

Commit

Permalink
Add TLS support to multiplayer client
Browse files Browse the repository at this point in the history
  • Loading branch information
loonycyborg committed Feb 13, 2021
1 parent 5036522 commit a2115bd
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 51 deletions.
4 changes: 2 additions & 2 deletions src/game_initialization/multiplayer.cpp
Expand Up @@ -240,7 +240,7 @@ std::unique_ptr<wesnothd_connection> mp_manager::open_connection(std::string hos
gui2::dialogs::loading_screen::progress(loading_stage::connect_to_server);

// Initializes the connection to the server.
auto conn = std::make_unique<wesnothd_connection>(addr->first, addr->second);
auto conn = std::make_unique<wesnothd_connection>(addr->first, addr->second, true);

// First, spin until we get a handshake from the server.
conn->wait_for_handshake();
Expand Down Expand Up @@ -288,7 +288,7 @@ std::unique_ptr<wesnothd_connection> mp_manager::open_connection(std::string hos

// Open a new connection with the new host and port.
conn.reset();
conn = std::make_unique<wesnothd_connection>(redirect_host, redirect_port);
conn = std::make_unique<wesnothd_connection>(redirect_host, redirect_port, true);

// Wait for new handshake.
conn->wait_for_handshake();
Expand Down
62 changes: 36 additions & 26 deletions src/server/common/server_base.cpp
Expand Up @@ -99,7 +99,6 @@ void server_base::serve(boost::asio::yield_context yield, boost::asio::ip::tcp::
}

socket_ptr socket = std::make_shared<socket_ptr::element_type>(io_service_);
bool use_tls { false };

boost::system::error_code error;
acceptor.async_accept(socket->lowest_layer(), yield[error]);
Expand Down Expand Up @@ -135,6 +134,13 @@ void server_base::serve(boost::asio::yield_context yield, boost::asio::ip::tcp::
char buf[4];
} protocol_version;

union {
uint32_t number;
char buf[4];
} handshake_response;

any_socket_ptr final_socket;

async_read(*socket, boost::asio::buffer(protocol_version.buf), yield[error]);
if(check_error(error, socket))
return;
Expand All @@ -143,49 +149,53 @@ void server_base::serve(boost::asio::yield_context yield, boost::asio::ip::tcp::
case 0:
async_write(*socket, boost::asio::buffer(handshake_response_.buf, 4), yield[error]);
if(check_error(error, socket)) return;
final_socket = socket;
break;
case 1:
if(!tls_enabled_) {
ERR_SERVER << client_address(socket) << "\tTLS requested by client but not enabled on server\n";
async_send_error(socket, "TLS support disabled on server.");
handshake_response.number = 0xFFFFFFFFU;
} else {
handshake_response.number = 0x00000000;
}

async_write(*socket, boost::asio::buffer(handshake_response.buf, 4), yield[error]);
if(check_error(error, socket) || !tls_enabled_) return;

final_socket = tls_socket_ptr { new tls_socket_ptr::element_type(std::move(*socket), tls_context_) };
utils::get<tls_socket_ptr>(final_socket)->async_handshake(boost::asio::ssl::stream_base::server, yield[error]);
if(error) {
ERR_SERVER << "TLS handshake failed: " << error.message() << "\n";
return;
}
use_tls = true;

break;
default:
ERR_SERVER << client_address(socket) << "\tincorrect handshake\n";
return;
}

const std::string ip = client_address(socket);
utils::visit([this](auto&& socket) {
const std::string ip = client_address(socket);

const std::string reason = is_ip_banned(ip);
if (!reason.empty()) {
LOG_SERVER << ip << "\trejected banned user. Reason: " << reason << "\n";
async_send_error(socket, "You are banned. Reason: " + reason);
return;
} else if (ip_exceeds_connection_limit(ip)) {
LOG_SERVER << ip << "\trejected ip due to excessive connections\n";
async_send_error(socket, "Too many connections from your IP.");
return;
} else {
if(use_tls) {
async_send_warning(socket, "Go TLS.");
tls_socket_ptr tls_socket { new tls_socket_ptr::element_type(std::move(*socket), tls_context_) };
tls_socket->async_handshake(boost::asio::ssl::stream_base::server, yield[error]);
if(error) {
ERR_SERVER << "TLS handshake failed: " << error.message() << "\n";
const std::string reason = is_ip_banned(ip);
if (!reason.empty()) {
LOG_SERVER << ip << "\trejected banned user. Reason: " << reason << "\n";
async_send_error(socket, "You are banned. Reason: " + reason);
return;
}

DBG_SERVER << ip << "\tnew encrypted connection fully accepted\n";
this->handle_new_client(tls_socket);
} else if (ip_exceeds_connection_limit(ip)) {
LOG_SERVER << ip << "\trejected ip due to excessive connections\n";
async_send_error(socket, "Too many connections from your IP.");
return;
} else {
DBG_SERVER << ip << "\tnew connection fully accepted\n";
if constexpr (std::is_same_v<decltype(socket), tls_socket_ptr>) {
DBG_SERVER << ip << "\tnew encrypted connection fully accepted\n";
} else {
DBG_SERVER << ip << "\tnew connection fully accepted\n";
}
this->handle_new_client(socket);
}
}
}, final_socket);
}

#ifndef _WIN32
Expand Down
85 changes: 65 additions & 20 deletions src/wesnothd_connection.cpp
Expand Up @@ -23,6 +23,7 @@
#include <boost/asio/connect.hpp>
#include <boost/asio/read.hpp>
#include <boost/asio/write.hpp>
#include <boost/asio/ssl/host_name_verification.hpp>

#include <cstdint>
#include <deque>
Expand Down Expand Up @@ -56,11 +57,15 @@ using boost::system::error_code;
using boost::system::system_error;

// main thread
wesnothd_connection::wesnothd_connection(const std::string& host, const std::string& service)
wesnothd_connection::wesnothd_connection(const std::string& host, const std::string& service, bool tls)
: worker_thread_()
, io_context_()
, resolver_(io_context_)
, socket_(io_context_)
, tls_context_(boost::asio::ssl::context::sslv23)
, host_(host)
, service_(service)
, use_tls_(tls)
, socket_(raw_socket{io_context_})
, last_error_()
, last_error_mutex_()
, handshake_finished_()
Expand Down Expand Up @@ -121,7 +126,7 @@ void wesnothd_connection::handle_resolve(const error_code& ec, results_type resu
throw system_error(ec);
}

boost::asio::async_connect(socket_, results,
boost::asio::async_connect(utils::get<raw_socket>(socket_), results,
std::bind(&wesnothd_connection::handle_connect, this, std::placeholders::_1, std::placeholders::_2));
}

Expand All @@ -147,11 +152,11 @@ void wesnothd_connection::handshake()
{
MPTEST_LOG;
static const uint32_t handshake = 0;
static const uint32_t tls_handshake = htonl(uint32_t(1));

boost::asio::async_write(socket_, boost::asio::buffer(reinterpret_cast<const char*>(&handshake), 4),
boost::asio::async_write(utils::get<raw_socket>(socket_), boost::asio::buffer(use_tls_ ? reinterpret_cast<const char*>(&tls_handshake) : reinterpret_cast<const char*>(&handshake), 4),
[](const error_code& ec, std::size_t) { if(ec) { throw system_error(ec); } });

boost::asio::async_read(socket_, boost::asio::buffer(&handshake_response_.binary, 4),
boost::asio::async_read(utils::get<raw_socket>(socket_), boost::asio::buffer(&handshake_response_.binary, 4),
std::bind(&wesnothd_connection::handle_handshake, this, std::placeholders::_1));
}

Expand All @@ -163,9 +168,43 @@ void wesnothd_connection::handle_handshake(const error_code& ec)
LOG_NW << __func__ << " Throwing: " << ec << "\n";
throw system_error(ec);
}

if(use_tls_) {
if(handshake_response_.num == 0xFFFFFFFFU) {
throw std::runtime_error("The server doesn't support TLS");
}

handshake_finished_.set_value();
recv();
if(handshake_response_.num == 0x00000000) {
tls_context_.set_default_verify_paths();
raw_socket s { std::move(utils::get<raw_socket>(socket_)) };
socket_.emplace<1>(std::move(s), tls_context_);

auto& socket { utils::get<tls_socket>(socket_) };

socket.set_verify_mode(
boost::asio::ssl::verify_peer |
boost::asio::ssl::verify_fail_if_no_peer_cert
);

socket.set_verify_callback(boost::asio::ssl::host_name_verification(host_));

socket.async_handshake(boost::asio::ssl::stream_base::client, [this](const error_code& ec) {
if(ec) {
LOG_NW << __func__ << " Throwing: " << ec << "\n";
throw system_error(ec);
}

handshake_finished_.set_value();
recv();
});
return;
}

throw std::runtime_error("Invalid handshake");
} else {
handshake_finished_.set_value();
recv();
}
}

// main thread
Expand Down Expand Up @@ -223,24 +262,26 @@ void wesnothd_connection::send_data(const configr_of& request)
void wesnothd_connection::cancel()
{
MPTEST_LOG;
if(socket_.is_open()) {
boost::system::error_code ec;
utils::visit([](auto&& socket) {
if(socket.lowest_layer().is_open()) {
boost::system::error_code ec;

#ifdef _MSC_VER
// Silence warning about boost::asio::basic_socket<Protocol>::cancel always
// returning an error on XP, which we don't support anymore.
#pragma warning(push)
#pragma warning(disable:4996)
#endif
socket_.cancel(ec);
socket.lowest_layer().cancel(ec);
#ifdef _MSC_VER
#pragma warning(pop)
#endif

if(ec) {
WRN_NW << "Failed to cancel network operations: " << ec.message() << std::endl;
if(ec) {
WRN_NW << "Failed to cancel network operations: " << ec.message() << std::endl;
}
}
}
}, socket_);
}

// main thread
Expand Down Expand Up @@ -384,19 +425,23 @@ void wesnothd_connection::send()
buf.data()
};

boost::asio::async_write(socket_, bufs,
std::bind(&wesnothd_connection::is_write_complete, this, std::placeholders::_1, std::placeholders::_2),
std::bind(&wesnothd_connection::handle_write, this, std::placeholders::_1, std::placeholders::_2));
utils::visit([this, &bufs](auto&& socket) {
boost::asio::async_write(socket, bufs,
std::bind(&wesnothd_connection::is_write_complete, this, std::placeholders::_1, std::placeholders::_2),
std::bind(&wesnothd_connection::handle_write, this, std::placeholders::_1, std::placeholders::_2));
}, socket_);
}

// worker thread
void wesnothd_connection::recv()
{
MPTEST_LOG;

boost::asio::async_read(socket_, read_buf_,
std::bind(&wesnothd_connection::is_read_complete, this, std::placeholders::_1, std::placeholders::_2),
std::bind(&wesnothd_connection::handle_read, this, std::placeholders::_1, std::placeholders::_2));
utils::visit([this](auto&& socket) {
boost::asio::async_read(socket, read_buf_,
std::bind(&wesnothd_connection::is_read_complete, this, std::placeholders::_1, std::placeholders::_2),
std::bind(&wesnothd_connection::handle_read, this, std::placeholders::_1, std::placeholders::_2));
}, socket_);
}

// main thread
Expand Down
15 changes: 12 additions & 3 deletions src/wesnothd_connection.hpp
Expand Up @@ -40,6 +40,7 @@
#endif
#include <boost/asio/ip/tcp.hpp>
#include <boost/asio/streambuf.hpp>
#include <boost/asio/ssl.hpp>

#include <condition_variable>
#include <deque>
Expand Down Expand Up @@ -72,8 +73,9 @@ class wesnothd_connection
*
* @param host Name of the host to connect to
* @param service Service identifier such as "80" or "http"
* @param tls Whether we want to use TLS to make connection encrypted
*/
wesnothd_connection(const std::string& host, const std::string& service);
wesnothd_connection(const std::string& host, const std::string& service, bool tls);

/**
* Queues the given data to be sent to the server.
Expand Down Expand Up @@ -147,8 +149,15 @@ class wesnothd_connection
typedef boost::asio::ip::tcp::resolver resolver;
resolver resolver_;

typedef boost::asio::ip::tcp::socket socket;
socket socket_;
boost::asio::ssl::context tls_context_;

std::string host_;
std::string service_;
typedef boost::asio::ip::tcp::socket raw_socket;
typedef boost::asio::ssl::stream<raw_socket> tls_socket;
typedef utils::variant<raw_socket, tls_socket> any_socket;
bool use_tls_;
any_socket socket_;

boost::system::error_code last_error_;

Expand Down

0 comments on commit a2115bd

Please sign in to comment.