Browse files

Database transactions

  • Loading branch information...
1 parent fb86e48 commit f54e168756d5aa7247acc30e14dea60919bd7695 @anandology anandology committed Jan 4, 2008
Showing with 152 additions and 59 deletions.
  1. +62 −2 test/db.py
  2. +90 −57 web/db.py
View
64 test/db.py
@@ -5,24 +5,84 @@
class DBTest(webtest.TestCase):
dbname = 'postgres'
- def setUpAll(self):
+ def setUp(self):
self.db = webtest.setup_database(self.dbname)
+ self.db.printing = True
self.db.query("CREATE TABLE person (name text, email text)")
- def tearDownAll(self):
+ def tearDown(self):
# there might be some error with the current connection, delete from a new connection
self.db = webtest.setup_database(self.dbname)
self.db.query('DROP TABLE person')
def testUnicode(self):
"""Bug#177265: unicode queries throw errors"""
self.db.select('person', where='name=$name', vars={'name': u'\xf4'})
+
+ def assertRows(self, n):
+ result = self.db.select('person')
+ self.assertEquals(len(list(result)), n)
+
+ def testCommit(self):
+ t = self.db.transaction()
+ self.db.insert('person', False, name='user1')
+ t.commit()
+
+ t = self.db.transaction()
+ self.db.insert('person', False, name='user2')
+ self.db.insert('person', False, name='user3')
+ t.commit()
+
+ self.assertRows(3)
+
+ def testRollback(self):
+ t = self.db.transaction()
+ self.db.insert('person', False, name='user1')
+ self.db.insert('person', False, name='user2')
+ self.db.insert('person', False, name='user3')
+ t.rollback()
+ self.assertRows(0)
+
+ def testWrongQuery(self):
+ # It should be possible to run a correct query after getting an error from a wrong query.
+ try:
+ self.db.select('notthere')
+ except:
+ pass
+ self.db.select('person')
+
+ def testNestedTransactions(self):
+ t1 = self.db.transaction()
+ self.db.insert('person', False, name='user1')
+ self.assertRows(1)
+
+ t2 = self.db.transaction()
+ self.db.insert('person', False, name='user2')
+ self.assertRows(2)
+ t2.rollback()
+ self.assertRows(1)
+ t3 = self.db.transaction()
+ self.db.insert('person', False, name='user3')
+ self.assertRows(2)
+ t3.commit()
+ t1.commit()
+ self.assertRows(2)
class SqliteTest(DBTest):
dbname = "sqlite"
+ def testNestedTransactions(self):
+ #nested transactions does not work with sqlite
+ pass
+
class MySQLTest(DBTest):
dbname = "mysql"
+
+ def setUp(self):
+ self.db = webtest.setup_database(self.dbname)
+ self.db.printing = True
+ # In mysql, transactions are supported only with INNODB engine.
+ self.db.query("CREATE TABLE person (name text, email text) ENGINE=INNODB")
if __name__ == '__main__':
webtest.main()
View
147 web/db.py
@@ -319,7 +319,71 @@ def sqlquote(a):
<sql: "WHERE x = 't' AND y = 3">
"""
return sqlparam(a).sqlquery()
-
+
+class Transaction:
+ """Database transaction."""
+ def __init__(self, ctx):
+ self.ctx = ctx
+ self.transaction_count = transaction_count = len(ctx.transactions)
+
+ class transaction_engine:
+ def do_transact(self):
+ ctx.db.commit()
+
+ def do_commit(self):
+ ctx.db.commit()
+
+ def do_rollback(self):
+ ctx.db.rollback()
+
+ class subtransaction_engine:
+ def query(self, q):
+ db_cursor = ctx.db.cursor()
+ ctx.db_execute(db_cursor, SQLQuery(q % transaction_count))
+
+ def do_transact(self):
+ self.query('SAVEPOINT webpy_sp_%s')
+
+ def do_commit(self):
+ self.query('RELEASE SAVEPOINT webpy_sp_%s')
+
+ def do_rollback(self):
+ self.query('ROLLBACK TO SAVEPOINT webpy_sp_%s')
+
+ class dummy_engine:
+ do_transact = do_commit = do_rollback = lambda self: None
+
+ if self.transaction_count:
+ # nested transactions are not supported in some databases
+ if self.ctx.get('ignore_nested_transactions'):
+ self.engine = dummy_engine()
+ else:
+ self.engine = subtransaction_engine()
+ else:
+ self.engine = transaction_engine()
+
+ self.engine.do_transact()
+ self.ctx.transactions.append(self)
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exctype, excvalue, traceback):
+ if exctype is not None:
+ self.rollback()
+ else:
+ self.commit()
+
+ def commit(self):
+ if len(self.ctx.transactions) > self.transaction_count:
+ self.engine.do_commit()
+ self.ctx.transactions = self.ctx.transactions[:self.transaction_count]
+
+ def rollback(self):
+ if len(self.ctx.transactions) > self.transaction_count:
+ self.engine.do_rollback()
+ self.ctx.transactions = self.ctx.transactions[:self.transaction_count]
+
class DB:
"""Database"""
def __init__(self):
@@ -330,18 +394,22 @@ def __init__(self):
def _getctx(self):
if not self._ctx.get('db'):
- self._ctx.dbq_count = 0
- self._ctx.db_transaction = 0
- self._ctx.db = self.db_module.connect(**self.keywords)
- if not hasattr(self._ctx.db, 'commit'):
- self._ctx.db.commit = lambda: None
-
- if not hasattr(self._ctx.db, 'rollback'):
- self._ctx.db.rollback = lambda: None
-
+ self._load_context()
return self._ctx
ctx = property(_getctx)
-
+
+ def _load_context(self):
+ self._ctx.dbq_count = 0
+ self._ctx.transactions = [] # stack of transactions
+ self._ctx.db = self.db_module.connect(**self.keywords)
+ self._ctx.db_execute = self._db_execute
+
+ if not hasattr(self._ctx.db, 'commit'):
+ self._ctx.db.commit = lambda: None
+
+ if not hasattr(self._ctx.db, 'rollback'):
+ self._ctx.db.rollback = lambda: None
+
def _db_cursor(self):
return self.ctx.db.cursor()
@@ -369,7 +437,11 @@ def _db_execute(self, cur, sql_query, dorollback=True):
except:
if self.printing:
print >> debug, 'ERR:', str(sql_query)
- if dorollback: self.rollback(care=False)
+ if dorollback:
+ if self.ctx.transactions:
+ self.ctx.transactions[-1].rollback()
+ else:
+ self.ctx.db.rollback()
raise
if self.printing:
@@ -426,7 +498,7 @@ def iterwrapper():
else:
out = db_cursor.rowcount
- if not self.ctx.db_transaction: self.ctx.db.commit()
+ if not self.ctx.transactions: self.ctx.db.commit()
return out
def select(self, tables, vars=None, what='*', where=None, order=None, group=None,
@@ -523,7 +595,7 @@ def q(x): return "(" + x + ")"
except Exception:
out = None
- if not self.ctx.db_transaction: self.ctx.db.commit()
+ if not self.ctx.transactions: self.ctx.db.commit()
return out
def update(self, tables, where, vars=None, _test=False, **values):
@@ -554,7 +626,7 @@ def update(self, tables, where, vars=None, _test=False, **values):
db_cursor = self._db_cursor()
self._db_execute(db_cursor, query)
- if not self.ctx.db_transaction: self.ctx.db.commit()
+ if not self.ctx.transactions: self.ctx.db.commit()
return db_cursor.rowcount
def delete(self, table, where=None, using=None, vars=None, _test=False):
@@ -577,53 +649,15 @@ def delete(self, table, where=None, using=None, vars=None, _test=False):
db_cursor = self._db_cursor()
self._db_execute(db_cursor, q)
- if not self.ctx.db_transaction: self.ctx.db.commit()
+ if not self.ctx.transactions: self.ctx.db.commit()
return db_cursor.rowcount
def _process_insert_query(self, query, tablename, seqname):
return query
- def transact(self):
+ def transaction(self):
"""Start a transaction."""
- if not self.ctx.db_transaction:
- # commit everything up to now, so we don't rollback it later
- self.ctx.db.commit()
- else:
- db_cursor = self._db_cursor()
- self._db_execute(db_cursor,
- SQLQuery("SAVEPOINT webpy_sp_%s" % self.ctx.db_transaction))
- self.ctx.db_transaction += 1
-
- def commit(self):
- """Commits a transaction."""
- self.ctx.db_transaction -= 1
- if self.ctx.db_transaction < 0:
- raise TransactionError, "not in a transaction"
-
- if not self.ctx.db_transaction:
- self.ctx.db.commit()
- else:
- db_cursor = self._db_cursor()
- self._db_execute(db_cursor,
- SQLQuery("RELEASE SAVEPOINT webpy_sp_%s" % self.ctx.db_transaction))
-
- def rollback(self, care=True):
- """Rolls back a transaction."""
- self.ctx.db_transaction -= 1
- if self.ctx.db_transaction < 0:
- self.db_transaction = 0
- if care:
- raise TransactionError, "not in a transaction"
- else:
- return
-
- if not self.ctx.db_transaction:
- self.ctx.db.rollback()
- else:
- db_cursor = self._db_cursor()
- self._db_execute(db_cursor,
- SQLQuery("ROLLBACK TO SAVEPOINT webpy_sp_%s" % self.ctx.db_transaction),
- dorollback=False)
+ return Transaction(self.ctx)
class PostgresDB(DB):
"""Postgres driver."""
@@ -655,7 +689,6 @@ def _process_insert_query(self, query, tablename, seqname):
seqname = tablename + "_id_seq"
return query + "; SELECT currval('%s')" % seqname
-
class MySQLDB(DB):
def __init__(self, **keywords):
DB.__init__(self)

0 comments on commit f54e168

Please sign in to comment.