Permalink
Browse files

Added transaction support.

  • Loading branch information...
1 parent 967d5b3 commit 7d6f2564739e941155503d26139b23240dd83379 @ovidiucp committed Aug 29, 2011
Showing with 121 additions and 14 deletions.
  1. +24 −0 README
  2. +85 −10 adb.py
  3. +12 −4 threadpool.py
View
24 README
@@ -23,6 +23,16 @@ adisp modules:
from adisp import process
import adb
+ def __init__(self):
+ self.adb = Database(driver="psycopg2",
+ host='DATABASE_HOST',
+ database='DATABASE_DB',
+ user='DATABASE_USER',
+ password='DATABASE_PASSWD',
+ num_threads=3,
+ tx_connection_pool_size=2,
+ queue_timeout=0.001)
+
@process
def someFunctionInvokedFromIOLoop(self):
@@ -50,6 +60,20 @@ import adb
print 'inserted %s records, time taken = %s seconds' % \
(rows, end_time - start_time)
+You can also use transactions:
+
+ @process
+ def transactions(self):
+ txId = yield self.adb.beginTransaction()
+ yield self.adb.runOperation(
+ "insert into mytable (userid, data) values (%s, %s)",
+ (1, "test"),
+ txId)
+ yield self.adb.commitTransaction(txId)
+
+To rollback a transaction, use rollbackTransaction(txId) instead of
+commitTransaction().
+
The command line options for the benchmark are the following:
View
95 adb.py
@@ -1,10 +1,11 @@
-# Asynchronous database interface for Tornado
+# Asynchronous database interface for Tornado with transaction support.
#
# Author: Ovidiu Predescu
# Date: August 2011
from functools import partial
import psycopg2
+from collections import deque
import tornado.ioloop
@@ -26,6 +27,7 @@ def __init__(self,
host='localhost',
ioloop=tornado.ioloop.IOLoop.instance(),
num_threads=10,
+ tx_connection_pool_size=5,
queue_timeout=1):
if not(driver):
raise ValueError("Missing 'driver' argument")
@@ -36,10 +38,18 @@ def __init__(self,
self._host = host
self._threadpool = ThreadPool(
per_thread_init_func=self.create_connection,
+ per_thread_close_func=self.close_connection,
num_threads=num_threads,
queue_timeout=queue_timeout)
self._ioloop = ioloop
+ # Connection pool for transactions
+ self._connection_pool = []
+ for i in xrange(tx_connection_pool_size):
+ conn = self.create_connection()
+ self._connection_pool.append(conn)
+ self._waiting_on_connection = deque()
+
def create_connection(self):
"""This method is executed in a worker thread.
@@ -69,41 +79,101 @@ def create_connection(self):
raise ValueError("Unknown driver %s" % self._driver)
return conn
+ def close_connection(self, conn):
+ conn.close()
+
def stop(self):
self._threadpool.stop()
+ for conn in self._connection_pool:
+ conn.close()
+
+ @async
+ def beginTransaction(self, callback):
+ """Begins a transaction. Picks up a transaction from the pool
+ and passes it to the callback. If none is available, adds the
+ callback to `_waiting_on_connection'.
+ """
+ if self._connection_pool:
+ conn = self._connection_pool.pop()
+ callback(conn)
+ else:
+ self._waiting_on_connection.append(callback)
@async
- def runQuery(self, query, args=None, callback=None):
+ def commitTransaction(self, connection, callback):
+ self._threadpool.add_task(
+ partial(self._commitTransaction, connection, callback))
+
+ def _commitTransaction(self, conn, callback, thread_state=None):
+ """Invoked in a worker thread.
+ """
+ conn.commit()
+ self._ioloop.add_callback(
+ partial(self._releaseConnectionInvokeCallback, conn, callback))
+
+ @async
+ def rollbackTransaction(self, connection, callback):
+ self._threadpool.add_task(
+ partial(self._rollbackTransaction, connection, callback))
+
+ def _rollbackTransaction(self, conn, callback, thread_state=None):
+ """Invoked in a worker thread.
+ """
+ conn.rollback()
+ self._ioloop.add_callback(
+ partial(self._releaseConnectionInvokeCallback, conn, callback))
+
+ def _releaseConnectionInvokeCallback(self, conn, callback):
+ """Release the connection back in the connection pool and
+ invoke the callback. Invokes any waiting callbacks before
+ releasing the connection into the pool.
+ """
+ # First invoke the callback to let the program know we're done
+ # with the transaction.
+ callback(conn)
+ # Now check to see if we have any pending clients. If so pass
+ # them the newly released connection.
+ if self._waiting_on_connection:
+ callback = self._waiting_on_connection.popleft()
+ callback(conn)
+ else:
+ self._connection_pool.append(conn)
+
+ @async
+ def runQuery(self, query, args=None, conn=None, callback=None):
"""Send a SELECT query to the database.
The callback is invoked with all the rows in the result.
"""
- self._threadpool.add_task(partial(self._query, query, args), callback)
+ self._threadpool.add_task(
+ partial(self._query, query, args, conn), callback)
- def _query(self, query, args, thread_state=None):
+ def _query(self, query, args, conn=None, thread_state=None):
"""This method is called in a worker thread.
Execute the query and return the result so it can be passed as
argument to the callback.
"""
- conn = thread_state
+ if not conn:
+ conn = thread_state
cursor = conn.cursor()
cursor.execute(query, args)
rows = cursor.fetchall()
cursor.close()
return rows
@async
- def runOperation(self, stmt, args=None, callback=None):
+ def runOperation(self, stmt, args=None, conn=None, callback=None):
"""Execute a SQL statement other than a SELECT.
The statement is committed immediately. The number of rows
affected by the statement is passed as argument to the
callback.
"""
- self._threadpool.add_task(partial(self._execute, stmt, args), callback)
+ self._threadpool.add_task(
+ partial(self._execute, stmt, args, conn), callback)
- def _execute(self, stmt, args, thread_state=None):
+ def _execute(self, stmt, args, conn=None, thread_state=None):
"""This method is called in a worker thread.
Executes the statement.
@@ -113,10 +183,15 @@ def _execute(self, stmt, args, thread_state=None):
if isinstance(stmt, tuple):
args = stmt[1]
stmt = stmt[0]
- conn = thread_state
+ if not conn:
+ conn = thread_state
+ should_commit = True
+ else:
+ should_commit = False
cursor = conn.cursor()
cursor.execute(stmt, args)
- conn.commit()
+ if should_commit:
+ conn.commit()
rowcount = cursor.rowcount
cursor.close()
return rowcount
View
16 threadpool.py
@@ -1,4 +1,4 @@
-# Thread pool to be used with Tornado
+# Thread pool to be used with Tornado.
#
# Author: Ovidiu Predescu
# Date: August 2011
@@ -9,6 +9,7 @@
from Queue import Queue, Empty
from functools import partial
import tornado.ioloop
+import time
class ThreadPool:
"""Creates a thread pool containing `num_threads' worker threads.
@@ -58,7 +59,10 @@ def func(thread_state=None):
environment).
"""
- def __init__(self, per_thread_init_func=None, num_threads=10,
+ def __init__(self,
+ per_thread_init_func=None,
+ per_thread_close_func=None,
+ num_threads=10,
queue_timeout=1,
ioloop=tornado.ioloop.IOLoop.instance()):
self._ioloop = ioloop
@@ -68,7 +72,7 @@ def __init__(self, per_thread_init_func=None, num_threads=10,
self._threads = []
self._running = True
for i in xrange(num_threads):
- t = WorkerThread(self, per_thread_init_func)
+ t = WorkerThread(self, per_thread_init_func, per_thread_close_func)
t.start()
self._threads.append(t)
@@ -78,12 +82,14 @@ def add_task(self, func, callback=None):
def stop(self):
self._running = False
+ map(lambda t: t.join(), self._threads)
class WorkerThread(Thread):
- def __init__(self, pool, per_thread_init_func):
+ def __init__(self, pool, per_thread_init_func, per_thread_close_func):
Thread.__init__(self)
self._pool = pool
self._per_thread_init_func = per_thread_init_func
+ self._per_thread_close_func = per_thread_close_func
def run(self):
if self._per_thread_init_func:
@@ -100,3 +106,5 @@ def run(self):
self._pool._ioloop.add_callback(partial(callback, result))
except Empty:
pass
+ if self._per_thread_close_func:
+ self._per_thread_close_func(thread_state)

0 comments on commit 7d6f256

Please sign in to comment.