Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix MySQL handle leak #4722

Merged
merged 1 commit into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 26 additions & 32 deletions src/database.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

#include <mysql/errmsg.h>

static bool connectToDatabase(MYSQL*& handle, const bool retryIfError)
static tfs::detail::Mysql_ptr connectToDatabase(const bool retryIfError)
{
bool isFirstAttemptToConnect = true;

Expand All @@ -19,29 +19,26 @@
}
isFirstAttemptToConnect = false;

// close the connection handle
mysql_close(handle);
// connection handle initialization
handle = mysql_init(nullptr);
tfs::detail::Mysql_ptr handle{mysql_init(nullptr)};
if (!handle) {
std::cout << std::endl << "Failed to initialize MySQL connection handle." << std::endl;
goto error;
}
// connects to database
if (!mysql_real_connect(handle, getString(ConfigManager::MYSQL_HOST).c_str(),
if (!mysql_real_connect(handle.get(), getString(ConfigManager::MYSQL_HOST).c_str(),
getString(ConfigManager::MYSQL_USER).c_str(), getString(ConfigManager::MYSQL_PASS).c_str(),
getString(ConfigManager::MYSQL_DB).c_str(), getNumber(ConfigManager::SQL_PORT),
getString(ConfigManager::MYSQL_SOCK).c_str(), 0)) {
std::cout << std::endl << "MySQL Error Message: " << mysql_error(handle) << std::endl;
std::cout << std::endl << "MySQL Error Message: " << mysql_error(handle.get()) << std::endl;
Dismissed Show dismissed Hide dismissed
Dismissed Show dismissed Hide dismissed
goto error;
}
return true;
return handle;

error:
if (retryIfError) {
goto retry;
}
return false;
return nullptr;
}

static bool isLostConnectionError(const unsigned error)
Expand All @@ -50,27 +47,28 @@
error == 1053 /*ER_SERVER_SHUTDOWN*/ || error == CR_CONNECTION_ERROR;
}

static bool executeQuery(MYSQL*& handle, std::string_view query, const bool retryIfLostConnection)
static bool executeQuery(tfs::detail::Mysql_ptr& handle, std::string_view query, const bool retryIfLostConnection)
{
while (mysql_real_query(handle, query.data(), query.length()) != 0) {
while (mysql_real_query(handle.get(), query.data(), query.length()) != 0) {
std::cout << "[Error - mysql_real_query] Query: " << query.substr(0, 256) << std::endl
<< "Message: " << mysql_error(handle) << std::endl;
const unsigned error = mysql_errno(handle);
<< "Message: " << mysql_error(handle.get()) << std::endl;
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
const unsigned error = mysql_errno(handle.get());
if (!isLostConnectionError(error) || !retryIfLostConnection) {
return false;
}
connectToDatabase(handle, true);
handle = connectToDatabase(true);
}
return true;
}

Database::~Database() { mysql_close(handle); }

bool Database::connect()
{
if (!connectToDatabase(handle, false)) {
auto newHandle = connectToDatabase(false);
if (!newHandle) {
return false;
}

handle = std::move(newHandle);
DBResult_ptr result = storeQuery("SHOW VARIABLES LIKE 'max_allowed_packet'");
if (result) {
maxPacketSize = result->getNumber<uint64_t>("Value");
Expand Down Expand Up @@ -122,19 +120,19 @@

// we should call that every time as someone would call executeQuery('SELECT...')
// as it is described in MySQL manual: "it doesn't hurt" :P
MYSQL_RES* res = mysql_store_result(handle);
tfs::detail::MysqlResult_ptr res{mysql_store_result(handle.get())};
if (!res) {
std::cout << "[Error - mysql_store_result] Query: " << query << std::endl
<< "Message: " << mysql_error(handle) << std::endl;
const unsigned error = mysql_errno(handle);
<< "Message: " << mysql_error(handle.get()) << std::endl;
Fixed Show fixed Hide fixed
const unsigned error = mysql_errno(handle.get());
if (!isLostConnectionError(error) || !retryQueries) {
return nullptr;
}
goto retry;
}

// retrieving results of query
DBResult_ptr result = std::make_shared<DBResult>(res);
DBResult_ptr result = std::make_shared<DBResult>(std::move(res));
if (!result->hasNext()) {
return nullptr;
}
Expand All @@ -152,7 +150,7 @@

if (length != 0) {
char* output = new char[maxLength];
mysql_real_escape_string(handle, output, s, length);
mysql_real_escape_string(handle.get(), output, s, length);
escaped.append(output);
delete[] output;
}
Expand All @@ -161,23 +159,19 @@
return escaped;
}

DBResult::DBResult(MYSQL_RES* res)
DBResult::DBResult(tfs::detail::MysqlResult_ptr&& res) : handle{std::move(res)}
{
handle = res;

size_t i = 0;

MYSQL_FIELD* field = mysql_fetch_field(handle);
MYSQL_FIELD* field = mysql_fetch_field(handle.get());
while (field) {
listNames[field->name] = i++;
field = mysql_fetch_field(handle);
field = mysql_fetch_field(handle.get());
}

row = mysql_fetch_row(handle);
row = mysql_fetch_row(handle.get());
}

DBResult::~DBResult() { mysql_free_result(handle); }

std::string_view DBResult::getString(std::string_view column) const
{
auto it = listNames.find(column);
Expand All @@ -191,15 +185,15 @@
return {};
}

auto size = mysql_fetch_lengths(handle)[it->second];
auto size = mysql_fetch_lengths(handle.get())[it->second];
return {row[it->second], size};
}

bool DBResult::hasNext() const { return row; }

bool DBResult::next()
{
row = mysql_fetch_row(handle);
row = mysql_fetch_row(handle.get());
return row;
}

Expand Down
29 changes: 17 additions & 12 deletions src/database.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,22 @@
class DBResult;
using DBResult_ptr = std::shared_ptr<DBResult>;

class Database
namespace tfs::detail {

struct MysqlDeleter
{
public:
Database() = default;
~Database();
void operator()(MYSQL* handle) const { mysql_close(handle); }
void operator()(MYSQL_RES* handle) const { mysql_free_result(handle); }
};

// non-copyable
Database(const Database&) = delete;
Database& operator=(const Database&) = delete;
using Mysql_ptr = std::unique_ptr<MYSQL, MysqlDeleter>;
using MysqlResult_ptr = std::unique_ptr<MYSQL_RES, MysqlDeleter>;

} // namespace tfs::detail

class Database
{
public:
/**
* Singleton implementation.
*
Expand Down Expand Up @@ -84,7 +90,7 @@ class Database
* @return id on success, 0 if last query did not result on any rows with
* auto_increment keys
*/
uint64_t getLastInsertId() const { return static_cast<uint64_t>(mysql_insert_id(handle)); }
uint64_t getLastInsertId() const { return static_cast<uint64_t>(mysql_insert_id(handle.get())); }

/**
* Get database engine version
Expand All @@ -108,7 +114,7 @@ class Database
bool rollback();
bool commit();

MYSQL* handle = nullptr;
tfs::detail::Mysql_ptr handle = nullptr;
std::recursive_mutex databaseLock;
uint64_t maxPacketSize = 1048576;
// Do not retry queries if we are in the middle of a transaction
Expand All @@ -120,8 +126,7 @@ class Database
class DBResult
{
public:
explicit DBResult(MYSQL_RES* res);
~DBResult();
explicit DBResult(tfs::detail::MysqlResult_ptr&& res);

// non-copyable
DBResult(const DBResult&) = delete;
Expand Down Expand Up @@ -150,7 +155,7 @@ class DBResult
bool next();

private:
MYSQL_RES* handle;
tfs::detail::MysqlResult_ptr handle;
MYSQL_ROW row;

std::map<std::string_view, size_t> listNames;
Expand Down
Loading