Skip to content

Commit

Permalink
Add support for executing async queries.
Browse files Browse the repository at this point in the history
Making the logging correctly handle multiple threads writing to the same stream at the same time will be handled separately later.
  • Loading branch information
Pentarctagon committed Oct 15, 2020
1 parent 90e979a commit 467e431
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 88 deletions.
18 changes: 12 additions & 6 deletions CMakeLists.txt
@@ -1,5 +1,4 @@

# set minimum version
# set minimum version
cmake_minimum_required(VERSION 3.1)

project(wesnoth)
Expand Down Expand Up @@ -58,6 +57,13 @@ option(ENABLE_TESTS "Build unit tests")
option(ENABLE_NLS "Enable building of translations" ${ENABLE_GAME})
option(ENABLE_HISTORY "Enable using GNU history for history in lua console" ON)

# boost::asio::post is new with 1.66
if(ENABLE_MYSQL)
set(BOOST_VERSION "1.66")
else()
set(BOOST_VERSION "1.56")
endif(ENABLE_MYSQL)

# set what std version to use
if(NOT CXX_STD)
set(CXX_STD "14")
Expand All @@ -80,7 +86,7 @@ else()
set(CRYPTO_LIBRARY "-framework Security")
endif()

find_package(Boost 1.56 REQUIRED COMPONENTS iostreams program_options regex system thread random)
find_package(Boost ${BOOST_VERSION} REQUIRED COMPONENTS iostreams program_options regex system thread random)

# no, gettext executables are not required when NLS is deactivated
find_package(Gettext)
Expand Down Expand Up @@ -535,7 +541,7 @@ if(ENABLE_GAME OR ENABLE_TESTS)
endif(ENABLE_GAME OR ENABLE_TESTS)

if(ENABLE_TESTS)
find_package( Boost 1.56 REQUIRED COMPONENTS unit_test_framework )
find_package( Boost ${BOOST_VERSION} REQUIRED COMPONENTS unit_test_framework )
endif(ENABLE_TESTS)

if(ENABLE_GAME)
Expand Down Expand Up @@ -567,8 +573,8 @@ if(ENABLE_GAME)
endif(ENABLE_HISTORY AND HISTORY_FOUND)
endif(ENABLE_GAME)

find_package(Boost 1.56 REQUIRED COMPONENTS filesystem)
find_package(Boost 1.56 REQUIRED COMPONENTS locale)
find_package(Boost ${BOOST_VERSION} REQUIRED COMPONENTS filesystem)
find_package(Boost ${BOOST_VERSION} REQUIRED COMPONENTS locale)

if(ENABLE_POT_UPDATE_TARGET)
find_package(TranslationTools REQUIRED)
Expand Down
6 changes: 5 additions & 1 deletion SConstruct
Expand Up @@ -184,7 +184,11 @@ if env['distcc']:

if env['ccache']: env.Tool('ccache')

boost_version = '1.56.0'
# boost::asio::post is new with 1.66
if env["forum_user_handler"]:
boost_version = "1.66"
else:
boost_version = "1.56"


def SortHelpText(a, b):
Expand Down
115 changes: 67 additions & 48 deletions src/server/common/dbconn.cpp
Expand Up @@ -40,7 +40,8 @@ dbconn::dbconn(const config& c)
account_ = mariadb::account::create(c["db_host"].str(), c["db_user"].str(), c["db_password"].str());
account_->set_connect_option(mysql_option::MYSQL_SET_CHARSET_NAME, std::string("utf8mb4"));
account_->set_schema(c["db_name"].str());
connection_ = mariadb::connection::create(account_);
// initialize sync query connection
connection_ = create_connection();
}
catch(const mariadb::exception::base& e)
{
Expand All @@ -55,14 +56,33 @@ void dbconn::log_sql_exception(const std::string& text, const mariadb::exception
<< "error id: " << e.error_id() << std::endl;
}

mariadb::connection_ref dbconn::create_connection()
{
return mariadb::connection::create(account_);
}

//
// queries
//
/* simple test async query that will taken a noticeable amount of time to complete */
int dbconn::async_test_query(int limit)
{
std::string sql = "with recursive TEST(T) as "
"( "
"select 1 "
"union all "
"select T+1 from TEST where T < ? "
") "
"select count(*) from TEST";
int t = get_single_long(create_connection(), sql, limit);
return t;
}

std::string dbconn::get_uuid()
{
try
{
return get_single_string("SELECT UUID()");
return get_single_string(connection_, "SELECT UUID()");
}
catch(const mariadb::exception::base& e)
{
Expand All @@ -81,7 +101,7 @@ std::string dbconn::get_tournaments()
try
{
tournaments t;
get_complex_results(t, db_tournament_query_);
get_complex_results(connection_, t, db_tournament_query_);
return t.str();
}
catch(const mariadb::exception::base& e)
Expand All @@ -95,7 +115,7 @@ bool dbconn::user_exists(const std::string& name)
{
try
{
return exists("SELECT 1 FROM `"+db_users_table_+"` WHERE UPPER(username)=UPPER(?)", name);
return exists(connection_, "SELECT 1 FROM `"+db_users_table_+"` WHERE UPPER(username)=UPPER(?)", name);
}
catch(const mariadb::exception::base& e)
{
Expand All @@ -108,7 +128,7 @@ bool dbconn::extra_row_exists(const std::string& name)
{
try
{
return exists("SELECT 1 FROM `"+db_extra_table_+"` WHERE UPPER(username)=UPPER(?)", name);
return exists(connection_, "SELECT 1 FROM `"+db_extra_table_+"` WHERE UPPER(username)=UPPER(?)", name);
}
catch(const mariadb::exception::base& e)
{
Expand All @@ -121,7 +141,8 @@ bool dbconn::is_user_in_group(const std::string& name, int group_id)
{
try
{
return exists("SELECT 1 FROM `"+db_users_table_+"` u, `"+db_user_group_table_+"` ug WHERE UPPER(u.username)=UPPER(?) AND u.USER_ID = ug.USER_ID AND ug.GROUP_ID = ?", name, group_id);
return exists(connection_, "SELECT 1 FROM `"+db_users_table_+"` u, `"+db_user_group_table_+"` ug WHERE UPPER(u.username)=UPPER(?) AND u.USER_ID = ug.USER_ID AND ug.GROUP_ID = ?",
name, group_id);
}
catch(const mariadb::exception::base& e)
{
Expand All @@ -136,7 +157,8 @@ ban_check dbconn::get_ban_info(const std::string& name, const std::string& ip)
{
// selected ban_type value must be part of user_handler::BAN_TYPE
ban_check b;
get_complex_results(b, "select ban_userid, ban_email, case when ban_ip != '' then 1 when ban_userid != 0 then 2 when ban_email != '' then 3 end as ban_type, ban_end from `"+db_banlist_table_+"` where (ban_ip = ? or ban_userid = (select user_id from `"+db_users_table_+"` where UPPER(username) = UPPER(?)) or UPPER(ban_email) = (select UPPER(user_email) from `"+db_users_table_+"` where UPPER(username) = UPPER(?))) AND ban_exclude = 0 AND (ban_end = 0 OR ban_end >= ?)", ip, name, name, std::time(nullptr));
get_complex_results(connection_, b, "select ban_userid, ban_email, case when ban_ip != '' then 1 when ban_userid != 0 then 2 when ban_email != '' then 3 end as ban_type, ban_end from `"+db_banlist_table_+"` where (ban_ip = ? or ban_userid = (select user_id from `"+db_users_table_+"` where UPPER(username) = UPPER(?)) or UPPER(ban_email) = (select UPPER(user_email) from `"+db_users_table_+"` where UPPER(username) = UPPER(?))) AND ban_exclude = 0 AND (ban_end = 0 OR ban_end >= ?)",
ip, name, name, std::time(nullptr));
return b;
}
catch(const mariadb::exception::base& e)
Expand All @@ -150,7 +172,7 @@ std::string dbconn::get_user_string(const std::string& table, const std::string&
{
try
{
return get_single_string("SELECT `"+column+"` from `"+table+"` WHERE UPPER(username)=UPPER(?)", name);
return get_single_string(connection_, "SELECT `"+column+"` from `"+table+"` WHERE UPPER(username)=UPPER(?)", name);
}
catch(const mariadb::exception::base& e)
{
Expand All @@ -162,7 +184,7 @@ int dbconn::get_user_int(const std::string& table, const std::string& column, co
{
try
{
return get_single_int("SELECT `"+column+"` from `"+table+"` WHERE UPPER(username)=UPPER(?)", name);
return static_cast<int>(get_single_long(connection_, "SELECT `"+column+"` from `"+table+"` WHERE UPPER(username)=UPPER(?)", name));
}
catch(const mariadb::exception::base& e)
{
Expand All @@ -176,9 +198,9 @@ void dbconn::write_user_int(const std::string& column, const std::string& name,
{
if(!extra_row_exists(name))
{
modify("INSERT INTO `"+db_extra_table_+"` VALUES(?,?,'0')", name, value);
modify(connection_, "INSERT INTO `"+db_extra_table_+"` VALUES(?,?,'0')", name, value);
}
modify("UPDATE `"+db_extra_table_+"` SET "+column+"=? WHERE UPPER(username)=UPPER(?)", value, name);
modify(connection_, "UPDATE `"+db_extra_table_+"` SET "+column+"=? WHERE UPPER(username)=UPPER(?)", value, name);
}
catch(const mariadb::exception::base& e)
{
Expand All @@ -190,8 +212,8 @@ void dbconn::insert_game_info(const std::string& uuid, int game_id, const std::s
{
try
{
modify("INSERT INTO `"+db_game_info_table_+"`(INSTANCE_UUID, GAME_ID, INSTANCE_VERSION, GAME_NAME, MAP_NAME, ERA_NAME, RELOAD, OBSERVERS, PUBLIC, PASSWORD, MAP_SOURCE_ADDON, MAP_VERSION, ERA_SOURCE_ADDON, ERA_VERSION) VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
uuid, game_id, version, name, map_name, era_name, reload, observers, is_public, has_password, map_source, map_version, era_source, era_version);
modify(connection_, "INSERT INTO `"+db_game_info_table_+"`(INSTANCE_UUID, GAME_ID, INSTANCE_VERSION, GAME_NAME, MAP_NAME, ERA_NAME, RELOAD, OBSERVERS, PUBLIC, PASSWORD, MAP_SOURCE_ADDON, MAP_VERSION, ERA_SOURCE_ADDON, ERA_VERSION) VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
uuid, game_id, version, name, map_name, era_name, reload, observers, is_public, has_password, map_source, map_version, era_source, era_version);
}
catch(const mariadb::exception::base& e)
{
Expand All @@ -202,8 +224,8 @@ void dbconn::update_game_end(const std::string& uuid, int game_id, const std::st
{
try
{
modify("UPDATE `"+db_game_info_table_+"` SET END_TIME = CURRENT_TIMESTAMP, REPLAY_NAME = ? WHERE INSTANCE_UUID = ? AND GAME_ID = ?",
replay_location, uuid, game_id);
modify(connection_, "UPDATE `"+db_game_info_table_+"` SET END_TIME = CURRENT_TIMESTAMP, REPLAY_NAME = ? WHERE INSTANCE_UUID = ? AND GAME_ID = ?",
replay_location, uuid, game_id);
}
catch(const mariadb::exception::base& e)
{
Expand All @@ -214,8 +236,8 @@ void dbconn::insert_game_player_info(const std::string& uuid, int game_id, const
{
try
{
modify("INSERT INTO `"+db_game_player_info_table_+"`(INSTANCE_UUID, GAME_ID, USER_ID, SIDE_NUMBER, IS_HOST, FACTION, CLIENT_VERSION, CLIENT_SOURCE, USER_NAME) VALUES(?, ?, IFNULL((SELECT user_id FROM `"+db_users_table_+"` WHERE username = ?), -1), ?, ?, ?, ?, ?, ?)",
uuid, game_id, username, side_number, is_host, faction, version, source, current_user);
modify(connection_, "INSERT INTO `"+db_game_player_info_table_+"`(INSTANCE_UUID, GAME_ID, USER_ID, SIDE_NUMBER, IS_HOST, FACTION, CLIENT_VERSION, CLIENT_SOURCE, USER_NAME) VALUES(?, ?, IFNULL((SELECT user_id FROM `"+db_users_table_+"` WHERE username = ?), -1), ?, ?, ?, ?, ?, ?)",
uuid, game_id, username, side_number, is_host, faction, version, source, current_user);
}
catch(const mariadb::exception::base& e)
{
Expand All @@ -226,8 +248,8 @@ void dbconn::insert_modification_info(const std::string& uuid, int game_id, cons
{
try
{
modify("INSERT INTO `"+db_game_modification_info_table_+"`(INSTANCE_UUID, GAME_ID, MODIFICATION_NAME, SOURCE_ADDON, VERSION) VALUES(?, ?, ?, ?, ?)",
uuid, game_id, modification_name, modification_source, modification_version);
modify(connection_, "INSERT INTO `"+db_game_modification_info_table_+"`(INSTANCE_UUID, GAME_ID, MODIFICATION_NAME, SOURCE_ADDON, VERSION) VALUES(?, ?, ?, ?, ?)",
uuid, game_id, modification_name, modification_source, modification_version);
}
catch(const mariadb::exception::base& e)
{
Expand All @@ -238,7 +260,7 @@ void dbconn::set_oos_flag(const std::string& uuid, int game_id)
{
try
{
modify("UPDATE `"+db_game_info_table_+"` SET OOS = 1 WHERE INSTANCE_UUID = ? AND GAME_ID = ?",
modify(connection_, "UPDATE `"+db_game_info_table_+"` SET OOS = 1 WHERE INSTANCE_UUID = ? AND GAME_ID = ?",
uuid, game_id);
}
catch(const mariadb::exception::base& e)
Expand All @@ -252,18 +274,18 @@ void dbconn::set_oos_flag(const std::string& uuid, int game_id)
// therefore for queries that can return multiple rows of multiple columns, implement a class to define how the results should be read
//
template<typename... Args>
void dbconn::get_complex_results(rs_base& base, const std::string& sql, Args&&... args)
void dbconn::get_complex_results(mariadb::connection_ref connection, rs_base& base, const std::string& sql, Args&&... args)
{
mariadb::result_set_ref rslt = select(sql, args...);
mariadb::result_set_ref rslt = select(connection, sql, args...);
base.read(rslt);
}
//
// get single values
//
template<typename... Args>
std::string dbconn::get_single_string(const std::string& sql, Args&&... args)
std::string dbconn::get_single_string(mariadb::connection_ref connection, const std::string& sql, Args&&... args)
{
mariadb::result_set_ref rslt = select(sql, args...);
mariadb::result_set_ref rslt = select(connection, sql, args...);
if(rslt->next())
{
return rslt->get_string(0);
Expand All @@ -274,9 +296,9 @@ std::string dbconn::get_single_string(const std::string& sql, Args&&... args)
}
}
template<typename... Args>
int dbconn::get_single_int(const std::string& sql, Args&&... args)
long dbconn::get_single_long(mariadb::connection_ref connection, const std::string& sql, Args&&... args)
{
mariadb::result_set_ref rslt = select(sql, args...);
mariadb::result_set_ref rslt = select(connection, sql, args...);
if(rslt->next())
{
// mariadbpp checks for strict integral equivalence, but we don't care
Expand All @@ -292,58 +314,55 @@ int dbconn::get_single_int(const std::string& sql, Args&&... args)
case mariadb::value::type::unsigned32:
case mariadb::value::type::signed32:
return rslt->get_signed32(0);
case mariadb::value::type::unsigned64:
case mariadb::value::type::signed64:
return rslt->get_signed64(0);
default:
throw mariadb::exception::base("Value retrieved was not an int!");
throw mariadb::exception::base("Value retrieved was not a long!");
}
}
else
{
throw mariadb::exception::base("No int value found in the database!");
throw mariadb::exception::base("No long value found in the database!");
}
}
template<typename... Args>
bool dbconn::exists(const std::string& sql, Args&&... args)
bool dbconn::exists(mariadb::connection_ref connection, const std::string& sql, Args&&... args)
{
mariadb::result_set_ref rslt = select(sql, args...);
mariadb::result_set_ref rslt = select(connection, sql, args...);
return rslt->next();
}

//
// select or modify values
//
template<typename... Args>
mariadb::result_set_ref dbconn::select(const std::string& sql, Args&&... args)
mariadb::result_set_ref dbconn::select(mariadb::connection_ref connection, const std::string& sql, Args&&... args)
{
try
{
mariadb::statement_ref stmt = query(sql, args...);
return mariadb::result_set_ref(stmt->query());
mariadb::statement_ref stmt = query(connection, sql, args...);
mariadb::result_set_ref rslt = mariadb::result_set_ref(stmt->query());
return rslt;
}
catch(const mariadb::exception::base& e)
{
if(!connection_->connected())
{
ERR_SQL << "Connection is invalid!" << std::endl;
}
ERR_SQL << "SQL query failed for query: `"+sql+"`!" << std::endl;
ERR_SQL << "SQL query failed for query: `"+sql+"`" << std::endl;
throw e;
}
}
template<typename... Args>
int dbconn::modify(const std::string& sql, Args&&... args)
int dbconn::modify(mariadb::connection_ref connection, const std::string& sql, Args&&... args)
{
try
{
mariadb::statement_ref stmt = query(sql, args...);
return stmt->insert();
mariadb::statement_ref stmt = query(connection, sql, args...);
int count = stmt->insert();
return count;
}
catch(const mariadb::exception::base& e)
{
if(!connection_->connected())
{
ERR_SQL << "Connection is invalid!" << std::endl;
}
ERR_SQL << "SQL query failed for query: `"+sql+"`!" << std::endl;
ERR_SQL << "SQL query failed for query: `"+sql+"`" << std::endl;
throw e;
}
}
Expand All @@ -352,9 +371,9 @@ int dbconn::modify(const std::string& sql, Args&&... args)
// start of recursive unpacking of variadic template in order to be able to call correct parameterized setters on query
//
template<typename... Args>
mariadb::statement_ref dbconn::query(const std::string& sql, Args&&... args)
mariadb::statement_ref dbconn::query(mariadb::connection_ref connection, const std::string& sql, Args&&... args)
{
mariadb::statement_ref stmt = connection_->create_statement(sql);
mariadb::statement_ref stmt = connection->create_statement(sql);
prepare(stmt, 0, args...);
return stmt;
}
Expand Down

0 comments on commit 467e431

Please sign in to comment.