Skip to content

Commit

Permalink
Fix dropbox#50 executemany failed bug, add unit tests as well
Browse files Browse the repository at this point in the history
  • Loading branch information
taogeYT committed May 15, 2020
1 parent 333d0d1 commit 7eaa0d6
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 20 deletions.
30 changes: 10 additions & 20 deletions pyhive/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
from itertools import islice
import re

MATCH_INSERT_SQL = re.compile(
r"\s*((?:INSERT)\b.+\bVALUES?\s*)(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))",
re.IGNORECASE | re.DOTALL)


class DBAPICursor(with_metaclass(abc.ABCMeta, object)):
"""Base class for some common DB-API logic"""
Expand Down Expand Up @@ -86,26 +90,12 @@ def executemany(self, operation, seq_of_parameters):
Return values are not defined.
"""
from .hive import _escaper
match_insert_sql = re.compile(
r"\s*((?:INSERT)\b.+\bVALUES?\s*)(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))",
re.IGNORECASE | re.DOTALL)
match = match_insert_sql.match(operation)
if match:
part1, part2 = match.group(1), match.group(2).rstrip()
values = []
for parameter in seq_of_parameters:
record = part2 % _escaper.escape_args(parameter)
values.append(record)
sql = part1 + ",".join(values)
self.execute(sql)
else:
for parameters in seq_of_parameters[:-1]:
self.execute(operation, parameters)
while self._state != self._STATE_FINISHED:
self._fetch_more()
if seq_of_parameters:
self.execute(operation, seq_of_parameters[-1])
for parameters in seq_of_parameters[:-1]:
self.execute(operation, parameters)
while self._state != self._STATE_FINISHED:
self._fetch_more()
if seq_of_parameters:
self.execute(operation, seq_of_parameters[-1])

def fetchone(self):
"""Fetch the next row of a query result set, returning a single sequence, or ``None`` when
Expand Down
20 changes: 20 additions & 0 deletions pyhive/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,26 @@ def execute(self, operation, parameters=None, **kwargs):
_check_status(response)
self._operationHandle = response.operationHandle

def executemany(self, operation, seq_of_parameters):
"""Prepare a database operation (query or command) and then execute it against all parameter
sequences or mappings found in the sequence ``seq_of_parameters``.
Only the final result set is retained.
Return values are not defined.
"""
match = common.MATCH_INSERT_SQL.match(operation)
if match:
part1, part2 = match.group(1), match.group(2).rstrip()
values = []
for parameter in seq_of_parameters:
record = part2 % _escaper.escape_args(parameter)
values.append(record)
sql = part1 + ",".join(values)
self.execute(sql)
else:
return super(Cursor, self).executemany(operation, seq_of_parameters)

def cancel(self):
req = ttypes.TCancelOperationReq(
operationHandle=self._operationHandle,
Expand Down
20 changes: 20 additions & 0 deletions pyhive/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,26 @@ def execute(self, operation, parameters=None):
url, data=sql.encode('utf-8'), headers=headers, **self._requests_kwargs)
self._process_response(response)

def executemany(self, operation, seq_of_parameters):
"""Prepare a database operation (query or command) and then execute it against all parameter
sequences or mappings found in the sequence ``seq_of_parameters``.
Only the final result set is retained.
Return values are not defined.
"""
match = common.MATCH_INSERT_SQL.match(operation)
if match:
part1, part2 = match.group(1), match.group(2).rstrip()
values = []
for parameter in seq_of_parameters:
record = part2 % _escaper.escape_args(parameter)
values.append(record)
sql = part1 + ",".join(values)
self.execute(sql)
else:
return super(Cursor, self).executemany(operation, seq_of_parameters)

def cancel(self):
if self._state == self._STATE_NONE:
raise ProgrammingError("No query yet")
Expand Down
8 changes: 8 additions & 0 deletions pyhive/tests/dbapi_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,14 @@ def test_executemany_none(self, cursor):
self.assertIsNone(cursor.description)
self.assertRaises(exc.ProgrammingError, cursor.fetchone)

@with_cursor
def test_executemany_batch_insert(self, cursor):
records = [("one", 1)]*10
cursor.execute("create temporary table test_executemany_batch_insert (foo string, bar int)")
cursor.executemany("INSERT INTO TABLE test_executemany_batch_insert VALUES (%s, %s)", records)
cursor.execute('SELECT * FROM test_executemany_batch_insert')
self.assertEqual(cursor.fetchall(), records)

@with_cursor
def test_fetchone_no_data(self, cursor):
self.assertRaises(exc.ProgrammingError, cursor.fetchone)
Expand Down

0 comments on commit 7eaa0d6

Please sign in to comment.