Permalink
Browse files

test DB classes with each driver separately.

  • Loading branch information...
1 parent 06a9549 commit 88c46699bc89adaec2d9f48122d902d1c6688257 @anandology anandology committed May 8, 2009
Showing with 58 additions and 31 deletions.
  1. +31 −10 test/db.py
  2. +27 −21 web/db.py
View
@@ -4,26 +4,27 @@
class DBTest(webtest.TestCase):
dbname = 'postgres'
+ driver = None
def setUp(self):
- self.db = webtest.setup_database(self.dbname)
+ self.db = webtest.setup_database(self.dbname, driver=self.driver)
self.db.query("CREATE TABLE person (name text, email text)")
def tearDown(self):
# there might be some error with the current connection, delete from a new connection
- self.db = webtest.setup_database(self.dbname)
+ self.db = webtest.setup_database(self.dbname, driver=self.driver)
self.db.query('DROP TABLE person')
def _testable(self):
try:
- webtest.setup_database(self.dbname)
+ webtest.setup_database(self.dbname, driver=self.driver)
return True
except ImportError, e:
- print >> web.debug, str(e), "(ignoring the %s tests)" % self.dbname
+ print >> web.debug, str(e), "(ignoring %s)" % self.__class__.__name__
return False
def testUnicode(self):
- """Bug#177265: unicode queries throw errors"""
+ # Bug#177265: unicode queries throw errors
self.db.select('person', where='name=$name', vars={'name': u'\xf4'})
def assertRows(self, n):
@@ -87,19 +88,33 @@ def test_multiple_insert(self):
assert db.select("person", where="name='a'")
assert db.select("person", where="name='b'")
- def testUnicode(self):
+ def test_result_is_unicode(self):
db = webtest.setup_database(self.dbname)
self.db.insert('person', False, name='user')
name = db.select('person')[0].name
self.assertEquals(type(name), unicode)
+class PostgresTest(DBTest):
+ dbname = "postgres"
+ driver = "psycopg2"
+
+class PostgresTest_psycopg(PostgresTest):
+ driver = "psycopg"
+
+class PostgresTest_pgdb(PostgresTest):
+ driver = "pgdb"
+
class SqliteTest(DBTest):
dbname = "sqlite"
+ driver = "sqlite3"
def testNestedTransactions(self):
#nested transactions does not work with sqlite
pass
-
+
+class SqliteTest_pysqlite2(SqliteTest):
+ driver = "pysqlite2.dbapi2"
+
class MySQLTest(DBTest):
dbname = "mysql"
@@ -108,12 +123,18 @@ def setUp(self):
# In mysql, transactions are supported only with INNODB engine.
self.db.query("CREATE TABLE person (name text, email text) ENGINE=INNODB")
+del DBTest
+
+def is_test(cls):
+ import inspect
+ return inspect.isclass(cls) and webtest.TestCase in inspect.getmro(cls)
# ignore db tests when the required db adapter is not found.
-for t in [DBTest, MySQLTest, SqliteTest]:
- if not t('_testable')._testable():
+for t in globals().values():
+ if is_test(t) and not t('_testable')._testable():
del globals()[t.__name__]
- pass
+
+del t
if __name__ == '__main__':
webtest.main()
View
@@ -435,8 +435,13 @@ class DB:
def __init__(self, db_module, keywords):
"""Creates a database.
"""
+ # some DB implementaions take optional paramater `driver` to use a specific driver modue
+ # but it should not be passed to connect
+ keywords.pop('driver', None)
+
self.db_module = db_module
self.keywords = keywords
+
self._ctx = threadeddict()
# flag to enable/disable printing queries
@@ -869,25 +874,17 @@ def __init__(self, **keywords):
keywords['password'] = keywords['pw']
del keywords['pw']
- db_module = self.get_db_module()
+ db_module = import_driver(["psycopg2", "psycopg", "pgdb"], preferred=keywords.pop('driver', None))
+ if db_module.__name__ == "psycopg2":
+ import psycopg2.extensions
+ psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
+
keywords['database'] = keywords.pop('db')
self.dbname = "postgres"
self.paramstyle = db_module.paramstyle
DB.__init__(self, db_module, keywords)
self.supports_multiple_insert = True
- def get_db_module(self):
- try:
- import psycopg2 as db
- import psycopg2.extensions
- psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
- except ImportError:
- try:
- import psycopg as db
- except ImportError:
- import pgdb as db
- return db
-
def _process_insert_query(self, query, tablename, seqname):
if seqname is None:
seqname = tablename + "_id_seq"
@@ -923,17 +920,26 @@ def __init__(self, **keywords):
def _process_insert_query(self, query, tablename, seqname):
return query, SQLQuery('SELECT last_insert_id();')
+def import_driver(drivers, preferred=None):
+ """Import the first available driver or preferred driver.
+ """
+ if preferred:
+ drivers = [preferred]
+
+ for d in drivers:
+ try:
+ return __import__(d, None, None, ['x'])
+ except ImportError:
+ pass
+ raise ImportError("Unable to import " + "or ".join(drivers))
+
class SqliteDB(DB):
def __init__(self, **keywords):
- try:
- import sqlite3 as db
+ db = import_driver(["sqlite3", "pysqlite2.dbapi2", "sqlite"], preferred=keywords.pop('driver', None))
+
+ if db.__name__ in ["sqlite3", "pysqlite2.dbapi2"]:
db.paramstyle = 'qmark'
- except ImportError:
- try:
- from pysqlite2 import dbapi2 as db
- db.paramstyle = 'qmark'
- except ImportError:
- import sqlite as db
+
self.paramstyle = db.paramstyle
keywords['database'] = keywords.pop('db')
self.dbname = "sqlite"

0 comments on commit 88c4669

Please sign in to comment.