Permalink
Browse files

Fix problem where prepared statements would get mixed up after a DB r…

…econnect
  • Loading branch information...
1 parent 82a3870 commit 00772db78af6e60dd1e14abce9b400b67e88c0d7 @rajkosto committed Nov 17, 2012
@@ -79,7 +79,7 @@ class Database
//allocate index for prepared statement with SQL request 'fmt'
virtual SqlStatement* CreateStatement(SqlStatementID& index, const std::string& fmt) = 0;
//get prepared statement format string
- virtual std::string GetStmtString(const int stmtId) const = 0;
+ virtual std::string GetStmtString(UInt32 stmtId) const = 0;
virtual operator bool () const = 0;
@@ -32,7 +32,7 @@ static const size_t MAX_CONNECTION_POOL_SIZE = 16;
//////////////////////////////////////////////////////////////////////////
ConcreteDatabase::ConcreteDatabase() : m_pAsyncConn(NULL), m_pResultQueue(NULL), m_threadBody(NULL), m_delayThread(NULL),
- m_logSQL(false), m_pingIntervallms(0), m_nQueryConnPoolSize(1), m_bAllowAsyncTransactions(false), m_iStmtIndex(-1), _logger(nullptr)
+ m_logSQL(false), m_pingIntervallms(0), m_nQueryConnPoolSize(1), m_bAllowAsyncTransactions(false), _logger(nullptr)
{
m_nQueryCounter = -1;
}
@@ -462,7 +462,7 @@ bool ConcreteDatabase::ExecuteStmt( const SqlStatementID& id, SqlStmtParameters*
if(pTrans)
{
//add SQL request to trans queue
- pTrans->DelayExecute(new SqlPreparedRequest(id.ID(), params));
+ pTrans->DelayExecute(new SqlPreparedRequest(id, params));
}
else
{
@@ -471,7 +471,7 @@ bool ConcreteDatabase::ExecuteStmt( const SqlStatementID& id, SqlStmtParameters*
return DirectExecuteStmt(id, params);
// Simple sql statement
- m_threadBody->Delay(new SqlPreparedRequest(id.ID(), params));
+ m_threadBody->Delay(new SqlPreparedRequest(id, params));
}
return true;
@@ -483,27 +483,17 @@ bool ConcreteDatabase::DirectExecuteStmt( const SqlStatementID& id, SqlStmtParam
std::auto_ptr<SqlStmtParameters> p(params);
//execute statement
SqlConnection::Lock _guard(getAsyncConnection());
- return _guard->ExecuteStmt(id.ID(), *params);
+ return _guard->ExecuteStmt(id, *params);
}
SqlStatement* ConcreteDatabase::CreateStatement(SqlStatementID& index, const std::string& fmt )
{
- int nId = -1;
- //check if statement ID is initialized
- if(!index.initialized())
+ //initialize the statement if its not
+ if(!index.isInitialized())
{
//count input parameters
- int nParams = std::count(fmt.begin(), fmt.end(), '?');
- //find existing or add a new record in registry
- LOCK_GUARD _guard(m_stmtGuard);
- PreparedStmtRegistry::const_iterator iter = m_stmtRegistry.find(fmt);
- if(iter == m_stmtRegistry.end())
- {
- nId = ++m_iStmtIndex;
- m_stmtRegistry[fmt] = nId;
- }
- else
- nId = iter->second;
+ size_t nParams = std::count(fmt.begin(), fmt.end(), '?');
+ UInt32 nId = m_prepStmtRegistry.getStmtId(fmt);
//save initialized statement index info
index.init(nId, nParams);
@@ -512,21 +502,9 @@ SqlStatement* ConcreteDatabase::CreateStatement(SqlStatementID& index, const std
return new SqlStatementImpl(index, *this);
}
-std::string ConcreteDatabase::GetStmtString(const int stmtId) const
+std::string ConcreteDatabase::GetStmtString(UInt32 stmtId) const
{
- LOCK_GUARD _guard(m_stmtGuard);
-
- if(stmtId == -1 || stmtId > m_iStmtIndex)
- return std::string();
-
- PreparedStmtRegistry::const_iterator iter_last = m_stmtRegistry.end();
- for(PreparedStmtRegistry::const_iterator iter = m_stmtRegistry.begin(); iter != iter_last; ++iter)
- {
- if(iter->second == stmtId)
- return iter->first;
- }
-
- return std::string();
+ return m_prepStmtRegistry.getStmtString(stmtId);
}
bool ConcreteDatabase::DoDelay( const char* sql, QueryCallback callback )
@@ -572,3 +550,46 @@ void ConcreteDatabase::TransHelper::reset()
m_pTrans = NULL;
}
}
+
+namespace
+{
+ UInt32 GetUniqueId()
+ {
+ UInt64 cpuTimer = __rdtsc();
+ UInt32 mixed = cpuTimer & 0xFFFFFFFF;
+ mixed ^= (UInt64(cpuTimer >> 32)) & 0xFFFFFFFF;
+ return mixed;
+ }
+};
+
+UInt32 ConcreteDatabase::PreparedStmtRegistry::getStmtId( const std::string& fmt )
+{
+ UInt32 nId;
+
+ RegistryGuardType _guard(_lock);
+ StatementMap::const_iterator iter = _map.find(fmt);
+ if(iter == _map.end())
+ {
+ nId = GetUniqueId();
+ _map[fmt] = nId;
+ }
+ else
+ nId = iter->second;
+
+ return nId;
+}
+
+std::string ConcreteDatabase::PreparedStmtRegistry::getStmtString( UInt32 stmtId ) const
+{
+ if(stmtId == 0)
+ return std::string();
+
+ RegistryGuardType _guard(_lock);
+ for(StatementMap::const_iterator iter = _map.begin(); iter != _map.end(); ++iter)
+ {
+ if(iter->second == stmtId)
+ return iter->first;
+ }
+
+ return std::string();
+}
@@ -80,7 +80,7 @@ class ConcreteDatabase : public Database
//allocate index for prepared statement with SQL request 'fmt'
SqlStatement* CreateStatement(SqlStatementID& index, const std::string& fmt) override;
//get prepared statement format string
- std::string GetStmtString(const int stmtId) const override;
+ std::string GetStmtString(UInt32 stmtId) const override;
operator bool () const override { return (m_pQueryConnections.size() && m_pAsyncConn != 0); }
@@ -157,32 +157,37 @@ class ConcreteDatabase : public Database
bool DirectExecuteStmt(const SqlStatementID& id, SqlStmtParameters* params);
//connection helper counters
- int m_nQueryConnPoolSize; //current size of query connection pool
+ int m_nQueryConnPoolSize; //current size of query connection pool
Poco::AtomicCounter m_nQueryCounter; //counter for connection selection
//lets use pool of connections for sync queries
- typedef std::vector< SqlConnection* > SqlConnectionContainer;
+ typedef std::vector<SqlConnection*> SqlConnectionContainer;
SqlConnectionContainer m_pQueryConnections;
//only one single DB connection for transactions
SqlConnection* m_pAsyncConn;
- SqlResultQueue* m_pResultQueue; ///< Transaction queues from diff. threads
- SqlDelayThread* m_threadBody; ///< Pointer to delay sql executer (owned by m_delayThread)
- Poco::Thread* m_delayThread; ///< Pointer to executer thread
+ SqlResultQueue* m_pResultQueue; ///< Transaction queues from diff. threads
+ SqlDelayThread* m_threadBody; ///< Pointer to delay sql executer (owned by m_delayThread)
+ Poco::Thread* m_delayThread; ///< Pointer to executer thread
- bool m_bAllowAsyncTransactions; ///< flag which specifies if async transactions are enabled
+ bool m_bAllowAsyncTransactions; ///< flag which specifies if async transactions are enabled
//PREPARED STATEMENT REGISTRY
- typedef Poco::FastMutex LOCK_TYPE;
- typedef Poco::ScopedLock<LOCK_TYPE> LOCK_GUARD;
-
- mutable LOCK_TYPE m_stmtGuard;
-
- typedef boost::unordered_map<std::string, int> PreparedStmtRegistry;
- PreparedStmtRegistry m_stmtRegistry; ///<
+ class PreparedStmtRegistry
+ {
+ public:
+ //find existing or add a new record in registry
+ UInt32 getStmtId(const std::string& fmt);
+ std::string getStmtString(UInt32 stmtId) const;
+ private:
+ typedef Poco::FastMutex RegistryLockType;
+ typedef Poco::ScopedLock<RegistryLockType> RegistryGuardType;
+ mutable RegistryLockType _lock; //guards _map
- int m_iStmtIndex;
+ typedef boost::unordered_map<std::string, UInt32> StatementMap;
+ StatementMap _map;
+ } m_prepStmtRegistry;
class Poco::Logger* _logger;
private:
@@ -106,8 +106,8 @@ MySQLConnection::MySQLConnection( Database& db ) : SqlConnection(db), mMysql(NUL
MySQLConnection::~MySQLConnection()
{
- FreePreparedStatements();
- mysql_close(mMysql);
+ if (mMysql)
+ mysql_close(mMysql);
}
bool MySQLConnection::Initialize(const std::string& infoString)
@@ -186,6 +186,9 @@ int MySQLConnection::_Connect(MYSQL* mysqlInit)
reconnecting = true;
}
+ //remove any state from previous session
+ this->clear();
+
for(;;)
{
const char* unix_socket = NULL;
@@ -26,33 +26,25 @@ SqlPreparedStatement* SqlConnection::CreateStatement( const std::string& fmt )
return new SqlPlainPreparedStatement(fmt, *this);
}
-void SqlConnection::FreePreparedStatements()
+void SqlConnection::clear()
{
SqlConnection::Lock guard(this);
-
- size_t nStmts = m_holder.size();
- for (size_t i = 0; i < nStmts; ++i)
- delete m_holder[i];
-
- m_holder.clear();
+ m_stmtHolder.clear();
}
-SqlPreparedStatement* SqlConnection::GetStmt( int nIndex )
+SqlPreparedStatement* SqlConnection::GetStmt( const SqlStatementID& stId )
{
- if(nIndex < 0)
+ if(!stId.isInitialized())
return NULL;
- //resize stmt container
- if(m_holder.size() <= nIndex)
- m_holder.resize(nIndex + 1, NULL);
-
- SqlPreparedStatement * pStmt = NULL;
+ UInt32 stmtId = stId.getId();
+ SqlPreparedStatement* pStmt = m_stmtHolder.getPrepStmtObj(stmtId);
- //create stmt if needed
- if(m_holder[nIndex] == NULL)
+ //create stmt obj if needed
+ if(pStmt == NULL)
{
//obtain SQL request string
- std::string fmt = m_db.GetStmtString(nIndex);
+ std::string fmt = m_db.GetStmtString(stmtId);
poco_assert(fmt.length());
//allocate SQlPreparedStatement object
pStmt = CreateStatement(fmt);
@@ -64,23 +56,21 @@ SqlPreparedStatement* SqlConnection::GetStmt( int nIndex )
}
//save statement in internal registry
- m_holder[nIndex] = pStmt;
+ m_stmtHolder.insertPrepStmtObj(stmtId,pStmt);
}
- else
- pStmt = m_holder[nIndex];
return pStmt;
}
-bool SqlConnection::ExecuteStmt(int nIndex, const SqlStmtParameters& id )
+bool SqlConnection::ExecuteStmt(const SqlStatementID& stId, const SqlStmtParameters& params )
{
- if(nIndex == -1)
+ if(!stId.isInitialized())
return false;
//get prepared statement object
- SqlPreparedStatement* pStmt = GetStmt(nIndex);
+ SqlPreparedStatement* pStmt = GetStmt(stId);
//bind parameters
- pStmt->bind(id);
+ pStmt->bind(params);
//execute statement
return pStmt->execute();
}
@@ -105,3 +95,29 @@ bool SqlConnection::RollbackTransaction()
{
return true;
}
+
+void SqlConnection::StmtHolder::clear()
+{
+ for (StatementObjMap::iterator it=_map.begin();it!=_map.end();++it)
+ delete it->second;
+
+ _map.clear();
+}
+
+SqlPreparedStatement* SqlConnection::StmtHolder::getPrepStmtObj( UInt32 stmtId ) const
+{
+ StatementObjMap::const_iterator iter = _map.find(stmtId);
+ if(iter == _map.end())
+ return NULL;
+ else
+ return iter->second;
+}
+
+void SqlConnection::StmtHolder::insertPrepStmtObj( UInt32 stmtId, SqlPreparedStatement* stmtObj )
+{
+ StatementObjMap::iterator iter = _map.find(stmtId);
+ if(iter != _map.end())
+ delete iter->second;
+
+ _map[stmtId] = stmtObj;
+}
@@ -26,6 +26,7 @@ class Database;
class QueryResult;
class QueryNamedResult;
class SqlPreparedStatement;
+class SqlStatementID;
class SqlStmtParameters;
//
@@ -53,7 +54,7 @@ class SqlConnection
virtual bool RollbackTransaction();
//methods to work with prepared statements
- bool ExecuteStmt(int nIndex, const SqlStmtParameters& id);
+ bool ExecuteStmt(const SqlStatementID& stId, const SqlStmtParameters& id);
//SqlConnection object lock
class Lock
@@ -75,18 +76,31 @@ class SqlConnection
SqlConnection(Database& db) : m_db(db) {}
virtual SqlPreparedStatement* CreateStatement(const std::string& fmt);
- //allocate prepared statement and return statement ID
- SqlPreparedStatement* GetStmt(int nIndex);
+ //allocate and return prepared statement object
+ SqlPreparedStatement* GetStmt(const SqlStatementID& stId);
Database& m_db;
- //free prepared statements objects
- void FreePreparedStatements();
+ //clear any state (because of reconnects, etc)
+ virtual void clear();
private:
- typedef Poco::Mutex LOCK_TYPE;
- LOCK_TYPE m_mutex;
+ typedef Poco::FastMutex ConnLockType;
+ ConnLockType m_mutex;
- typedef std::vector<SqlPreparedStatement* > StmtHolder;
- StmtHolder m_holder;
+ class StmtHolder
+ {
+ public:
+ StmtHolder() {}
+ ~StmtHolder() { clear(); }
+
+ //remove all statements
+ void clear();
+
+ SqlPreparedStatement* getPrepStmtObj(UInt32 stmtId) const;
+ void insertPrepStmtObj(UInt32 stmtId, SqlPreparedStatement* stmtObj);
+ private:
+ typedef boost::unordered_map<UInt32,SqlPreparedStatement*> StatementObjMap;
+ StatementObjMap _map;
+ } m_stmtHolder;
};
Oops, something went wrong.

0 comments on commit 00772db

Please sign in to comment.