Permalink
Browse files

new: db.multiple_insert

  • Loading branch information...
Anand
Anand committed May 1, 2008
1 parent 0ee302d commit 780f52f6ff5a61ba0a6ea4a147e5fd57348eafd1
Showing with 67 additions and 4 deletions.
  1. +8 −1 test/db.py
  2. +3 −1 test/webtest.py
  3. +56 −2 web/db.py
View
@@ -72,7 +72,14 @@ def testPooling(self):
db.hasPooling = True
import DBUtils
self.assertTrue(isinstance(db.ctx.db, DBUtils.PooledDB.PooledDB))
-
+
+ def test_multiple_insert(self):
+ db = webtest.setup_database(self.dbname)
+ db.multiple_insert('person', [dict(name='a'), dict(name='b')], seqname=False)
+
+ assert db.select("person", where="name='a'")
+ assert db.select("person", where="name='b'")
+
class SqliteTest(DBTest):
dbname = "sqlite"
View
@@ -85,7 +85,9 @@ def runTests(suite):
def main(suite=None):
if not suite:
main_module = __import__('__main__')
- suite = module_suite(main_module, sys.argv[1:] or None)
+ # allow command line switches
+ args = [a for a in sys.argv[1:] if not a.startswith('-')]
+ suite = module_suite(main_module, args or None)
result = runTests(suite)
sys.exit(not result.wasSuccessful())
View
@@ -400,6 +400,7 @@ def __init__(self):
# flag to enable/disable printing queries
self.printing = False
self.hasPooling = False
+ self.supports_multiple_insert = False
def _getctx(self):
if not self._ctx.get('db'):
@@ -633,6 +634,58 @@ def q(x): return "(" + x + ")"
if not self.ctx.transactions: self.ctx.db.commit()
return out
+
+ def multiple_insert(self, tablename, values, seqname=None, _test=False):
+ if not values:
+ return []
+
+ if not self.supports_multiple_insert:
+ out = [self.insert(tablename, seqname=seqname, _test=_test, **v) for v in values]
+ if seqname is False:
+ return None
+ else:
+ return out
+
+ keys = values[0].keys()
+ #@@ make sure all keys are valid
+
+ # make sure all rows have same keys.
+ for v in values:
+ if v.keys() != keys:
+ raise ValueError, 'Bad data'
+
+ sql_query = SQLQuery('INSERT INTO %s (%s) VALUES ' % (tablename, ', '.join(keys)))
+
+ data = []
+ for row in values:
+ d = SQLQuery.join([SQLParam(row[k]) for k in keys], ', ')
+ data.append('(' + d + ')')
+ sql_query += SQLQuery.join(data, ', ')
+
+ if _test: return sql_query
+
+ db_cursor = self._db_cursor()
+ if seqname is not False:
+ sql_query = self._process_insert_query(sql_query, tablename, seqname)
+
+ if isinstance(sql_query, tuple):
+ # for some databases, a separate query has to be made to find
+ # the id of the inserted row.
+ q1, q2 = sql_query
+ self._db_execute(db_cursor, q1)
+ self._db_execute(db_cursor, q2)
+ else:
+ self._db_execute(db_cursor, sql_query)
+
+ try:
+ out = db_cursor.fetchone()[0]
+ out = range(out-len(values)+1, out+1)
+ except Exception:
+ out = None
+
+ if not self.ctx.transactions: self.ctx.db.commit()
+ return out
+
def update(self, tables, where, vars=None, _test=False, **values):
"""
@@ -709,6 +762,7 @@ def __init__(self, **keywords):
self.db_module = self.get_db_module()
self.paramstyle = self.db_module.paramstyle
self.keywords = keywords
+ self.supports_multiple_insert = True
def get_db_module(self):
try:
@@ -743,11 +797,11 @@ def __init__(self, **keywords):
self.db_module = db
self.keywords = keywords
self.dbname = "mysql"
-
+ self.supports_multiple_insert = True
+
def _process_insert_query(self, query, tablename, seqname):
return query, SQLQuery('SELECT last_insert_id();')
-
class SqliteDB(DB):
def __init__(self, **keywords):
DB.__init__(self)

0 comments on commit 780f52f

Please sign in to comment.