Skip to content

Commit

Permalink
Handle types for 'INSERT SELECT %s' queries.
Browse files Browse the repository at this point in the history
  • Loading branch information
jamadden committed Jul 19, 2019
1 parent 923689d commit 7ea3f22
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 14 deletions.
45 changes: 32 additions & 13 deletions src/relstorage/adapters/_sql.py
Expand Up @@ -45,6 +45,7 @@
from zope.interface import implementer

from relstorage._compat import NStringIO
from relstorage._compat import intern
from relstorage._util import CachedIn
from .interfaces import IDBDialect

Expand Down Expand Up @@ -168,6 +169,11 @@ def _col_list(self):
def __compile_visit__(self, compiler):
compiler.visit_csv(self._columns)

def has_bind_param(self):
return any(
isinstance(c, (_BindParam, _OrderedBindParam))
for c in self._columns
)

_ColumnList = _Columns

Expand Down Expand Up @@ -344,11 +350,11 @@ def __init__(self, lhs, rhs):

def __compile_visit__(self, compiler):
compiler.visit(self.lhs)
compiler.emit(' JOIN ')
compiler.emit_keyword('JOIN')
compiler.visit(self.rhs)
# careful with USING clause in a join: Oracle doesn't allow such
# columns to have a prefix.
compiler.emit_keyword(' USING')
compiler.emit_keyword('USING')
compiler.visit_grouped(self._join_columns)


Expand Down Expand Up @@ -500,16 +506,28 @@ def _quote_query_for_prepare(self, query):
def _find_datatypes_for_prepared_query(self):
# Deduce the datatypes based on the types of the columns
# we're sending as params.
if isinstance(self.root, Insert) and self.root.values and self.root.column_list:
# If we're sending in a list of values, those have to
# exactly match the columns, so we can easily get a list
# of datatypes.
#
# TODO: We should be able to do this for an `INSERT (col) SELECT $1` too,
# by matching the parameter to the column name.
if isinstance(self.root, Insert):
root = self.root
dialect = root.context
# TODO: Should probably delegate this to the node.
column_list = self.root.column_list
datatypes = self.root.context.datatypes_for_columns(column_list)
if root.values and root.column_list:
# If we're sending in a list of values, those have to
# exactly match the columns, so we can easily get a list
# of datatypes.
column_list = root.column_list
datatypes = dialect.datatypes_for_columns(column_list)
elif root.select and root.select.column_list.has_bind_param():
targets = root.column_list
sources = root.select.column_list
# TODO: This doesn't support bind params anywhere except the
# select list!
columns_with_params = [
target
for target, source in zip(targets, sources)
if isinstance(source, _OrderedBindParam)
]
assert len(self.placeholders) == len(columns_with_params)
datatypes = dialect.datatypes_for_columns(columns_with_params)
return datatypes
return ()

Expand Down Expand Up @@ -550,6 +568,7 @@ def prepare(self):
conjunction=self._PREPARED_CONJUNCTION,
)


if placeholder_to_number:
execute = 'EXECUTE {name}({params})'.format(
name=name,
Expand All @@ -574,10 +593,10 @@ def convert(d):
params[ix - 1] = d[placeholder_name]
return params

return stmt, execute, convert
return intern(stmt), intern(execute), convert

def finalize(self):
return self.buf.getvalue().strip(), {v: k for k, v in self.placeholders.items()}
return intern(self.buf.getvalue().strip()), {v: k for k, v in self.placeholders.items()}

def visit(self, node):
node.__compile_visit__(self)
Expand Down
27 changes: 26 additions & 1 deletion src/relstorage/adapters/tests/test__sql.py
Expand Up @@ -49,7 +49,7 @@

objects = HistoryVariantTable(
current_object,
object_state,
object_state,
)

object_and_state = HistoryVariantTable(
Expand Down Expand Up @@ -217,3 +217,28 @@ def test_prepared_insert_values(self):
stmt._prepare_stmt,
r"PREPARE rs_prep_stmt_[0-9]* \(BIGINT\) AS.*"
)

def test_prepared_insert_select_with_param(self):
stmt = current_object.insert().from_select(
(current_object.c.zoid,
current_object.c.tid),
object_state.select(
object_state.c.zoid,
object_state.orderedbindparam()
)
)
self.assertEqual(
str(stmt),
'INSERT INTO current_object(zoid, tid) SELECT zoid, %s FROM object_state'
)

stmt = stmt.prepared()
self.assertTrue(
str(stmt).startswith('EXECUTE rs_prep_stmt')
)

stmt = stmt.compiled()
self.assertRegex(
stmt._prepare_stmt,
r"PREPARE rs_prep_stmt_[0-9]* \(BIGINT\) AS.*"
)

0 comments on commit 7ea3f22

Please sign in to comment.