Skip to content

Commit

Permalink
Fix MySQL handle leak (#4722)
Browse files Browse the repository at this point in the history
The first leak in #4288 is related to handle leaking, I couldn't really reproduce how mysql_init is called without mysql_close not being called on the line before, but using a unique_ptr with a custom deleter fixes the leak.

Using unique_ptr for MYSQL_RES also fixes eventual leaks that happen on queries.
  • Loading branch information
ranisalt committed Jun 5, 2024
1 parent e3cae60 commit e80e408
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 44 deletions.
58 changes: 26 additions & 32 deletions src/database.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

#include <mysql/errmsg.h>

static bool connectToDatabase(MYSQL*& handle, const bool retryIfError)
static tfs::detail::Mysql_ptr connectToDatabase(const bool retryIfError)
{
bool isFirstAttemptToConnect = true;

Expand All @@ -19,29 +19,26 @@ static bool connectToDatabase(MYSQL*& handle, const bool retryIfError)
}
isFirstAttemptToConnect = false;

// close the connection handle
mysql_close(handle);
// connection handle initialization
handle = mysql_init(nullptr);
tfs::detail::Mysql_ptr handle{mysql_init(nullptr)};
if (!handle) {
std::cout << std::endl << "Failed to initialize MySQL connection handle." << std::endl;
goto error;
}
// connects to database
if (!mysql_real_connect(handle, getString(ConfigManager::MYSQL_HOST).c_str(),
if (!mysql_real_connect(handle.get(), getString(ConfigManager::MYSQL_HOST).c_str(),
getString(ConfigManager::MYSQL_USER).c_str(), getString(ConfigManager::MYSQL_PASS).c_str(),
getString(ConfigManager::MYSQL_DB).c_str(), getNumber(ConfigManager::SQL_PORT),
getString(ConfigManager::MYSQL_SOCK).c_str(), 0)) {
std::cout << std::endl << "MySQL Error Message: " << mysql_error(handle) << std::endl;
std::cout << std::endl << "MySQL Error Message: " << mysql_error(handle.get()) << std::endl;
goto error;
}
return true;
return handle;

error:
if (retryIfError) {
goto retry;
}
return false;
return nullptr;
}

static bool isLostConnectionError(const unsigned error)
Expand All @@ -50,27 +47,28 @@ static bool isLostConnectionError(const unsigned error)
error == 1053 /*ER_SERVER_SHUTDOWN*/ || error == CR_CONNECTION_ERROR;
}

static bool executeQuery(MYSQL*& handle, std::string_view query, const bool retryIfLostConnection)
static bool executeQuery(tfs::detail::Mysql_ptr& handle, std::string_view query, const bool retryIfLostConnection)
{
while (mysql_real_query(handle, query.data(), query.length()) != 0) {
while (mysql_real_query(handle.get(), query.data(), query.length()) != 0) {
std::cout << "[Error - mysql_real_query] Query: " << query.substr(0, 256) << std::endl
<< "Message: " << mysql_error(handle) << std::endl;
const unsigned error = mysql_errno(handle);
<< "Message: " << mysql_error(handle.get()) << std::endl;
const unsigned error = mysql_errno(handle.get());
if (!isLostConnectionError(error) || !retryIfLostConnection) {
return false;
}
connectToDatabase(handle, true);
handle = connectToDatabase(true);
}
return true;
}

Database::~Database() { mysql_close(handle); }

bool Database::connect()
{
if (!connectToDatabase(handle, false)) {
auto newHandle = connectToDatabase(false);
if (!newHandle) {
return false;
}

handle = std::move(newHandle);
DBResult_ptr result = storeQuery("SHOW VARIABLES LIKE 'max_allowed_packet'");
if (result) {
maxPacketSize = result->getNumber<uint64_t>("Value");
Expand Down Expand Up @@ -122,19 +120,19 @@ DBResult_ptr Database::storeQuery(std::string_view query)

// we should call that every time as someone would call executeQuery('SELECT...')
// as it is described in MySQL manual: "it doesn't hurt" :P
MYSQL_RES* res = mysql_store_result(handle);
tfs::detail::MysqlResult_ptr res{mysql_store_result(handle.get())};
if (!res) {
std::cout << "[Error - mysql_store_result] Query: " << query << std::endl
<< "Message: " << mysql_error(handle) << std::endl;
const unsigned error = mysql_errno(handle);
<< "Message: " << mysql_error(handle.get()) << std::endl;
const unsigned error = mysql_errno(handle.get());
if (!isLostConnectionError(error) || !retryQueries) {
return nullptr;
}
goto retry;
}

// retrieving results of query
DBResult_ptr result = std::make_shared<DBResult>(res);
DBResult_ptr result = std::make_shared<DBResult>(std::move(res));
if (!result->hasNext()) {
return nullptr;
}
Expand All @@ -152,7 +150,7 @@ std::string Database::escapeBlob(const char* s, uint32_t length) const

if (length != 0) {
char* output = new char[maxLength];
mysql_real_escape_string(handle, output, s, length);
mysql_real_escape_string(handle.get(), output, s, length);
escaped.append(output);
delete[] output;
}
Expand All @@ -161,23 +159,19 @@ std::string Database::escapeBlob(const char* s, uint32_t length) const
return escaped;
}

DBResult::DBResult(MYSQL_RES* res)
DBResult::DBResult(tfs::detail::MysqlResult_ptr&& res) : handle{std::move(res)}
{
handle = res;

size_t i = 0;

MYSQL_FIELD* field = mysql_fetch_field(handle);
MYSQL_FIELD* field = mysql_fetch_field(handle.get());
while (field) {
listNames[field->name] = i++;
field = mysql_fetch_field(handle);
field = mysql_fetch_field(handle.get());
}

row = mysql_fetch_row(handle);
row = mysql_fetch_row(handle.get());
}

DBResult::~DBResult() { mysql_free_result(handle); }

std::string_view DBResult::getString(std::string_view column) const
{
auto it = listNames.find(column);
Expand All @@ -191,15 +185,15 @@ std::string_view DBResult::getString(std::string_view column) const
return {};
}

auto size = mysql_fetch_lengths(handle)[it->second];
auto size = mysql_fetch_lengths(handle.get())[it->second];
return {row[it->second], size};
}

bool DBResult::hasNext() const { return row; }

bool DBResult::next()
{
row = mysql_fetch_row(handle);
row = mysql_fetch_row(handle.get());
return row;
}

Expand Down
29 changes: 17 additions & 12 deletions src/database.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,22 @@
class DBResult;
using DBResult_ptr = std::shared_ptr<DBResult>;

class Database
namespace tfs::detail {

struct MysqlDeleter
{
public:
Database() = default;
~Database();
void operator()(MYSQL* handle) const { mysql_close(handle); }
void operator()(MYSQL_RES* handle) const { mysql_free_result(handle); }
};

// non-copyable
Database(const Database&) = delete;
Database& operator=(const Database&) = delete;
using Mysql_ptr = std::unique_ptr<MYSQL, MysqlDeleter>;
using MysqlResult_ptr = std::unique_ptr<MYSQL_RES, MysqlDeleter>;

} // namespace tfs::detail

class Database
{
public:
/**
* Singleton implementation.
*
Expand Down Expand Up @@ -84,7 +90,7 @@ class Database
* @return id on success, 0 if last query did not result on any rows with
* auto_increment keys
*/
uint64_t getLastInsertId() const { return static_cast<uint64_t>(mysql_insert_id(handle)); }
uint64_t getLastInsertId() const { return static_cast<uint64_t>(mysql_insert_id(handle.get())); }

/**
* Get database engine version
Expand All @@ -108,7 +114,7 @@ class Database
bool rollback();
bool commit();

MYSQL* handle = nullptr;
tfs::detail::Mysql_ptr handle = nullptr;
std::recursive_mutex databaseLock;
uint64_t maxPacketSize = 1048576;
// Do not retry queries if we are in the middle of a transaction
Expand All @@ -120,8 +126,7 @@ class Database
class DBResult
{
public:
explicit DBResult(MYSQL_RES* res);
~DBResult();
explicit DBResult(tfs::detail::MysqlResult_ptr&& res);

// non-copyable
DBResult(const DBResult&) = delete;
Expand Down Expand Up @@ -150,7 +155,7 @@ class DBResult
bool next();

private:
MYSQL_RES* handle;
tfs::detail::MysqlResult_ptr handle;
MYSQL_ROW row;

std::map<std::string_view, size_t> listNames;
Expand Down

0 comments on commit e80e408

Please sign in to comment.