diff --git a/src/database.cpp b/src/database.cpp index aaf2361318..20bf4273f5 100644 --- a/src/database.cpp +++ b/src/database.cpp @@ -9,7 +9,7 @@ #include -static bool connectToDatabase(MYSQL*& handle, const bool retryIfError) +static tfs::detail::Mysql_ptr connectToDatabase(const bool retryIfError) { bool isFirstAttemptToConnect = true; @@ -19,29 +19,26 @@ static bool connectToDatabase(MYSQL*& handle, const bool retryIfError) } isFirstAttemptToConnect = false; - // close the connection handle - mysql_close(handle); - // connection handle initialization - handle = mysql_init(nullptr); + tfs::detail::Mysql_ptr handle{mysql_init(nullptr)}; if (!handle) { std::cout << std::endl << "Failed to initialize MySQL connection handle." << std::endl; goto error; } // connects to database - if (!mysql_real_connect(handle, getString(ConfigManager::MYSQL_HOST).c_str(), + if (!mysql_real_connect(handle.get(), getString(ConfigManager::MYSQL_HOST).c_str(), getString(ConfigManager::MYSQL_USER).c_str(), getString(ConfigManager::MYSQL_PASS).c_str(), getString(ConfigManager::MYSQL_DB).c_str(), getNumber(ConfigManager::SQL_PORT), getString(ConfigManager::MYSQL_SOCK).c_str(), 0)) { - std::cout << std::endl << "MySQL Error Message: " << mysql_error(handle) << std::endl; + std::cout << std::endl << "MySQL Error Message: " << mysql_error(handle.get()) << std::endl; goto error; } - return true; + return handle; error: if (retryIfError) { goto retry; } - return false; + return nullptr; } static bool isLostConnectionError(const unsigned error) @@ -50,27 +47,28 @@ static bool isLostConnectionError(const unsigned error) error == 1053 /*ER_SERVER_SHUTDOWN*/ || error == CR_CONNECTION_ERROR; } -static bool executeQuery(MYSQL*& handle, std::string_view query, const bool retryIfLostConnection) +static bool executeQuery(tfs::detail::Mysql_ptr& handle, std::string_view query, const bool retryIfLostConnection) { - while (mysql_real_query(handle, query.data(), query.length()) != 0) { + while (mysql_real_query(handle.get(), query.data(), query.length()) != 0) { std::cout << "[Error - mysql_real_query] Query: " << query.substr(0, 256) << std::endl - << "Message: " << mysql_error(handle) << std::endl; - const unsigned error = mysql_errno(handle); + << "Message: " << mysql_error(handle.get()) << std::endl; + const unsigned error = mysql_errno(handle.get()); if (!isLostConnectionError(error) || !retryIfLostConnection) { return false; } - connectToDatabase(handle, true); + handle = connectToDatabase(true); } return true; } -Database::~Database() { mysql_close(handle); } - bool Database::connect() { - if (!connectToDatabase(handle, false)) { + auto newHandle = connectToDatabase(false); + if (!newHandle) { return false; } + + handle = std::move(newHandle); DBResult_ptr result = storeQuery("SHOW VARIABLES LIKE 'max_allowed_packet'"); if (result) { maxPacketSize = result->getNumber("Value"); @@ -122,11 +120,11 @@ DBResult_ptr Database::storeQuery(std::string_view query) // we should call that every time as someone would call executeQuery('SELECT...') // as it is described in MySQL manual: "it doesn't hurt" :P - MYSQL_RES* res = mysql_store_result(handle); + tfs::detail::MysqlResult_ptr res{mysql_store_result(handle.get())}; if (!res) { std::cout << "[Error - mysql_store_result] Query: " << query << std::endl - << "Message: " << mysql_error(handle) << std::endl; - const unsigned error = mysql_errno(handle); + << "Message: " << mysql_error(handle.get()) << std::endl; + const unsigned error = mysql_errno(handle.get()); if (!isLostConnectionError(error) || !retryQueries) { return nullptr; } @@ -134,7 +132,7 @@ DBResult_ptr Database::storeQuery(std::string_view query) } // retrieving results of query - DBResult_ptr result = std::make_shared(res); + DBResult_ptr result = std::make_shared(std::move(res)); if (!result->hasNext()) { return nullptr; } @@ -152,7 +150,7 @@ std::string Database::escapeBlob(const char* s, uint32_t length) const if (length != 0) { char* output = new char[maxLength]; - mysql_real_escape_string(handle, output, s, length); + mysql_real_escape_string(handle.get(), output, s, length); escaped.append(output); delete[] output; } @@ -161,23 +159,19 @@ std::string Database::escapeBlob(const char* s, uint32_t length) const return escaped; } -DBResult::DBResult(MYSQL_RES* res) +DBResult::DBResult(tfs::detail::MysqlResult_ptr&& res) : handle{std::move(res)} { - handle = res; - size_t i = 0; - MYSQL_FIELD* field = mysql_fetch_field(handle); + MYSQL_FIELD* field = mysql_fetch_field(handle.get()); while (field) { listNames[field->name] = i++; - field = mysql_fetch_field(handle); + field = mysql_fetch_field(handle.get()); } - row = mysql_fetch_row(handle); + row = mysql_fetch_row(handle.get()); } -DBResult::~DBResult() { mysql_free_result(handle); } - std::string_view DBResult::getString(std::string_view column) const { auto it = listNames.find(column); @@ -191,7 +185,7 @@ std::string_view DBResult::getString(std::string_view column) const return {}; } - auto size = mysql_fetch_lengths(handle)[it->second]; + auto size = mysql_fetch_lengths(handle.get())[it->second]; return {row[it->second], size}; } @@ -199,7 +193,7 @@ bool DBResult::hasNext() const { return row; } bool DBResult::next() { - row = mysql_fetch_row(handle); + row = mysql_fetch_row(handle.get()); return row; } diff --git a/src/database.h b/src/database.h index f5633fee02..199b69a7d6 100644 --- a/src/database.h +++ b/src/database.h @@ -9,16 +9,22 @@ class DBResult; using DBResult_ptr = std::shared_ptr; -class Database +namespace tfs::detail { + +struct MysqlDeleter { -public: - Database() = default; - ~Database(); + void operator()(MYSQL* handle) const { mysql_close(handle); } + void operator()(MYSQL_RES* handle) const { mysql_free_result(handle); } +}; - // non-copyable - Database(const Database&) = delete; - Database& operator=(const Database&) = delete; +using Mysql_ptr = std::unique_ptr; +using MysqlResult_ptr = std::unique_ptr; +} // namespace tfs::detail + +class Database +{ +public: /** * Singleton implementation. * @@ -84,7 +90,7 @@ class Database * @return id on success, 0 if last query did not result on any rows with * auto_increment keys */ - uint64_t getLastInsertId() const { return static_cast(mysql_insert_id(handle)); } + uint64_t getLastInsertId() const { return static_cast(mysql_insert_id(handle.get())); } /** * Get database engine version @@ -108,7 +114,7 @@ class Database bool rollback(); bool commit(); - MYSQL* handle = nullptr; + tfs::detail::Mysql_ptr handle = nullptr; std::recursive_mutex databaseLock; uint64_t maxPacketSize = 1048576; // Do not retry queries if we are in the middle of a transaction @@ -120,8 +126,7 @@ class Database class DBResult { public: - explicit DBResult(MYSQL_RES* res); - ~DBResult(); + explicit DBResult(tfs::detail::MysqlResult_ptr&& res); // non-copyable DBResult(const DBResult&) = delete; @@ -150,7 +155,7 @@ class DBResult bool next(); private: - MYSQL_RES* handle; + tfs::detail::MysqlResult_ptr handle; MYSQL_ROW row; std::map listNames;